Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions flowrep/models/parsers/dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pyiron_snippets import versions

from flowrep.models.parsers import object_scope, parser_helpers
from flowrep.models.parsers import import_parser, object_scope, parser_helpers

CallDependencies = dict[versions.VersionInfo, Callable]

Expand Down Expand Up @@ -43,10 +43,13 @@ def get_call_dependencies(
return call_dependencies
visited.add(func_fqn)

scope = object_scope.get_scope(func)
tree = parser_helpers.get_ast_function_node(func)
collector = CallCollector()
collector.visit(tree)
local_modules = import_parser.build_scope(collector.imports, collector.import_froms)
scope = object_scope.get_scope(func)
for name, obj in local_modules.items():
scope.register(name=name, obj=obj)

for call in collector.calls:
try:
Expand Down Expand Up @@ -105,7 +108,17 @@ def split_by_version_availability(
class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls: list[ast.expr] = []
self.imports: list[ast.Import] = []
self.import_froms: list[ast.ImportFrom] = []

def visit_Call(self, node: ast.Call) -> None:
self.calls.append(node.func)
self.generic_visit(node)

def visit_Import(self, node: ast.Import) -> None:
self.imports.append(node)
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.import_froms.append(node)
self.generic_visit(node)
47 changes: 47 additions & 0 deletions flowrep/models/parsers/import_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import ast
import importlib


def build_scope(
imports: list[ast.Import] | None = None,
import_froms: list[ast.ImportFrom] | None = None,
) -> dict:
"""
Build a scope dictionary from a list of `import` and `from ... import ...` statements.

Args:
imports (list | None): A list of `ast.Import` nodes.
import_froms (list | None): A list of `ast.ImportFrom` nodes.

Returns:
dict: A dictionary representing the scope with imported modules and objects.
"""
scope = {}

imports = imports or []
import_froms = import_froms or []

# Handle `import` statements
for imp in imports:
for alias in imp.names:
asname = alias.asname or alias.name
module = importlib.import_module(alias.name)
scope[asname] = module

# Handle `from ... import ...` statements
for imp_from in import_froms:
level = imp_from.level
# Dynamically import the module (absolute or relative)
if imp_from.module is None or level > 0:
raise ValueError(
f"Relative imports are not supported in dependency parsing. "
f"Encountered importing from {imp_from.module}."
)
module = importlib.import_module(imp_from.module)
for alias in imp_from.names:
name = alias.name
asname = alias.asname or name
obj = getattr(module, name)
scope[asname] = obj

return scope
14 changes: 14 additions & 0 deletions tests/unit/models/parsers/test_dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def _fqns(deps: dependency_parser.CallDependencies) -> set[str]:
return {info.fully_qualified_name for info in deps}


def _local_imports(x):
import sys as s
from math import sqrt

a = s.getsizeof(x)
return sqrt(a)


class TestGetCallDependencies(unittest.TestCase):
"""Tests for :func:`dependency_parser.get_call_dependencies`."""

Expand Down Expand Up @@ -171,6 +179,12 @@ def test_non_callable_resolved_symbol_is_skipped(self):
deps = dependency_parser.get_call_dependencies(_calls_non_callable)
self.assertIsInstance(deps, dict)

def test_local_imports_included(self):
deps = dependency_parser.get_call_dependencies(_local_imports)
fqns = _fqns(deps)
self.assertIn("sys.getsizeof", fqns)
self.assertIn("math.sqrt", fqns)


class TestSplitByVersionAvailability(unittest.TestCase):
"""Tests for :func:`dependency_parser.split_by_version_availability`."""
Expand Down
Loading