-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_builder.py
More file actions
66 lines (58 loc) · 2.3 KB
/
graph_builder.py
File metadata and controls
66 lines (58 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import networkx as nx
import matplotlib.pyplot as plt
edges = [
("S0", "S1", "a1"),
("S0", "S2", "a2"),
("S1", "S3", "a3"),
("S1", "S4", "a4"),
("S2", "S5", "a5"),
("S4", "S6", "a6"),
]
nodes = {
"S0": {"w": 4, "n": 6},
"S1": {"w": 3, "n": 4},
"S2": {"w": 1, "n": 2},
"S3": {"w": 0, "n": 1},
"S4": {"w": 2, "n": 2},
"S5": {"w": 0, "n": 1},
"S6": {"w": 1, "n": 1},
}
def plot_tree_with_labels(edges=edges, nodes=nodes, root="S0", width=1.0, vert_gap=0.1, vert_loc=0, xcenter=1):
def hierarchical_pos(G, root, width, vert_gap, vert_loc, xcenter):
if not nx.is_tree(G):
raise TypeError("Il grafo deve essere un albero!")
def _hierarchy_pos(G, node, width, vert_gap, vert_loc, xcenter, pos=None, parent=None):
if pos is None:
pos = {node: (xcenter, vert_loc)}
else:
pos[node] = (xcenter, vert_loc)
children = list(G.neighbors(node))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
next_x = xcenter - width / 2 - dx / 2
for child in children:
next_x += dx
pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap, xcenter=next_x, pos=pos, parent=node)
return pos
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
G = nx.DiGraph()
for parent, child, label in edges:
G.add_edge(parent, child, label=label)
pos = hierarchical_pos(G, root, width, vert_gap, vert_loc, xcenter)
node_labels = {
node: f"{node}\nW={nodes[node]['w']}\nN={nodes[node]['n']}"
for node in G.nodes()
}
plt.figure(figsize=(10, 8))
nx.draw(G, pos, with_labels=False, arrows=True, node_size=2000, node_color='#ADD8E6')
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_color='black')
edge_labels = nx.get_edge_attributes(G, 'label')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plt.title("MonteCarlo Tree Search")
plt.axis("off")
plt.show()
if __name__ == "__main__":
plot_tree_with_labels()