Skip to content

Commit

Permalink
Merge remote-tracking branch 'kushaangupta/feature/tensorflow-2-support'
Browse files Browse the repository at this point in the history
  • Loading branch information
endolith committed Apr 2, 2022
2 parents 031f497 + e4bce27 commit 3a80ab6
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions ann_visualizer/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"""

def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
"""Vizualizez a Sequential model.
"""Visualize a Sequential model.
# Arguments
model: A Keras model instance.
model: A tensorflow.keras model instance.
view: whether to display the model after generation.
Expand All @@ -28,9 +28,15 @@ def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
title: A title for the graph
"""
from graphviz import Digraph;
import keras;
from keras.models import Sequential;
from keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten;
from tensorflow.keras.models import Sequential;
from tensorflow.keras.layers import (
Activation,
Dense,
Conv2D,
MaxPooling2D,
Dropout,
Flatten
);
import json;
input_layer = 0;
hidden_layers_nr = 0;
Expand All @@ -41,44 +47,44 @@ def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
if(layer == model.layers[0]):
input_layer = int(str(layer.input_shape).split(",")[1][1:-1]);
hidden_layers_nr += 1;
if (type(layer) == keras.layers.core.Dense):
if (type(layer) == Dense):
hidden_layers.append(int(str(layer.output_shape).split(",")[1][1:-1]));
layer_types.append("Dense");
else:
hidden_layers.append(1);
if (type(layer) == keras.layers.convolutional.Conv2D):
if (type(layer) == Conv2D):
layer_types.append("Conv2D");
elif (type(layer) == keras.layers.pooling.MaxPooling2D):
elif (type(layer) == MaxPooling2D):
layer_types.append("MaxPooling2D");
elif (type(layer) == keras.layers.core.Dropout):
elif (type(layer) == Dropout):
layer_types.append("Dropout");
elif (type(layer) == keras.layers.core.Flatten):
elif (type(layer) == Flatten):
layer_types.append("Flatten");
elif (type(layer) == keras.layers.core.Activation):
elif (type(layer) == Activation):
layer_types.append("Activation");
else:
if(layer == model.layers[-1]):
output_layer = int(str(layer.output_shape).split(",")[1][1:-1]);
else:
hidden_layers_nr += 1;
if (type(layer) == keras.layers.core.Dense):
if (type(layer) == Dense):
hidden_layers.append(int(str(layer.output_shape).split(",")[1][1:-1]));
layer_types.append("Dense");
else:
hidden_layers.append(1);
if (type(layer) == keras.layers.convolutional.Conv2D):
if (type(layer) == Conv2D):
layer_types.append("Conv2D");
elif (type(layer) == keras.layers.pooling.MaxPooling2D):
elif (type(layer) == MaxPooling2D):
layer_types.append("MaxPooling2D");
elif (type(layer) == keras.layers.core.Dropout):
elif (type(layer) == Dropout):
layer_types.append("Dropout");
elif (type(layer) == keras.layers.core.Flatten):
elif (type(layer) == Flatten):
layer_types.append("Flatten");
elif (type(layer) == keras.layers.core.Activation):
elif (type(layer) == Activation):
layer_types.append("Activation");
last_layer_nodes = input_layer;
nodes_up = input_layer;
if(type(model.layers[0]) != keras.layers.core.Dense):
if(type(model.layers[0]) != Dense):
last_layer_nodes = 1;
nodes_up = 1;
input_layer = 1;
Expand All @@ -88,7 +94,7 @@ def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
g.graph_attr.update(splines="false", nodesep='1', ranksep='2');
#Input Layer
with g.subgraph(name='cluster_input') as c:
if(type(model.layers[0]) == keras.layers.core.Dense):
if(type(model.layers[0]) == Dense):
the_label = title+'\n\n\n\nInput Layer';
if (int(str(model.layers[0].input_shape).split(",")[1][1:-1]) > 10):
the_label += " (+"+str(int(str(model.layers[0].input_shape).split(",")[1][1:-1]) - 10)+")";
Expand All @@ -101,7 +107,7 @@ def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
c.attr(rank='same');
c.node_attr.update(color="#2ecc71", style="filled", fontcolor="#2ecc71", shape="circle");

elif(type(model.layers[0]) == keras.layers.convolutional.Conv2D):
elif(type(model.layers[0]) == Conv2D):
#Conv2D Input visualizing
the_label = title+'\n\n\n\nInput Layer';
c.attr(color="white", label=the_label);
Expand Down Expand Up @@ -188,7 +194,7 @@ def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):


with g.subgraph(name='cluster_output') as c:
if (type(model.layers[-1]) == keras.layers.core.Dense):
if (type(model.layers[-1]) == Dense):
c.attr(color='white')
c.attr(rank='same');
c.attr(labeljust="1");
Expand Down

4 comments on commit 3a80ab6

@endolith
Copy link
Owner Author

@endolith endolith commented on 3a80ab6 Apr 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From RedaOps#36

Actually this doesn't really fix RedaOps#25 RedaOps#34 or RedaOps#35 as @kushaangupta claimed in pull request RedaOps#36?

Those issues are all about Input layers, not about tensorflow versions? Input layers are actually fixed in ff773b1 ?

@kushaangupta
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you're right! I assumed that the error they got was due to the same reason that I was having. I should've taken a look at the seed code. I guess it's better to point people to your fork rather than mine. Thanks for pointing it out

@endolith
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kushaangupta Yeah I'm trying to collect all the different forks as much as possible and get this working again

@endolith
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensorflow changes were also submitted as RedaOps#31 by @aliakbarhamzeh1378 , but that fork no longer exists

Please sign in to comment.