diff --git a/flowrep/models/parsers/dependency_parser.py b/flowrep/models/parsers/dependency_parser.py index 189b28e4..4c605185 100644 --- a/flowrep/models/parsers/dependency_parser.py +++ b/flowrep/models/parsers/dependency_parser.py @@ -4,7 +4,7 @@ from pyiron_snippets import versions -from flowrep.models.parsers import object_scope, parser_helpers +from flowrep.models.parsers import import_parser, object_scope, parser_helpers CallDependencies = dict[versions.VersionInfo, Callable] @@ -43,10 +43,13 @@ def get_call_dependencies( return call_dependencies visited.add(func_fqn) - scope = object_scope.get_scope(func) tree = parser_helpers.get_ast_function_node(func) collector = CallCollector() collector.visit(tree) + local_modules = import_parser.build_scope(collector.imports, collector.import_froms) + scope = object_scope.get_scope(func) + for name, obj in local_modules.items(): + scope.register(name=name, obj=obj) for call in collector.calls: try: @@ -105,7 +108,17 @@ def split_by_version_availability( class CallCollector(ast.NodeVisitor): def __init__(self): self.calls: list[ast.expr] = [] + self.imports: list[ast.Import] = [] + self.import_froms: list[ast.ImportFrom] = [] def visit_Call(self, node: ast.Call) -> None: self.calls.append(node.func) self.generic_visit(node) + + def visit_Import(self, node: ast.Import) -> None: + self.imports.append(node) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + self.import_froms.append(node) + self.generic_visit(node) diff --git a/flowrep/models/parsers/import_parser.py b/flowrep/models/parsers/import_parser.py new file mode 100644 index 00000000..b9c4c86c --- /dev/null +++ b/flowrep/models/parsers/import_parser.py @@ -0,0 +1,47 @@ +import ast +import importlib + + +def build_scope( + imports: list[ast.Import] | None = None, + import_froms: list[ast.ImportFrom] | None = None, +) -> dict: + """ + Build a scope dictionary from a list of `import` and `from ... import ...` statements. + + Args: + imports (list | None): A list of `ast.Import` nodes. + import_froms (list | None): A list of `ast.ImportFrom` nodes. + + Returns: + dict: A dictionary representing the scope with imported modules and objects. + """ + scope = {} + + imports = imports or [] + import_froms = import_froms or [] + + # Handle `import` statements + for imp in imports: + for alias in imp.names: + asname = alias.asname or alias.name + module = importlib.import_module(alias.name) + scope[asname] = module + + # Handle `from ... import ...` statements + for imp_from in import_froms: + level = imp_from.level + # Dynamically import the module (absolute or relative) + if imp_from.module is None or level > 0: + raise ValueError( + f"Relative imports are not supported in dependency parsing. " + f"Encountered importing from {imp_from.module}." + ) + module = importlib.import_module(imp_from.module) + for alias in imp_from.names: + name = alias.name + asname = alias.asname or name + obj = getattr(module, name) + scope[asname] = obj + + return scope diff --git a/tests/unit/models/parsers/test_dependency_parser.py b/tests/unit/models/parsers/test_dependency_parser.py index d4b69cf9..7814d566 100644 --- a/tests/unit/models/parsers/test_dependency_parser.py +++ b/tests/unit/models/parsers/test_dependency_parser.py @@ -86,6 +86,21 @@ def _fqns(deps: dependency_parser.CallDependencies) -> set[str]: return {info.fully_qualified_name for info in deps} +def _local_imports(x): + import sys as s + from math import sqrt + + a = s.getsizeof(x) + return sqrt(a) + + +def _import_from_sibling(x, y): + from .test_for_parser import pair + + a, b = pair(x, y) + return a, b + + class TestGetCallDependencies(unittest.TestCase): """Tests for :func:`dependency_parser.get_call_dependencies`.""" @@ -171,6 +186,18 @@ def test_non_callable_resolved_symbol_is_skipped(self): deps = dependency_parser.get_call_dependencies(_calls_non_callable) self.assertIsInstance(deps, dict) + def test_local_imports_included(self): + deps = dependency_parser.get_call_dependencies(_local_imports) + fqns = _fqns(deps) + self.assertIn("sys.getsizeof", fqns) + self.assertIn("math.sqrt", fqns) + + def test_relative_import_raises(self): + with self.assertRaises(ValueError) as ctx: + dependency_parser.get_call_dependencies(_import_from_sibling) + self.assertIn("Relative imports are not supported", str(ctx.exception)) + self.assertIn("test_for_parser", str(ctx.exception)) + class TestSplitByVersionAvailability(unittest.TestCase): """Tests for :func:`dependency_parser.split_by_version_availability`."""