Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flowrep/models/parsers/case_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def walk_branch(
walker_factory: parser_protocol.WalkerFactory,
) -> WalkedBranch:
fork = symbol_map.fork_scope()
w = walker_factory(fork)
w.walk(stmts, scope)
w = walker_factory(scope, fork)
w.walk(stmts)
assigned = fork.assigned_symbols
fork.produce_symbols(assigned)
return WalkedBranch(label, w, assigned)
Expand Down
4 changes: 2 additions & 2 deletions flowrep/models/parsers/for_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def parse_for_node(
available_accumulators=symbol_map.declared_accumulators.copy(),
)

body_walker = walker_factory(body_symbol_map)
body_walker.walk(body_tree.body, scope)
body_walker = walker_factory(scope, body_symbol_map)
body_walker.walk(body_tree.body)
consumed = body_walker.symbol_map.consumed_accumulators

_validate_some_output_exists(consumed)
Expand Down
35 changes: 7 additions & 28 deletions flowrep/models/parsers/parser_protocol.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from __future__ import annotations

import ast
from collections.abc import Callable, Collection
from types import FunctionType
from collections.abc import Callable
from typing import Protocol, runtime_checkable

from flowrep.models import edge_models
from flowrep.models.nodes import union, workflow_model
from flowrep.models.parsers import object_scope, symbol_scope

WalkerFactory = Callable[[symbol_scope.SymbolScope], "BodyWalker"]
WalkerFactory = Callable[
[object_scope.ScopeProxy, symbol_scope.SymbolScope], "BodyWalker"
]


@runtime_checkable
class BodyWalker(Protocol):
"""What control flow parsers need to walk a sub-body."""

scope: object_scope.ScopeProxy
symbol_map: symbol_scope.SymbolScope
nodes: union.Nodes

Expand All @@ -34,31 +36,8 @@ def output_edges(self) -> edge_models.OutputEdges: ...
@property
def outputs(self) -> list[str]: ...

def visit(self, stmt: ast.stmt, scope: object_scope.ScopeProxy) -> None: ...
def visit(self, stmt: ast.AST) -> None: ...

def walk(
self, statements: list[ast.stmt], scope: object_scope.ScopeProxy
) -> None: ...

def handle_assign(
self, body: ast.Assign | ast.AnnAssign, scope: object_scope.ScopeProxy
) -> None: ...

def handle_for(self, tree: ast.For, scope: object_scope.ScopeProxy) -> None: ...

def handle_if(self, tree: ast.If, scope: object_scope.ScopeProxy) -> None: ...

def handle_try(self, tree: ast.Try, scope: object_scope.ScopeProxy) -> None: ...

def handle_while(self, tree: ast.While, scope: object_scope.ScopeProxy) -> None: ...

def handle_appending_to_accumulator(self, append_call: ast.Call) -> None: ...

def handle_return(
self,
body: ast.Return,
func: FunctionType,
output_labels: Collection[str],
) -> None: ...
def walk(self, statements: list[ast.stmt]) -> None: ...

def build_model(self) -> workflow_model.WorkflowNode: ...
4 changes: 2 additions & 2 deletions flowrep/models/parsers/while_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def parse_while_node(
tree.test, scope, symbol_map, WHILE_CONDITION_LABEL
)

body_walker = walker_factory(symbol_map.fork_scope())
body_walker.walk(tree.body, scope)
body_walker = walker_factory(scope, symbol_map.fork_scope())
body_walker.walk(tree.body)
reassigned_symbols = body_walker.symbol_map.reassigned_symbols

_validate_some_output_exists(reassigned_symbols)
Expand Down
165 changes: 84 additions & 81 deletions flowrep/models/parsers/workflow_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from collections.abc import Callable, Collection
from types import FunctionType
from typing import Any, cast
from typing import cast

from pyiron_snippets import versions

Expand All @@ -20,8 +20,6 @@
while_parser,
)

SpecialHandlers = dict[type[ast.stmt], Callable[[Any, object_scope.ScopeProxy], None]]


def workflow(
func: FunctionType | str | None = None,
Expand Down Expand Up @@ -121,31 +119,19 @@ def parse_workflow(
require_version=require_version,
)
inputs = label_helpers.get_input_labels(func)
state = WorkflowParser(
state = _WorkflowFunctionParser(
object_scope.get_scope(func),
symbol_scope.SymbolScope({p: edge_models.InputSource(port=p) for p in inputs}),
fully_qualified_name=info.fully_qualified_name,
version=info.version,
func=func,
output_labels=output_labels,
)
tree = parser_helpers.get_ast_function_node(func)

found_return = False

def handle_return(stmt: ast.Return, scope: object_scope.ScopeProxy):
nonlocal found_return
if found_return:
raise ValueError(
"Workflow python definitions must have exactly one return."
)
found_return = True
state.handle_return(stmt, func, output_labels)
state.walk(skip_docstring(tree.body))

state.walk(
skip_docstring(tree.body),
object_scope.get_scope(func),
special_handlers={ast.Return: handle_return},
)

if not found_return:
if not state.found_return:
raise ValueError("Workflow python definitions must have a return statement.")

source_code = parser_helpers.get_available_source_code(func)
Expand All @@ -165,7 +151,7 @@ def skip_docstring(body: list[ast.stmt]) -> list[ast.stmt]:
)


class WorkflowParser(parser_protocol.BodyWalker):
class WorkflowParser(ast.NodeVisitor, parser_protocol.BodyWalker):
"""
Aggregates state until there is enough data to successfully build the pydantic
data model.
Expand All @@ -180,10 +166,12 @@ class WorkflowParser(parser_protocol.BodyWalker):

def __init__(
self,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
fully_qualified_name: str | None = None,
version: str | None = None,
):
self.scope = scope
self.symbol_map = symbol_map
self.nodes: union.Nodes = {}
self.fully_qualified_name = fully_qualified_name
Expand Down Expand Up @@ -226,54 +214,24 @@ def build_model(
source_code=source_code,
)

def visit(self, stmt: ast.stmt, scope: object_scope.ScopeProxy) -> None:
if isinstance(stmt, ast.Assign | ast.AnnAssign):
self.handle_assign(stmt, scope)
elif isinstance(stmt, ast.For):
self.handle_for(stmt, scope)
elif isinstance(stmt, ast.While):
self.handle_while(stmt, scope)
elif isinstance(stmt, ast.If):
self.handle_if(stmt, scope)
elif isinstance(stmt, ast.Try):
self.handle_try(stmt, scope)
elif isinstance(stmt, ast.Expr) and is_append_call(stmt.value):
self.handle_appending_to_accumulator(cast(ast.Call, stmt.value))
else:
raise TypeError(
f"Workflow python definitions can only interpret assignments, a subset "
f"of flow control (for/while/if/try) and a return, but ast found "
f"{type(stmt)}"
)
def walk(self, statements: list[ast.stmt]) -> None:
for statement in statements:
self.visit(statement)

def walk(
self,
statements: list[ast.stmt],
scope: object_scope.ScopeProxy,
*,
special_handlers: SpecialHandlers | None = None,
) -> None:
for stmt in statements:
if special_handlers:
for ast_type, handler in special_handlers.items():
if isinstance(stmt, ast_type):
handler(stmt, scope)
break
else:
self.visit(stmt, scope)
else:
self.visit(stmt, scope)

def handle_assign(
self, body: ast.Assign | ast.AnnAssign, scope: object_scope.ScopeProxy
):
def visit_Assign(self, stmt: ast.Assign) -> None:
self._handle_assign(stmt)

def visit_AnnAssign(self, stmt: ast.AnnAssign) -> None:
self._handle_assign(stmt)

def _handle_assign(self, body: ast.Assign | ast.AnnAssign):
# Get returned symbols from the left-hand side
lhs = body.targets[0] if isinstance(body, ast.Assign) else body.target
new_symbols = parser_helpers.resolve_symbols_to_strings(lhs)

rhs = body.value
if isinstance(rhs, ast.Call):
child = atomic_parser.get_labeled_recipe(rhs, self.nodes.keys(), scope)
child = atomic_parser.get_labeled_recipe(rhs, self.nodes.keys(), self.scope)
self.nodes[child.label] = child.node
parser_helpers.consume_call_arguments(self.symbol_map, rhs, child)
self.symbol_map.register(new_symbols, child)
Expand Down Expand Up @@ -303,30 +261,85 @@ def _connect_node_to_enclosing_scope(self, label: str, node: union.NodeType):
labeled_node = helper_models.LabeledNode(label=label, node=node)
self.symbol_map.register(new_symbols=node.outputs, child=labeled_node)

def handle_for(self, tree: ast.For, scope: object_scope.ScopeProxy) -> None:
def visit_For(self, tree: ast.For) -> None:
for_node = for_parser.parse_for_node(
tree, scope, self.symbol_map, WorkflowParser
tree, self.scope, self.symbol_map, WorkflowParser
)
# Accumulators consumed by the for body are no longer available here
self.symbol_map.declared_accumulators -= set(for_node.outputs)
self._digest_flow_control("for", for_node)

def handle_while(self, tree: ast.While, scope: object_scope.ScopeProxy) -> None:
def visit_While(self, tree: ast.While) -> None:
while_node = while_parser.parse_while_node(
tree, scope, self.symbol_map, WorkflowParser
tree, self.scope, self.symbol_map, WorkflowParser
)
self._digest_flow_control("while", while_node)

def handle_if(self, tree: ast.If, scope: object_scope.ScopeProxy) -> None:
if_node = if_parser.parse_if_node(tree, scope, self.symbol_map, WorkflowParser)
def visit_If(self, tree: ast.If) -> None:
if_node = if_parser.parse_if_node(
tree, self.scope, self.symbol_map, WorkflowParser
)
self._digest_flow_control("if", if_node)

def handle_try(self, tree: ast.Try, scope: object_scope.ScopeProxy) -> None:
def visit_Try(self, tree: ast.Try) -> None:
try_node = try_parser.parse_try_node(
tree, scope, self.symbol_map, WorkflowParser
tree, self.scope, self.symbol_map, WorkflowParser
)
self._digest_flow_control("try", try_node)

def visit_Expr(self, stmt: ast.Expr) -> None:
if is_append_call(stmt.value):
self._handle_appending_to_accumulator(cast(ast.Call, stmt.value))
else:
self.generic_visit(stmt)

def _handle_appending_to_accumulator(self, append_call: ast.Call) -> None:
used_accumulator = cast(
ast.Name, cast(ast.Attribute, append_call.func).value
).id
appended_symbol = cast(ast.Name, append_call.args[0]).id
self.symbol_map.use_accumulator(used_accumulator, appended_symbol)
appended_source = self.symbol_map[appended_symbol]
if isinstance(appended_source, edge_models.SourceHandle):
self.symbol_map.produce(appended_symbol)

def generic_visit(self, stmt: ast.AST) -> None:
raise TypeError(
f"Workflow python definitions can only interpret a subset of assignments, "
f"and flow controls (for/while/if/try) and (when parsing a function "
f"definition) a return, but ast found "
f"{type(stmt)}"
)


class _WorkflowFunctionParser(WorkflowParser):
def __init__(
self,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
*,
fully_qualified_name: str | None = None,
version: str | None = None,
func: FunctionType,
output_labels: Collection[str],
):
super().__init__(scope, symbol_map, fully_qualified_name, version)
self._func = func
self._output_labels = output_labels
self._found_return = False

@property
def found_return(self) -> bool:
return self._found_return

def visit_Return(self, stmt: ast.Return) -> None:
if self._found_return:
raise ValueError(
"Workflow python definitions must have exactly one return."
)
self._found_return = True
self.handle_return(stmt, self._func, self._output_labels)

def handle_return(
self,
body: ast.Return,
Expand Down Expand Up @@ -373,16 +386,6 @@ def handle_return(
)
self.symbol_map.produce(port, symbol)

def handle_appending_to_accumulator(self, append_call: ast.Call) -> None:
used_accumulator = cast(
ast.Name, cast(ast.Attribute, append_call.func).value
).id
appended_symbol = cast(ast.Name, append_call.args[0]).id
self.symbol_map.use_accumulator(used_accumulator, appended_symbol)
appended_source = self.symbol_map[appended_symbol]
if isinstance(appended_source, edge_models.SourceHandle):
self.symbol_map.produce(appended_symbol)


def is_append_call(node: ast.expr | ast.Expr) -> bool:
"""Check if node is an append call to a known accumulator."""
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/models/parsers/test_workflow_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flowrep.models.nodes import atomic_model, workflow_model
from flowrep.models.parsers import (
atomic_parser,
object_scope,
parser_protocol,
symbol_scope,
workflow_parser,
Expand Down Expand Up @@ -145,7 +146,9 @@ class MyClass:
class TestParseWorkflowBasic(unittest.TestCase):
def test_protocol_fulfillment(self):
self.assertIsInstance(
workflow_parser.WorkflowParser(symbol_scope.SymbolScope({})),
workflow_parser.WorkflowParser(
object_scope.ScopeProxy({}), symbol_scope.SymbolScope({})
),
parser_protocol.BodyWalker,
)

Expand Down Expand Up @@ -609,7 +612,9 @@ def outer_wf(a):
)

def test_fqn_defaults_to_none_on_raw_parser(self):
parser = workflow_parser.WorkflowParser(symbol_scope.SymbolScope({}))
parser = workflow_parser.WorkflowParser(
object_scope.ScopeProxy({}), symbol_scope.SymbolScope({})
)
self.assertIsNone(parser.fully_qualified_name)

def test_fqn_roundtrips_through_serialization(self):
Expand Down
Loading