diff --git a/flowrep/models/parsers/object_scope.py b/flowrep/models/parsers/object_scope.py index bcae785f..54dc8edc 100644 --- a/flowrep/models/parsers/object_scope.py +++ b/flowrep/models/parsers/object_scope.py @@ -24,28 +24,57 @@ def get_scope(func: FunctionType) -> ScopeProxy: return ScopeProxy(inspect.getmodule(func).__dict__ | vars(builtins)) +def resolve_attribute_to_object(attribute: str, scope: ScopeProxy | object) -> object: + """ + Resolve a dot-separated attribute string to the actual object it references in the + given scope. For example, if attribute is "os.path.join", this function will + return the actual join function from the os.path module. + + Args: + attribute: A dot-separated string representing the attribute to resolve. + scope: The scope in which to resolve the attribute. This can be a ScopeProxy + or any object that supports attribute access. + + Returns: + The object that the attribute resolves to in the given scope. + """ + obj = None + try: + for attr in attribute.split("."): + obj = getattr(obj or scope, attr) + return obj + except AttributeError as e: + raise ValueError(f"Could not find attribute '{attr}' of {attribute}") from e + + def resolve_symbol_to_object( node: ast.expr, # Expecting a Name or Attribute here, and will otherwise TypeError scope: ScopeProxy | object, _chain: list[str] | None = None, ) -> object: - """ """ + """ + Recursively resolve a symbol in the form of an ast.Name or ast.Attribute to the + actual object it references in the given scope. The _chain parameter is used + internally to keep track of the attribute chain being resolved, and should not + be provided by the caller. + + Args: + node: An ast.expr representing the symbol to resolve. Expected to be an + ast.Name or ast.Attribute. + scope: The scope in which to resolve the symbol. This can be a ScopeProxy + or any object that supports attribute access. + + Returns: + The object that the symbol resolves to in the given scope. + """ _chain = _chain or [] - error_suffix = f" while attempting to resolve the symbol chain '{'.'.join(_chain)}'" if isinstance(node, ast.Name): - attr = node.id - try: - obj = getattr(scope, attr) - for attr in _chain: - obj = getattr(obj, attr) - return obj - except AttributeError as e: - raise ValueError(f"Could not find attribute '{attr}' {error_suffix}") from e + return resolve_attribute_to_object(".".join([node.id] + _chain), scope) elif isinstance(node, ast.Attribute): return resolve_symbol_to_object(node.value, scope, [node.attr] + _chain) else: raise TypeError( - f"Cannot resolve symbol {node} {error_suffix}. " - f"Expected an ast.Name or chain of ast.Attribute and ast.Name, but got " - f"{node}." + f"Cannot resolve symbol {node} while building the symbol chain " + f"'{'.'.join(_chain)}'. Expected an ast.Name or chain of ast.Attribute " + f"and ast.Name, but got {node}." ) diff --git a/tests/unit/models/parsers/test_object_scope.py b/tests/unit/models/parsers/test_object_scope.py index e442579f..7fc2bcb0 100644 --- a/tests/unit/models/parsers/test_object_scope.py +++ b/tests/unit/models/parsers/test_object_scope.py @@ -74,3 +74,12 @@ def test_unrecognized_node_raises(self): node = ast.Constant(value=42) with self.assertRaises(TypeError): object_scope.resolve_symbol_to_object(node, scope) + + def test_resolve_attribute_to_object(self): + scope = object_scope.ScopeProxy({"ast": ast}) + f = object_scope.resolve_attribute_to_object("ast.literal_eval", scope) + self.assertIs(f, ast.literal_eval) + + +if __name__ == "__main__": + unittest.main()