diff --git a/flowrep/models/parsers/case_helpers.py b/flowrep/models/parsers/case_helpers.py index 6f1d2582..4a2197ec 100644 --- a/flowrep/models/parsers/case_helpers.py +++ b/flowrep/models/parsers/case_helpers.py @@ -77,8 +77,8 @@ def walk_branch( walker_factory: parser_protocol.WalkerFactory, ) -> WalkedBranch: fork = symbol_map.fork_scope() - w = walker_factory(fork) - w.walk(stmts, scope) + w = walker_factory(scope, fork) + w.walk(stmts) assigned = fork.assigned_symbols fork.produce_symbols(assigned) return WalkedBranch(label, w, assigned) diff --git a/flowrep/models/parsers/for_parser.py b/flowrep/models/parsers/for_parser.py index ef03e1db..9e360b6d 100644 --- a/flowrep/models/parsers/for_parser.py +++ b/flowrep/models/parsers/for_parser.py @@ -52,8 +52,8 @@ def parse_for_node( available_accumulators=symbol_map.declared_accumulators.copy(), ) - body_walker = walker_factory(body_symbol_map) - body_walker.walk(body_tree.body, scope) + body_walker = walker_factory(scope, body_symbol_map) + body_walker.walk(body_tree.body) consumed = body_walker.symbol_map.consumed_accumulators _validate_some_output_exists(consumed) diff --git a/flowrep/models/parsers/parser_protocol.py b/flowrep/models/parsers/parser_protocol.py index 6d2861b4..9e49861c 100644 --- a/flowrep/models/parsers/parser_protocol.py +++ b/flowrep/models/parsers/parser_protocol.py @@ -1,21 +1,23 @@ from __future__ import annotations import ast -from collections.abc import Callable, Collection -from types import FunctionType +from collections.abc import Callable from typing import Protocol, runtime_checkable from flowrep.models import edge_models from flowrep.models.nodes import union, workflow_model from flowrep.models.parsers import object_scope, symbol_scope -WalkerFactory = Callable[[symbol_scope.SymbolScope], "BodyWalker"] +WalkerFactory = Callable[ + [object_scope.ScopeProxy, symbol_scope.SymbolScope], "BodyWalker" +] @runtime_checkable class BodyWalker(Protocol): """What control flow parsers need to walk a sub-body.""" + scope: object_scope.ScopeProxy symbol_map: symbol_scope.SymbolScope nodes: union.Nodes @@ -34,31 +36,8 @@ def output_edges(self) -> edge_models.OutputEdges: ... @property def outputs(self) -> list[str]: ... - def visit(self, stmt: ast.stmt, scope: object_scope.ScopeProxy) -> None: ... + def visit(self, stmt: ast.AST) -> None: ... - def walk( - self, statements: list[ast.stmt], scope: object_scope.ScopeProxy - ) -> None: ... - - def handle_assign( - self, body: ast.Assign | ast.AnnAssign, scope: object_scope.ScopeProxy - ) -> None: ... - - def handle_for(self, tree: ast.For, scope: object_scope.ScopeProxy) -> None: ... - - def handle_if(self, tree: ast.If, scope: object_scope.ScopeProxy) -> None: ... - - def handle_try(self, tree: ast.Try, scope: object_scope.ScopeProxy) -> None: ... - - def handle_while(self, tree: ast.While, scope: object_scope.ScopeProxy) -> None: ... - - def handle_appending_to_accumulator(self, append_call: ast.Call) -> None: ... - - def handle_return( - self, - body: ast.Return, - func: FunctionType, - output_labels: Collection[str], - ) -> None: ... + def walk(self, statements: list[ast.stmt]) -> None: ... def build_model(self) -> workflow_model.WorkflowNode: ... diff --git a/flowrep/models/parsers/while_parser.py b/flowrep/models/parsers/while_parser.py index 9f3a4f45..989ea6b0 100644 --- a/flowrep/models/parsers/while_parser.py +++ b/flowrep/models/parsers/while_parser.py @@ -40,8 +40,8 @@ def parse_while_node( tree.test, scope, symbol_map, WHILE_CONDITION_LABEL ) - body_walker = walker_factory(symbol_map.fork_scope()) - body_walker.walk(tree.body, scope) + body_walker = walker_factory(scope, symbol_map.fork_scope()) + body_walker.walk(tree.body) reassigned_symbols = body_walker.symbol_map.reassigned_symbols _validate_some_output_exists(reassigned_symbols) diff --git a/flowrep/models/parsers/workflow_parser.py b/flowrep/models/parsers/workflow_parser.py index 6ce15043..75b892f5 100644 --- a/flowrep/models/parsers/workflow_parser.py +++ b/flowrep/models/parsers/workflow_parser.py @@ -1,7 +1,7 @@ import ast from collections.abc import Callable, Collection from types import FunctionType -from typing import Any, cast +from typing import cast from pyiron_snippets import versions @@ -20,8 +20,6 @@ while_parser, ) -SpecialHandlers = dict[type[ast.stmt], Callable[[Any, object_scope.ScopeProxy], None]] - def workflow( func: FunctionType | str | None = None, @@ -121,31 +119,19 @@ def parse_workflow( require_version=require_version, ) inputs = label_helpers.get_input_labels(func) - state = WorkflowParser( + state = _WorkflowFunctionParser( + object_scope.get_scope(func), symbol_scope.SymbolScope({p: edge_models.InputSource(port=p) for p in inputs}), fully_qualified_name=info.fully_qualified_name, version=info.version, + func=func, + output_labels=output_labels, ) tree = parser_helpers.get_ast_function_node(func) - found_return = False - - def handle_return(stmt: ast.Return, scope: object_scope.ScopeProxy): - nonlocal found_return - if found_return: - raise ValueError( - "Workflow python definitions must have exactly one return." - ) - found_return = True - state.handle_return(stmt, func, output_labels) + state.walk(skip_docstring(tree.body)) - state.walk( - skip_docstring(tree.body), - object_scope.get_scope(func), - special_handlers={ast.Return: handle_return}, - ) - - if not found_return: + if not state.found_return: raise ValueError("Workflow python definitions must have a return statement.") source_code = parser_helpers.get_available_source_code(func) @@ -165,7 +151,7 @@ def skip_docstring(body: list[ast.stmt]) -> list[ast.stmt]: ) -class WorkflowParser(parser_protocol.BodyWalker): +class WorkflowParser(ast.NodeVisitor, parser_protocol.BodyWalker): """ Aggregates state until there is enough data to successfully build the pydantic data model. @@ -180,10 +166,12 @@ class WorkflowParser(parser_protocol.BodyWalker): def __init__( self, + scope: object_scope.ScopeProxy, symbol_map: symbol_scope.SymbolScope, fully_qualified_name: str | None = None, version: str | None = None, ): + self.scope = scope self.symbol_map = symbol_map self.nodes: union.Nodes = {} self.fully_qualified_name = fully_qualified_name @@ -226,54 +214,24 @@ def build_model( source_code=source_code, ) - def visit(self, stmt: ast.stmt, scope: object_scope.ScopeProxy) -> None: - if isinstance(stmt, ast.Assign | ast.AnnAssign): - self.handle_assign(stmt, scope) - elif isinstance(stmt, ast.For): - self.handle_for(stmt, scope) - elif isinstance(stmt, ast.While): - self.handle_while(stmt, scope) - elif isinstance(stmt, ast.If): - self.handle_if(stmt, scope) - elif isinstance(stmt, ast.Try): - self.handle_try(stmt, scope) - elif isinstance(stmt, ast.Expr) and is_append_call(stmt.value): - self.handle_appending_to_accumulator(cast(ast.Call, stmt.value)) - else: - raise TypeError( - f"Workflow python definitions can only interpret assignments, a subset " - f"of flow control (for/while/if/try) and a return, but ast found " - f"{type(stmt)}" - ) + def walk(self, statements: list[ast.stmt]) -> None: + for statement in statements: + self.visit(statement) - def walk( - self, - statements: list[ast.stmt], - scope: object_scope.ScopeProxy, - *, - special_handlers: SpecialHandlers | None = None, - ) -> None: - for stmt in statements: - if special_handlers: - for ast_type, handler in special_handlers.items(): - if isinstance(stmt, ast_type): - handler(stmt, scope) - break - else: - self.visit(stmt, scope) - else: - self.visit(stmt, scope) - - def handle_assign( - self, body: ast.Assign | ast.AnnAssign, scope: object_scope.ScopeProxy - ): + def visit_Assign(self, stmt: ast.Assign) -> None: + self._handle_assign(stmt) + + def visit_AnnAssign(self, stmt: ast.AnnAssign) -> None: + self._handle_assign(stmt) + + def _handle_assign(self, body: ast.Assign | ast.AnnAssign): # Get returned symbols from the left-hand side lhs = body.targets[0] if isinstance(body, ast.Assign) else body.target new_symbols = parser_helpers.resolve_symbols_to_strings(lhs) rhs = body.value if isinstance(rhs, ast.Call): - child = atomic_parser.get_labeled_recipe(rhs, self.nodes.keys(), scope) + child = atomic_parser.get_labeled_recipe(rhs, self.nodes.keys(), self.scope) self.nodes[child.label] = child.node parser_helpers.consume_call_arguments(self.symbol_map, rhs, child) self.symbol_map.register(new_symbols, child) @@ -303,30 +261,85 @@ def _connect_node_to_enclosing_scope(self, label: str, node: union.NodeType): labeled_node = helper_models.LabeledNode(label=label, node=node) self.symbol_map.register(new_symbols=node.outputs, child=labeled_node) - def handle_for(self, tree: ast.For, scope: object_scope.ScopeProxy) -> None: + def visit_For(self, tree: ast.For) -> None: for_node = for_parser.parse_for_node( - tree, scope, self.symbol_map, WorkflowParser + tree, self.scope, self.symbol_map, WorkflowParser ) # Accumulators consumed by the for body are no longer available here self.symbol_map.declared_accumulators -= set(for_node.outputs) self._digest_flow_control("for", for_node) - def handle_while(self, tree: ast.While, scope: object_scope.ScopeProxy) -> None: + def visit_While(self, tree: ast.While) -> None: while_node = while_parser.parse_while_node( - tree, scope, self.symbol_map, WorkflowParser + tree, self.scope, self.symbol_map, WorkflowParser ) self._digest_flow_control("while", while_node) - def handle_if(self, tree: ast.If, scope: object_scope.ScopeProxy) -> None: - if_node = if_parser.parse_if_node(tree, scope, self.symbol_map, WorkflowParser) + def visit_If(self, tree: ast.If) -> None: + if_node = if_parser.parse_if_node( + tree, self.scope, self.symbol_map, WorkflowParser + ) self._digest_flow_control("if", if_node) - def handle_try(self, tree: ast.Try, scope: object_scope.ScopeProxy) -> None: + def visit_Try(self, tree: ast.Try) -> None: try_node = try_parser.parse_try_node( - tree, scope, self.symbol_map, WorkflowParser + tree, self.scope, self.symbol_map, WorkflowParser ) self._digest_flow_control("try", try_node) + def visit_Expr(self, stmt: ast.Expr) -> None: + if is_append_call(stmt.value): + self._handle_appending_to_accumulator(cast(ast.Call, stmt.value)) + else: + self.generic_visit(stmt) + + def _handle_appending_to_accumulator(self, append_call: ast.Call) -> None: + used_accumulator = cast( + ast.Name, cast(ast.Attribute, append_call.func).value + ).id + appended_symbol = cast(ast.Name, append_call.args[0]).id + self.symbol_map.use_accumulator(used_accumulator, appended_symbol) + appended_source = self.symbol_map[appended_symbol] + if isinstance(appended_source, edge_models.SourceHandle): + self.symbol_map.produce(appended_symbol) + + def generic_visit(self, stmt: ast.AST) -> None: + raise TypeError( + f"Workflow python definitions can only interpret a subset of assignments, " + f"and flow controls (for/while/if/try) and (when parsing a function " + f"definition) a return, but ast found " + f"{type(stmt)}" + ) + + +class _WorkflowFunctionParser(WorkflowParser): + def __init__( + self, + scope: object_scope.ScopeProxy, + symbol_map: symbol_scope.SymbolScope, + *, + fully_qualified_name: str | None = None, + version: str | None = None, + func: FunctionType, + output_labels: Collection[str], + ): + super().__init__(scope, symbol_map, fully_qualified_name, version) + self._func = func + self._output_labels = output_labels + self._found_return = False + + @property + def found_return(self) -> bool: + return self._found_return + + def visit_Return(self, stmt: ast.Return) -> None: + if self._found_return: + raise ValueError( + "Workflow python definitions must have exactly one return." + ) + self._found_return = True + self.handle_return(stmt, self._func, self._output_labels) + def handle_return( self, body: ast.Return, @@ -373,16 +386,6 @@ def handle_return( ) self.symbol_map.produce(port, symbol) - def handle_appending_to_accumulator(self, append_call: ast.Call) -> None: - used_accumulator = cast( - ast.Name, cast(ast.Attribute, append_call.func).value - ).id - appended_symbol = cast(ast.Name, append_call.args[0]).id - self.symbol_map.use_accumulator(used_accumulator, appended_symbol) - appended_source = self.symbol_map[appended_symbol] - if isinstance(appended_source, edge_models.SourceHandle): - self.symbol_map.produce(appended_symbol) - def is_append_call(node: ast.expr | ast.Expr) -> bool: """Check if node is an append call to a known accumulator.""" diff --git a/tests/unit/models/parsers/test_workflow_parser.py b/tests/unit/models/parsers/test_workflow_parser.py index e9339104..875185c1 100644 --- a/tests/unit/models/parsers/test_workflow_parser.py +++ b/tests/unit/models/parsers/test_workflow_parser.py @@ -5,6 +5,7 @@ from flowrep.models.nodes import atomic_model, workflow_model from flowrep.models.parsers import ( atomic_parser, + object_scope, parser_protocol, symbol_scope, workflow_parser, @@ -145,7 +146,9 @@ class MyClass: class TestParseWorkflowBasic(unittest.TestCase): def test_protocol_fulfillment(self): self.assertIsInstance( - workflow_parser.WorkflowParser(symbol_scope.SymbolScope({})), + workflow_parser.WorkflowParser( + object_scope.ScopeProxy({}), symbol_scope.SymbolScope({}) + ), parser_protocol.BodyWalker, ) @@ -609,7 +612,9 @@ def outer_wf(a): ) def test_fqn_defaults_to_none_on_raw_parser(self): - parser = workflow_parser.WorkflowParser(symbol_scope.SymbolScope({})) + parser = workflow_parser.WorkflowParser( + object_scope.ScopeProxy({}), symbol_scope.SymbolScope({}) + ) self.assertIsNone(parser.fully_qualified_name) def test_fqn_roundtrips_through_serialization(self):