|
| 1 | +import pytest |
| 2 | +import sys |
| 3 | +import pathlib |
| 4 | +import importlib.util |
| 5 | +import glob |
| 6 | +import os |
| 7 | +import torch |
| 8 | + |
| 9 | +# allow imports from repo root |
| 10 | +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[1])) |
| 11 | + |
| 12 | +from compare_encoders import optimizer_to_data |
| 13 | + |
| 14 | + |
| 15 | +def canonical_representation(opt): |
| 16 | + data = optimizer_to_data(opt) |
| 17 | + |
| 18 | + def simplify(attrs): |
| 19 | + result = {} |
| 20 | + for k, v in attrs.items(): |
| 21 | + key = str(getattr(k, "name", k)) |
| 22 | + result[key] = v.tolist() if hasattr(v, "tolist") else v |
| 23 | + return result |
| 24 | + |
| 25 | + nodes = sorted( |
| 26 | + (int(t), tuple(sorted((k, str(v)) for k, v in simplify(attrs).items()))) |
| 27 | + for t, attrs in zip(data.node_types.tolist(), data.node_attributes) |
| 28 | + ) |
| 29 | + edges = set(map(tuple, data.edge_index.numpy().T)) |
| 30 | + return nodes, edges |
| 31 | + |
| 32 | + |
| 33 | +@pytest.mark.parametrize( |
| 34 | + "pt_path", glob.glob(os.path.join("computation_graphs", "optimizers", "*.pt"))) |
| 35 | +def test_graph_dict_matches_source(pt_path): |
| 36 | + base = os.path.splitext(pt_path)[0] |
| 37 | + py_path = base + ".py" |
| 38 | + spec = importlib.util.spec_from_file_location("mod", py_path) |
| 39 | + module = importlib.util.module_from_spec(spec) |
| 40 | + spec.loader.exec_module(module) |
| 41 | + |
| 42 | + cls_candidates = [v for v in module.__dict__.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)] |
| 43 | + assert len(cls_candidates) == 1, f"Could not find optimizer class in {py_path}" |
| 44 | + cls = cls_candidates[0] |
| 45 | + |
| 46 | + scripted_py = torch.jit.script(cls()) |
| 47 | + scripted_pt = torch.jit.load(pt_path) |
| 48 | + |
| 49 | + nodes_py, edges_py = canonical_representation(scripted_py) |
| 50 | + nodes_pt, edges_pt = canonical_representation(scripted_pt) |
| 51 | + |
| 52 | + assert nodes_py == nodes_pt |
| 53 | + assert edges_py == edges_pt |
0 commit comments