Skip to content

Commit 21eeb49

Browse files
authored
Merge pull request #156 from pyiron/crawler-edits
Crawler edits
2 parents ac96a5e + ad8f016 commit 21eeb49

2 files changed

Lines changed: 298 additions & 24 deletions

File tree

flowrep/crawler.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,104 @@
11
import ast
22
import inspect
33
import types
4-
from typing import Any
4+
from collections.abc import Callable
55

66
from pyiron_snippets import versions
77

8+
from flowrep.models.parsers import object_scope, parser_helpers
89

9-
class CallCollector(ast.NodeVisitor):
10-
def __init__(self):
11-
self.calls = []
10+
CallDependencies = dict[versions.VersionInfo, list[Callable]]
1211

13-
def visit_Call(self, node):
14-
self.calls.append(node.func)
15-
self.generic_visit(node)
1612

13+
def get_call_dependencies(
14+
func: types.FunctionType,
15+
version_scraping: versions.VersionScrapingMap | None = None,
16+
_call_dependencies: CallDependencies | None = None,
17+
_visited: set[str] | None = None,
18+
) -> CallDependencies:
19+
"""
20+
Recursively collect all callable dependencies of *func* via AST introspection.
1721
18-
def _build_global_namespace(func) -> dict[str, object]:
19-
return dict(func.__globals__)
22+
Each dependency is keyed by its :class:`~pyiron_snippets.versions.VersionInfo`
23+
and maps to the list of concrete callables sharing that identity. The search
24+
is depth-first: for every resolved callee that is a
25+
:class:`~types.FunctionType` (i.e. has inspectable source), the function
26+
recurses into the callee's own scope.
2027
28+
Args:
29+
func: The function whose call-graph to analyse.
30+
version_scraping (VersionScrapingMap | None): Since some modules may store
31+
their version in other ways, this provides an optional map between module
32+
names and callables to leverage for extracting that module's version.
33+
_call_dependencies: Accumulator for recursive calls — do not pass manually.
34+
_visited: Fully-qualified names already traversed — do not pass manually.
2135
22-
def _resolve_ast_node(node: ast.AST, namespace: dict[str, object]) -> Any:
36+
Returns:
37+
A mapping from :class:`VersionInfo` to the callables found under that
38+
identity across the entire (sub-)tree.
2339
"""
24-
Resolve an AST node to its corresponding object in the given namespace.
40+
call_dependencies: CallDependencies = _call_dependencies or {}
41+
visited: set[str] = _visited or set()
42+
43+
func_fqn = versions.VersionInfo.of(func).fully_qualified_name
44+
if func_fqn in visited:
45+
return call_dependencies
46+
visited.add(func_fqn)
47+
48+
scope = object_scope.get_scope(func)
49+
tree = parser_helpers.get_ast_function_node(func)
50+
collector = CallCollector()
51+
collector.visit(tree)
52+
53+
for call in collector.calls:
54+
try:
55+
caller = object_scope.resolve_symbol_to_object(call, scope)
56+
except (ValueError, TypeError):
57+
continue
58+
59+
if not callable(caller):
60+
continue
61+
62+
info = versions.VersionInfo.of(caller, version_scraping=version_scraping)
63+
call_dependencies.setdefault(info, []).append(caller)
64+
65+
# Depth-first search on dependencies — only possible when we have source
66+
if isinstance(caller, types.FunctionType):
67+
get_call_dependencies(caller, version_scraping, call_dependencies, visited)
68+
69+
return call_dependencies
70+
71+
72+
def split_by_version_availability(
73+
call_dependencies: CallDependencies,
74+
) -> tuple[CallDependencies, CallDependencies]:
75+
"""
76+
Partition *call_dependencies* by whether a version string is available.
2577
2678
Args:
27-
node (ast.AST): The AST node to resolve.
28-
namespace (dict[str, object]): The namespace to use for resolution.
79+
call_dependencies: The dependency map to partition.
2980
3081
Returns:
31-
Any: The resolved object, or None if it cannot be resolved.
82+
A ``(has_version, no_version)`` tuple of :data:`CallDependencies` dicts.
3283
"""
33-
if isinstance(node, ast.Name):
34-
return namespace.get(node.id)
84+
has_version: CallDependencies = {}
85+
no_version: CallDependencies = {}
86+
for info, dependents in call_dependencies.items():
87+
if info.version is None:
88+
no_version[info] = dependents
89+
else:
90+
has_version[info] = dependents
3591

36-
if isinstance(node, ast.Attribute):
37-
base = _resolve_ast_node(node.value, namespace)
38-
if base is None:
39-
return None
40-
return getattr(base, node.attr, None)
92+
return has_version, no_version
4193

42-
return None
94+
95+
class CallCollector(ast.NodeVisitor):
96+
def __init__(self):
97+
self.calls: list[ast.expr] = []
98+
99+
def visit_Call(self, node: ast.Call) -> None:
100+
self.calls.append(node.func)
101+
self.generic_visit(node)
43102

44103

45104
def extract_called_functions(func: types.FunctionType) -> set[types.FunctionType]:
@@ -58,11 +117,11 @@ def extract_called_functions(func: types.FunctionType) -> set[types.FunctionType
58117
collector = CallCollector()
59118
collector.visit(tree)
60119

61-
namespace = _build_global_namespace(func)
120+
namespace = object_scope.get_scope(func)
62121
resolved = set()
63122

64123
for call_node in collector.calls:
65-
obj = _resolve_ast_node(call_node, namespace)
124+
obj = object_scope.resolve_symbol_to_object(call_node, namespace)
66125
if callable(obj):
67126
resolved.add(obj)
68127

tests/unit/test_crawler.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22
import unittest
33

4+
from pyiron_snippets import versions
5+
46
from flowrep import crawler
57

68

@@ -19,6 +21,77 @@ def more_op(a, b):
1921
return c
2022

2123

24+
# ---------------------------------------------------------------------------
25+
# Helper functions defined at module level so they have inspectable source,
26+
# a proper __module__, and a stable __qualname__.
27+
# ---------------------------------------------------------------------------
28+
29+
30+
def _leaf():
31+
return 42
32+
33+
34+
def _single_call():
35+
return _leaf()
36+
37+
38+
def _diamond_a():
39+
return _leaf()
40+
41+
42+
def _diamond_b():
43+
return _leaf()
44+
45+
46+
def _diamond_root():
47+
_diamond_a()
48+
_diamond_b()
49+
50+
51+
def _mutual_b():
52+
return _leaf()
53+
54+
55+
def _mutual_a():
56+
return _mutual_b()
57+
58+
59+
# Mutual recursion to exercise cycle detection.
60+
def _cycle_a():
61+
return _cycle_b() # noqa: F821 — defined below
62+
63+
64+
def _cycle_b():
65+
return _cycle_a()
66+
67+
68+
def _no_calls():
69+
x = 1 + 2
70+
return x
71+
72+
73+
def _calls_len():
74+
return len([1, 2, 3])
75+
76+
77+
def _nested_call():
78+
return _single_call()
79+
80+
81+
def _multi_call():
82+
a = _leaf()
83+
b = _leaf()
84+
return a + b
85+
86+
87+
def _fqn(func) -> str:
88+
return versions.VersionInfo.of(func).fully_qualified_name
89+
90+
91+
def _fqns(deps: crawler.CallDependencies) -> set[str]:
92+
return {info.fully_qualified_name for info in deps}
93+
94+
2295
class TestCrawler(unittest.TestCase):
2396
def test_analyze_function_dependencies(self):
2497
loc, ext = crawler.analyze_function_dependencies(op)
@@ -39,5 +112,147 @@ def test_extract_called_functions(self):
39112
self.assertEqual(called, {op})
40113

41114

115+
class TestGetCallDependencies(unittest.TestCase):
116+
"""Tests for :func:`crawler.get_call_dependencies`."""
117+
118+
# --- basic behaviour ---
119+
120+
def test_no_calls_returns_empty(self):
121+
deps = crawler.get_call_dependencies(_no_calls)
122+
self.assertEqual(deps, {})
123+
124+
def test_single_direct_call(self):
125+
deps = crawler.get_call_dependencies(_single_call)
126+
self.assertIn(_fqn(_leaf), _fqns(deps))
127+
128+
def test_transitive_dependencies(self):
129+
deps = crawler.get_call_dependencies(_nested_call)
130+
fqns = _fqns(deps)
131+
# Should find both _single_call and _leaf
132+
self.assertIn(_fqn(_single_call), fqns)
133+
self.assertIn(_fqn(_leaf), fqns)
134+
135+
def test_diamond_dependency_no_duplicate_keys(self):
136+
"""
137+
_diamond_root -> _diamond_a -> _leaf AND _diamond_root -> _diamond_b -> _leaf.
138+
_leaf's VersionInfo should appear exactly once as a key.
139+
"""
140+
deps = crawler.get_call_dependencies(_diamond_root)
141+
matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)]
142+
self.assertEqual(len(matching), 1)
143+
144+
# --- cycle safety ---
145+
146+
def test_cycle_does_not_recurse_infinitely(self):
147+
# Should terminate without RecursionError
148+
deps = crawler.get_call_dependencies(_cycle_a)
149+
self.assertIn(_fqn(_cycle_b), _fqns(deps))
150+
151+
# --- builtins / non-FunctionType callables ---
152+
153+
def test_builtin_callable_included(self):
154+
deps = crawler.get_call_dependencies(_calls_len)
155+
self.assertIn(_fqn(len), _fqns(deps))
156+
157+
# --- accumulator semantics ---
158+
159+
def test_same_function_called_twice_appears_multiple_times_in_list(self):
160+
deps = crawler.get_call_dependencies(_multi_call)
161+
matching = [info for info in deps if info.fully_qualified_name == _fqn(_leaf)]
162+
self.assertEqual(len(matching), 1, "single key expected")
163+
# The list value should have two entries (one per call-site)
164+
self.assertEqual(len(deps[matching[0]]), 2)
165+
166+
def test_returns_dict_type(self):
167+
deps = crawler.get_call_dependencies(_leaf)
168+
self.assertIsInstance(deps, dict)
169+
170+
171+
class TestSplitByVersionAvailability(unittest.TestCase):
172+
"""Tests for :func:`crawler.split_by_version_availability`."""
173+
174+
@staticmethod
175+
def _make_info(
176+
module: str, qualname: str, version: str | None = None
177+
) -> versions.VersionInfo:
178+
return versions.VersionInfo(
179+
module=module,
180+
qualname=qualname,
181+
version=version,
182+
)
183+
184+
def test_empty_input(self):
185+
has, no = crawler.split_by_version_availability({})
186+
self.assertEqual(has, {})
187+
self.assertEqual(no, {})
188+
189+
def test_all_versioned(self):
190+
info_a = self._make_info("pkg", "a", "1.0")
191+
info_b = self._make_info("pkg", "b", "2.0")
192+
deps: crawler.CallDependencies = {info_a: [_leaf], info_b: [_leaf]}
193+
194+
has, no = crawler.split_by_version_availability(deps)
195+
self.assertEqual(len(has), 2)
196+
self.assertEqual(len(no), 0)
197+
198+
def test_all_unversioned(self):
199+
info_a = self._make_info("local", "a")
200+
info_b = self._make_info("local", "b")
201+
deps: crawler.CallDependencies = {info_a: [_leaf], info_b: [_leaf]}
202+
203+
has, no = crawler.split_by_version_availability(deps)
204+
self.assertEqual(len(has), 0)
205+
self.assertEqual(len(no), 2)
206+
207+
def test_mixed(self):
208+
versioned = self._make_info("pkg", "x", "3.1")
209+
unversioned = self._make_info("local", "y")
210+
deps: crawler.CallDependencies = {
211+
versioned: [_leaf],
212+
unversioned: [_single_call],
213+
}
214+
215+
has, no = crawler.split_by_version_availability(deps)
216+
self.assertIn(versioned, has)
217+
self.assertIn(unversioned, no)
218+
self.assertNotIn(versioned, no)
219+
self.assertNotIn(unversioned, has)
220+
221+
def test_preserves_callable_lists(self):
222+
info = self._make_info("pkg", "z", "1.0")
223+
callables = [_leaf, _single_call, _no_calls]
224+
deps: crawler.CallDependencies = {info: callables}
225+
226+
has, _ = crawler.split_by_version_availability(deps)
227+
self.assertIs(has[info], callables)
228+
229+
def test_partition_is_exhaustive_and_disjoint(self):
230+
"""Every key in the input appears in exactly one partition."""
231+
infos = [
232+
self._make_info("pkg", "a", "1.0"),
233+
self._make_info("local", "b"),
234+
self._make_info("pkg", "c", "0.1"),
235+
self._make_info("local", "d"),
236+
]
237+
deps: crawler.CallDependencies = {info: [_leaf] for info in infos}
238+
239+
has, no = crawler.split_by_version_availability(deps)
240+
self.assertEqual(set(has) | set(no), set(deps))
241+
self.assertTrue(set(has).isdisjoint(set(no)))
242+
243+
def test_version_none_vs_empty_string(self):
244+
"""Only ``None`` counts as unversioned; an empty string is still 'versioned'."""
245+
none_version = self._make_info("local", "f", None)
246+
empty_version = self._make_info("local", "g", "")
247+
deps: crawler.CallDependencies = {
248+
none_version: [_leaf],
249+
empty_version: [_leaf],
250+
}
251+
252+
has, no = crawler.split_by_version_availability(deps)
253+
self.assertIn(none_version, no)
254+
self.assertIn(empty_version, has)
255+
256+
42257
if __name__ == "__main__":
43258
unittest.main()

0 commit comments

Comments
 (0)