Skip to content

Commit e3c4741

Browse files
authored
Merge pull request #23 from TimeDelta/codex/write-test-for-torchscript-graph
Add optimizer graph dict test
2 parents 9933fd7 + a2ba15d commit e3c4741

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/test_graph_dict.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import sys
3+
import pathlib
4+
import importlib.util
5+
import glob
6+
import os
7+
import torch
8+
9+
# allow imports from repo root
10+
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[1]))
11+
12+
from compare_encoders import optimizer_to_data
13+
14+
15+
def canonical_representation(opt):
16+
data = optimizer_to_data(opt)
17+
18+
def simplify(attrs):
19+
result = {}
20+
for k, v in attrs.items():
21+
key = str(getattr(k, "name", k))
22+
result[key] = v.tolist() if hasattr(v, "tolist") else v
23+
return result
24+
25+
nodes = sorted(
26+
(int(t), tuple(sorted((k, str(v)) for k, v in simplify(attrs).items())))
27+
for t, attrs in zip(data.node_types.tolist(), data.node_attributes)
28+
)
29+
edges = set(map(tuple, data.edge_index.numpy().T))
30+
return nodes, edges
31+
32+
33+
@pytest.mark.parametrize(
34+
"pt_path", glob.glob(os.path.join("computation_graphs", "optimizers", "*.pt")))
35+
def test_graph_dict_matches_source(pt_path):
36+
base = os.path.splitext(pt_path)[0]
37+
py_path = base + ".py"
38+
spec = importlib.util.spec_from_file_location("mod", py_path)
39+
module = importlib.util.module_from_spec(spec)
40+
spec.loader.exec_module(module)
41+
42+
cls_candidates = [v for v in module.__dict__.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)]
43+
assert len(cls_candidates) == 1, f"Could not find optimizer class in {py_path}"
44+
cls = cls_candidates[0]
45+
46+
scripted_py = torch.jit.script(cls())
47+
scripted_pt = torch.jit.load(pt_path)
48+
49+
nodes_py, edges_py = canonical_representation(scripted_py)
50+
nodes_pt, edges_pt = canonical_representation(scripted_pt)
51+
52+
assert nodes_py == nodes_pt
53+
assert edges_py == edges_pt

0 commit comments

Comments
 (0)