diff --git a/flowrep/models/parsers/import_parser.py b/flowrep/models/parsers/import_parser.py index b9c4c86c..3c5be777 100644 --- a/flowrep/models/parsers/import_parser.py +++ b/flowrep/models/parsers/import_parser.py @@ -1,11 +1,14 @@ import ast import importlib +from flowrep.models.parsers import object_scope +from flowrep.models.parsers.object_scope import ScopeProxy + def build_scope( imports: list[ast.Import] | None = None, import_froms: list[ast.ImportFrom] | None = None, -) -> dict: +) -> object_scope.ScopeProxy: """ Build a scope dictionary from a list of `import` and `from ... import ...` statements. @@ -14,9 +17,10 @@ def build_scope( import_froms (list | None): A list of `ast.ImportFrom` nodes. Returns: - dict: A dictionary representing the scope with imported modules and objects. + object_scope.ScopeProxy: A mutable mapping representing the scope with imported + modules and objects. """ - scope = {} + scope = ScopeProxy() imports = imports or [] import_froms = import_froms or [] @@ -26,7 +30,7 @@ def build_scope( for alias in imp.names: asname = alias.asname or alias.name module = importlib.import_module(alias.name) - scope[asname] = module + scope.register(name=asname, obj=module) # Handle `from ... import ...` statements for imp_from in import_froms: @@ -42,6 +46,6 @@ def build_scope( name = alias.name asname = alias.asname or name obj = getattr(module, name) - scope[asname] = obj + scope.register(name=asname, obj=obj) return scope diff --git a/flowrep/models/parsers/object_scope.py b/flowrep/models/parsers/object_scope.py index c8699cb5..eac31e9d 100644 --- a/flowrep/models/parsers/object_scope.py +++ b/flowrep/models/parsers/object_scope.py @@ -18,8 +18,12 @@ class ScopeProxy(MutableMapping[str, object]): By default, does not allow re-registration of existing symbols to new values. """ - def __init__(self, d: MutableMapping[str, object], allow_overwrite: bool = False): - self._d = {k: v for k, v in d.items()} + def __init__( + self, + d: MutableMapping[str, object] | None = None, + allow_overwrite: bool = False, + ): + self._d = {} if d is None else {k: v for k, v in d.items()} self.allow_overwrite = allow_overwrite def __getitem__(self, name: str): diff --git a/tests/unit/models/parsers/test_object_scope.py b/tests/unit/models/parsers/test_object_scope.py index 87009276..9b1d41ba 100644 --- a/tests/unit/models/parsers/test_object_scope.py +++ b/tests/unit/models/parsers/test_object_scope.py @@ -84,6 +84,10 @@ def test_len(self): proxy = object_scope.ScopeProxy({"a": 1, "b": 2}) self.assertEqual(len(proxy), 2) + def test_empty_construction(self): + proxy = object_scope.ScopeProxy() + self.assertEqual(len(proxy), 0) + def test_str(self): proxy = object_scope.ScopeProxy({"x": 1}) self.assertEqual(str(proxy), "{'x': 1}")