-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathvisualize.py
More file actions
128 lines (115 loc) · 3.59 KB
/
visualize.py
File metadata and controls
128 lines (115 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
""" Network architecture visualizer using graphviz """
import sys
from graphviz import Digraph
import genotypes as gt
def subplot(genotype, prefix, dag_layers):
edge_attr = {
'fontsize': '15',
'fontname': 'Sans'
}
node_attr = {
'style': 'filled',
'shape': 'rect',
'align': 'center',
'fontsize': '20',
'height': '0.5',
'width': '0.5',
'penwidth': '2',
'fontname': 'Sans'
}
g = Digraph(
format='png',
edge_attr=edge_attr,
node_attr=node_attr,
engine='dot')
g.body.extend(['rankdir=LR'])
# input nodes
n_input = dag_layers.n_input
g_in = []
for i in range(n_input):
node = str(prefix)+'in_'+str(i)
g_in.append(node)
g.node(node, fillcolor='darkseagreen2')
# intermediate nodes
n_nodes = len(genotype)
g_nodes = []
for i in range(n_nodes):
node = str(prefix)+'n_'+str(i)
g.node(node, fillcolor='lightblue')
g_nodes.append(node)
j = 0
for i, edges in enumerate(genotype):
for g_child, sidx, n_state in edges:
v = g_nodes[n_state-n_input]
if isinstance(g_child[0], str):
op = g_child[0]
else:
p_child, n_in, n_out = subplot(g_child, str(prefix)+str(j)+'_', dag_layers.edges[0])
g.subgraph(p_child)
g.edge(n_out, v, label='', fillcolor="gray")
j=j+1
for i, si in enumerate(sidx):
if si < n_input:
u = g_in[si]
else:
u = g_nodes[si-n_input]
if isinstance(g_child[0], str):
g.edge(u, v, label=op, fillcolor="gray")
else:
g.edge(u, n_in[i], label='', fillcolor="gray")
# output node
g_out = str(prefix)+'out'
g.node(g_out, fillcolor='palegoldenrod')
for i in dag_layers.merge_out_range:
if i < n_input:
u = g_in[i]
else:
u = g_nodes[i-n_input]
g.edge(u, g_out, fillcolor="gray")
return g, g_in, g_out
def plot(genotype, dag_layers, file_path, caption=None):
""" make DAG plot and save to file_path as .png """
edge_attr = {
'fontsize': '15',
'fontname': 'Sans'
}
node_attr = {
'style': 'filled',
'shape': 'rect',
'align': 'center',
'fontsize': '20',
'height': '0.5',
'width': '0.5',
'penwidth': '2',
'fontname': 'Sans'
}
g = Digraph(
format='png',
edge_attr=edge_attr,
node_attr=node_attr,
engine='dot')
g.body.extend(['rankdir=TB'])
g_child, g_in, g_out = subplot(genotype,'', dag_layers)
for n in g_in:
g.edge('input',n,label='',fillcolor='gray')
g.subgraph(g_child)
# add image caption
if caption:
g.attr(label=caption, overlap='false', fontsize='20', fontname='Sans')
try:
g.render(file_path, view=False)
print('plot saved to: {}'.format(file_path))
except Exception as e:
print('render failed: {}'.format(str(e)))
with open(file_path, 'w') as f:
f.write(g.source)
if __name__ == '__main__':
if len(sys.argv) != 2:
raise ValueError("usage:\n python {} GENOTYPE".format(sys.argv[0]))
genotype_str = sys.argv[1]
try:
genotype = gt.from_str(genotype_str)
except AttributeError:
raise ValueError("Cannot parse {}".format(genotype_str))
plot(genotype.normal, "normal")
plot(genotype.reduce, "reduction")