Skip to content

Commit 0813585

Browse files
authored
Merge pull request #157 from pyiron/leverage-ast
Leverage `ast.NodeVisitor`
2 parents d376d6b + 1ac15ca commit 0813585

6 files changed

Lines changed: 104 additions & 117 deletions

File tree

flowrep/models/parsers/case_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def walk_branch(
7777
walker_factory: parser_protocol.WalkerFactory,
7878
) -> WalkedBranch:
7979
fork = symbol_map.fork_scope()
80-
w = walker_factory(fork)
81-
w.walk(stmts, scope)
80+
w = walker_factory(scope, fork)
81+
w.walk(stmts)
8282
assigned = fork.assigned_symbols
8383
fork.produce_symbols(assigned)
8484
return WalkedBranch(label, w, assigned)

flowrep/models/parsers/for_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def parse_for_node(
5252
available_accumulators=symbol_map.declared_accumulators.copy(),
5353
)
5454

55-
body_walker = walker_factory(body_symbol_map)
56-
body_walker.walk(body_tree.body, scope)
55+
body_walker = walker_factory(scope, body_symbol_map)
56+
body_walker.walk(body_tree.body)
5757
consumed = body_walker.symbol_map.consumed_accumulators
5858

5959
_validate_some_output_exists(consumed)
Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
from __future__ import annotations
22

33
import ast
4-
from collections.abc import Callable, Collection
5-
from types import FunctionType
4+
from collections.abc import Callable
65
from typing import Protocol, runtime_checkable
76

87
from flowrep.models import edge_models
98
from flowrep.models.nodes import union, workflow_model
109
from flowrep.models.parsers import object_scope, symbol_scope
1110

12-
WalkerFactory = Callable[[symbol_scope.SymbolScope], "BodyWalker"]
11+
WalkerFactory = Callable[
12+
[object_scope.ScopeProxy, symbol_scope.SymbolScope], "BodyWalker"
13+
]
1314

1415

1516
@runtime_checkable
1617
class BodyWalker(Protocol):
1718
"""What control flow parsers need to walk a sub-body."""
1819

20+
scope: object_scope.ScopeProxy
1921
symbol_map: symbol_scope.SymbolScope
2022
nodes: union.Nodes
2123

@@ -34,31 +36,8 @@ def output_edges(self) -> edge_models.OutputEdges: ...
3436
@property
3537
def outputs(self) -> list[str]: ...
3638

37-
def visit(self, stmt: ast.stmt, scope: object_scope.ScopeProxy) -> None: ...
39+
def visit(self, stmt: ast.AST) -> None: ...
3840

39-
def walk(
40-
self, statements: list[ast.stmt], scope: object_scope.ScopeProxy
41-
) -> None: ...
42-
43-
def handle_assign(
44-
self, body: ast.Assign | ast.AnnAssign, scope: object_scope.ScopeProxy
45-
) -> None: ...
46-
47-
def handle_for(self, tree: ast.For, scope: object_scope.ScopeProxy) -> None: ...
48-
49-
def handle_if(self, tree: ast.If, scope: object_scope.ScopeProxy) -> None: ...
50-
51-
def handle_try(self, tree: ast.Try, scope: object_scope.ScopeProxy) -> None: ...
52-
53-
def handle_while(self, tree: ast.While, scope: object_scope.ScopeProxy) -> None: ...
54-
55-
def handle_appending_to_accumulator(self, append_call: ast.Call) -> None: ...
56-
57-
def handle_return(
58-
self,
59-
body: ast.Return,
60-
func: FunctionType,
61-
output_labels: Collection[str],
62-
) -> None: ...
41+
def walk(self, statements: list[ast.stmt]) -> None: ...
6342

6443
def build_model(self) -> workflow_model.WorkflowNode: ...

flowrep/models/parsers/while_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def parse_while_node(
4040
tree.test, scope, symbol_map, WHILE_CONDITION_LABEL
4141
)
4242

43-
body_walker = walker_factory(symbol_map.fork_scope())
44-
body_walker.walk(tree.body, scope)
43+
body_walker = walker_factory(scope, symbol_map.fork_scope())
44+
body_walker.walk(tree.body)
4545
reassigned_symbols = body_walker.symbol_map.reassigned_symbols
4646

4747
_validate_some_output_exists(reassigned_symbols)

flowrep/models/parsers/workflow_parser.py

Lines changed: 84 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
from collections.abc import Callable, Collection
33
from types import FunctionType
4-
from typing import Any, cast
4+
from typing import cast
55

66
from pyiron_snippets import versions
77

@@ -20,8 +20,6 @@
2020
while_parser,
2121
)
2222

23-
SpecialHandlers = dict[type[ast.stmt], Callable[[Any, object_scope.ScopeProxy], None]]
24-
2523

2624
def workflow(
2725
func: FunctionType | str | None = None,
@@ -121,31 +119,19 @@ def parse_workflow(
121119
require_version=require_version,
122120
)
123121
inputs = label_helpers.get_input_labels(func)
124-
state = WorkflowParser(
122+
state = _WorkflowFunctionParser(
123+
object_scope.get_scope(func),
125124
symbol_scope.SymbolScope({p: edge_models.InputSource(port=p) for p in inputs}),
126125
fully_qualified_name=info.fully_qualified_name,
127126
version=info.version,
127+
func=func,
128+
output_labels=output_labels,
128129
)
129130
tree = parser_helpers.get_ast_function_node(func)
130131

131-
found_return = False
132-
133-
def handle_return(stmt: ast.Return, scope: object_scope.ScopeProxy):
134-
nonlocal found_return
135-
if found_return:
136-
raise ValueError(
137-
"Workflow python definitions must have exactly one return."
138-
)
139-
found_return = True
140-
state.handle_return(stmt, func, output_labels)
132+
state.walk(skip_docstring(tree.body))
141133

142-
state.walk(
143-
skip_docstring(tree.body),
144-
object_scope.get_scope(func),
145-
special_handlers={ast.Return: handle_return},
146-
)
147-
148-
if not found_return:
134+
if not state.found_return:
149135
raise ValueError("Workflow python definitions must have a return statement.")
150136

151137
source_code = parser_helpers.get_available_source_code(func)
@@ -165,7 +151,7 @@ def skip_docstring(body: list[ast.stmt]) -> list[ast.stmt]:
165151
)
166152

167153

168-
class WorkflowParser(parser_protocol.BodyWalker):
154+
class WorkflowParser(ast.NodeVisitor, parser_protocol.BodyWalker):
169155
"""
170156
Aggregates state until there is enough data to successfully build the pydantic
171157
data model.
@@ -180,10 +166,12 @@ class WorkflowParser(parser_protocol.BodyWalker):
180166

181167
def __init__(
182168
self,
169+
scope: object_scope.ScopeProxy,
183170
symbol_map: symbol_scope.SymbolScope,
184171
fully_qualified_name: str | None = None,
185172
version: str | None = None,
186173
):
174+
self.scope = scope
187175
self.symbol_map = symbol_map
188176
self.nodes: union.Nodes = {}
189177
self.fully_qualified_name = fully_qualified_name
@@ -226,54 +214,24 @@ def build_model(
226214
source_code=source_code,
227215
)
228216

229-
def visit(self, stmt: ast.stmt, scope: object_scope.ScopeProxy) -> None:
230-
if isinstance(stmt, ast.Assign | ast.AnnAssign):
231-
self.handle_assign(stmt, scope)
232-
elif isinstance(stmt, ast.For):
233-
self.handle_for(stmt, scope)
234-
elif isinstance(stmt, ast.While):
235-
self.handle_while(stmt, scope)
236-
elif isinstance(stmt, ast.If):
237-
self.handle_if(stmt, scope)
238-
elif isinstance(stmt, ast.Try):
239-
self.handle_try(stmt, scope)
240-
elif isinstance(stmt, ast.Expr) and is_append_call(stmt.value):
241-
self.handle_appending_to_accumulator(cast(ast.Call, stmt.value))
242-
else:
243-
raise TypeError(
244-
f"Workflow python definitions can only interpret assignments, a subset "
245-
f"of flow control (for/while/if/try) and a return, but ast found "
246-
f"{type(stmt)}"
247-
)
217+
def walk(self, statements: list[ast.stmt]) -> None:
218+
for statement in statements:
219+
self.visit(statement)
248220

249-
def walk(
250-
self,
251-
statements: list[ast.stmt],
252-
scope: object_scope.ScopeProxy,
253-
*,
254-
special_handlers: SpecialHandlers | None = None,
255-
) -> None:
256-
for stmt in statements:
257-
if special_handlers:
258-
for ast_type, handler in special_handlers.items():
259-
if isinstance(stmt, ast_type):
260-
handler(stmt, scope)
261-
break
262-
else:
263-
self.visit(stmt, scope)
264-
else:
265-
self.visit(stmt, scope)
266-
267-
def handle_assign(
268-
self, body: ast.Assign | ast.AnnAssign, scope: object_scope.ScopeProxy
269-
):
221+
def visit_Assign(self, stmt: ast.Assign) -> None:
222+
self._handle_assign(stmt)
223+
224+
def visit_AnnAssign(self, stmt: ast.AnnAssign) -> None:
225+
self._handle_assign(stmt)
226+
227+
def _handle_assign(self, body: ast.Assign | ast.AnnAssign):
270228
# Get returned symbols from the left-hand side
271229
lhs = body.targets[0] if isinstance(body, ast.Assign) else body.target
272230
new_symbols = parser_helpers.resolve_symbols_to_strings(lhs)
273231

274232
rhs = body.value
275233
if isinstance(rhs, ast.Call):
276-
child = atomic_parser.get_labeled_recipe(rhs, self.nodes.keys(), scope)
234+
child = atomic_parser.get_labeled_recipe(rhs, self.nodes.keys(), self.scope)
277235
self.nodes[child.label] = child.node
278236
parser_helpers.consume_call_arguments(self.symbol_map, rhs, child)
279237
self.symbol_map.register(new_symbols, child)
@@ -303,30 +261,85 @@ def _connect_node_to_enclosing_scope(self, label: str, node: union.NodeType):
303261
labeled_node = helper_models.LabeledNode(label=label, node=node)
304262
self.symbol_map.register(new_symbols=node.outputs, child=labeled_node)
305263

306-
def handle_for(self, tree: ast.For, scope: object_scope.ScopeProxy) -> None:
264+
def visit_For(self, tree: ast.For) -> None:
307265
for_node = for_parser.parse_for_node(
308-
tree, scope, self.symbol_map, WorkflowParser
266+
tree, self.scope, self.symbol_map, WorkflowParser
309267
)
310268
# Accumulators consumed by the for body are no longer available here
311269
self.symbol_map.declared_accumulators -= set(for_node.outputs)
312270
self._digest_flow_control("for", for_node)
313271

314-
def handle_while(self, tree: ast.While, scope: object_scope.ScopeProxy) -> None:
272+
def visit_While(self, tree: ast.While) -> None:
315273
while_node = while_parser.parse_while_node(
316-
tree, scope, self.symbol_map, WorkflowParser
274+
tree, self.scope, self.symbol_map, WorkflowParser
317275
)
318276
self._digest_flow_control("while", while_node)
319277

320-
def handle_if(self, tree: ast.If, scope: object_scope.ScopeProxy) -> None:
321-
if_node = if_parser.parse_if_node(tree, scope, self.symbol_map, WorkflowParser)
278+
def visit_If(self, tree: ast.If) -> None:
279+
if_node = if_parser.parse_if_node(
280+
tree, self.scope, self.symbol_map, WorkflowParser
281+
)
322282
self._digest_flow_control("if", if_node)
323283

324-
def handle_try(self, tree: ast.Try, scope: object_scope.ScopeProxy) -> None:
284+
def visit_Try(self, tree: ast.Try) -> None:
325285
try_node = try_parser.parse_try_node(
326-
tree, scope, self.symbol_map, WorkflowParser
286+
tree, self.scope, self.symbol_map, WorkflowParser
327287
)
328288
self._digest_flow_control("try", try_node)
329289

290+
def visit_Expr(self, stmt: ast.Expr) -> None:
291+
if is_append_call(stmt.value):
292+
self._handle_appending_to_accumulator(cast(ast.Call, stmt.value))
293+
else:
294+
self.generic_visit(stmt)
295+
296+
def _handle_appending_to_accumulator(self, append_call: ast.Call) -> None:
297+
used_accumulator = cast(
298+
ast.Name, cast(ast.Attribute, append_call.func).value
299+
).id
300+
appended_symbol = cast(ast.Name, append_call.args[0]).id
301+
self.symbol_map.use_accumulator(used_accumulator, appended_symbol)
302+
appended_source = self.symbol_map[appended_symbol]
303+
if isinstance(appended_source, edge_models.SourceHandle):
304+
self.symbol_map.produce(appended_symbol)
305+
306+
def generic_visit(self, stmt: ast.AST) -> None:
307+
raise TypeError(
308+
f"Workflow python definitions can only interpret a subset of assignments, "
309+
f"and flow controls (for/while/if/try) and (when parsing a function "
310+
f"definition) a return, but ast found "
311+
f"{type(stmt)}"
312+
)
313+
314+
315+
class _WorkflowFunctionParser(WorkflowParser):
316+
def __init__(
317+
self,
318+
scope: object_scope.ScopeProxy,
319+
symbol_map: symbol_scope.SymbolScope,
320+
*,
321+
fully_qualified_name: str | None = None,
322+
version: str | None = None,
323+
func: FunctionType,
324+
output_labels: Collection[str],
325+
):
326+
super().__init__(scope, symbol_map, fully_qualified_name, version)
327+
self._func = func
328+
self._output_labels = output_labels
329+
self._found_return = False
330+
331+
@property
332+
def found_return(self) -> bool:
333+
return self._found_return
334+
335+
def visit_Return(self, stmt: ast.Return) -> None:
336+
if self._found_return:
337+
raise ValueError(
338+
"Workflow python definitions must have exactly one return."
339+
)
340+
self._found_return = True
341+
self.handle_return(stmt, self._func, self._output_labels)
342+
330343
def handle_return(
331344
self,
332345
body: ast.Return,
@@ -373,16 +386,6 @@ def handle_return(
373386
)
374387
self.symbol_map.produce(port, symbol)
375388

376-
def handle_appending_to_accumulator(self, append_call: ast.Call) -> None:
377-
used_accumulator = cast(
378-
ast.Name, cast(ast.Attribute, append_call.func).value
379-
).id
380-
appended_symbol = cast(ast.Name, append_call.args[0]).id
381-
self.symbol_map.use_accumulator(used_accumulator, appended_symbol)
382-
appended_source = self.symbol_map[appended_symbol]
383-
if isinstance(appended_source, edge_models.SourceHandle):
384-
self.symbol_map.produce(appended_symbol)
385-
386389

387390
def is_append_call(node: ast.expr | ast.Expr) -> bool:
388391
"""Check if node is an append call to a known accumulator."""

tests/unit/models/parsers/test_workflow_parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from flowrep.models.nodes import atomic_model, workflow_model
66
from flowrep.models.parsers import (
77
atomic_parser,
8+
object_scope,
89
parser_protocol,
910
symbol_scope,
1011
workflow_parser,
@@ -145,7 +146,9 @@ class MyClass:
145146
class TestParseWorkflowBasic(unittest.TestCase):
146147
def test_protocol_fulfillment(self):
147148
self.assertIsInstance(
148-
workflow_parser.WorkflowParser(symbol_scope.SymbolScope({})),
149+
workflow_parser.WorkflowParser(
150+
object_scope.ScopeProxy({}), symbol_scope.SymbolScope({})
151+
),
149152
parser_protocol.BodyWalker,
150153
)
151154

@@ -609,7 +612,9 @@ def outer_wf(a):
609612
)
610613

611614
def test_fqn_defaults_to_none_on_raw_parser(self):
612-
parser = workflow_parser.WorkflowParser(symbol_scope.SymbolScope({}))
615+
parser = workflow_parser.WorkflowParser(
616+
object_scope.ScopeProxy({}), symbol_scope.SymbolScope({})
617+
)
613618
self.assertIsNone(parser.fully_qualified_name)
614619

615620
def test_fqn_roundtrips_through_serialization(self):

0 commit comments

Comments
 (0)