11import ast
22from collections .abc import Callable , Collection
33from types import FunctionType
4- from typing import Any , cast
4+ from typing import cast
55
66from pyiron_snippets import versions
77
2020 while_parser ,
2121)
2222
23- SpecialHandlers = dict [type [ast .stmt ], Callable [[Any , object_scope .ScopeProxy ], None ]]
24-
2523
2624def workflow (
2725 func : FunctionType | str | None = None ,
@@ -121,31 +119,19 @@ def parse_workflow(
121119 require_version = require_version ,
122120 )
123121 inputs = label_helpers .get_input_labels (func )
124- state = WorkflowParser (
122+ state = _WorkflowFunctionParser (
123+ object_scope .get_scope (func ),
125124 symbol_scope .SymbolScope ({p : edge_models .InputSource (port = p ) for p in inputs }),
126125 fully_qualified_name = info .fully_qualified_name ,
127126 version = info .version ,
127+ func = func ,
128+ output_labels = output_labels ,
128129 )
129130 tree = parser_helpers .get_ast_function_node (func )
130131
131- found_return = False
132-
133- def handle_return (stmt : ast .Return , scope : object_scope .ScopeProxy ):
134- nonlocal found_return
135- if found_return :
136- raise ValueError (
137- "Workflow python definitions must have exactly one return."
138- )
139- found_return = True
140- state .handle_return (stmt , func , output_labels )
132+ state .walk (skip_docstring (tree .body ))
141133
142- state .walk (
143- skip_docstring (tree .body ),
144- object_scope .get_scope (func ),
145- special_handlers = {ast .Return : handle_return },
146- )
147-
148- if not found_return :
134+ if not state .found_return :
149135 raise ValueError ("Workflow python definitions must have a return statement." )
150136
151137 source_code = parser_helpers .get_available_source_code (func )
@@ -165,7 +151,7 @@ def skip_docstring(body: list[ast.stmt]) -> list[ast.stmt]:
165151 )
166152
167153
168- class WorkflowParser (parser_protocol .BodyWalker ):
154+ class WorkflowParser (ast . NodeVisitor , parser_protocol .BodyWalker ):
169155 """
170156 Aggregates state until there is enough data to successfully build the pydantic
171157 data model.
@@ -180,10 +166,12 @@ class WorkflowParser(parser_protocol.BodyWalker):
180166
181167 def __init__ (
182168 self ,
169+ scope : object_scope .ScopeProxy ,
183170 symbol_map : symbol_scope .SymbolScope ,
184171 fully_qualified_name : str | None = None ,
185172 version : str | None = None ,
186173 ):
174+ self .scope = scope
187175 self .symbol_map = symbol_map
188176 self .nodes : union .Nodes = {}
189177 self .fully_qualified_name = fully_qualified_name
@@ -226,54 +214,24 @@ def build_model(
226214 source_code = source_code ,
227215 )
228216
229- def visit (self , stmt : ast .stmt , scope : object_scope .ScopeProxy ) -> None :
230- if isinstance (stmt , ast .Assign | ast .AnnAssign ):
231- self .handle_assign (stmt , scope )
232- elif isinstance (stmt , ast .For ):
233- self .handle_for (stmt , scope )
234- elif isinstance (stmt , ast .While ):
235- self .handle_while (stmt , scope )
236- elif isinstance (stmt , ast .If ):
237- self .handle_if (stmt , scope )
238- elif isinstance (stmt , ast .Try ):
239- self .handle_try (stmt , scope )
240- elif isinstance (stmt , ast .Expr ) and is_append_call (stmt .value ):
241- self .handle_appending_to_accumulator (cast (ast .Call , stmt .value ))
242- else :
243- raise TypeError (
244- f"Workflow python definitions can only interpret assignments, a subset "
245- f"of flow control (for/while/if/try) and a return, but ast found "
246- f"{ type (stmt )} "
247- )
217+ def walk (self , statements : list [ast .stmt ]) -> None :
218+ for statement in statements :
219+ self .visit (statement )
248220
249- def walk (
250- self ,
251- statements : list [ast .stmt ],
252- scope : object_scope .ScopeProxy ,
253- * ,
254- special_handlers : SpecialHandlers | None = None ,
255- ) -> None :
256- for stmt in statements :
257- if special_handlers :
258- for ast_type , handler in special_handlers .items ():
259- if isinstance (stmt , ast_type ):
260- handler (stmt , scope )
261- break
262- else :
263- self .visit (stmt , scope )
264- else :
265- self .visit (stmt , scope )
266-
267- def handle_assign (
268- self , body : ast .Assign | ast .AnnAssign , scope : object_scope .ScopeProxy
269- ):
221+ def visit_Assign (self , stmt : ast .Assign ) -> None :
222+ self ._handle_assign (stmt )
223+
224+ def visit_AnnAssign (self , stmt : ast .AnnAssign ) -> None :
225+ self ._handle_assign (stmt )
226+
227+ def _handle_assign (self , body : ast .Assign | ast .AnnAssign ):
270228 # Get returned symbols from the left-hand side
271229 lhs = body .targets [0 ] if isinstance (body , ast .Assign ) else body .target
272230 new_symbols = parser_helpers .resolve_symbols_to_strings (lhs )
273231
274232 rhs = body .value
275233 if isinstance (rhs , ast .Call ):
276- child = atomic_parser .get_labeled_recipe (rhs , self .nodes .keys (), scope )
234+ child = atomic_parser .get_labeled_recipe (rhs , self .nodes .keys (), self . scope )
277235 self .nodes [child .label ] = child .node
278236 parser_helpers .consume_call_arguments (self .symbol_map , rhs , child )
279237 self .symbol_map .register (new_symbols , child )
@@ -303,30 +261,85 @@ def _connect_node_to_enclosing_scope(self, label: str, node: union.NodeType):
303261 labeled_node = helper_models .LabeledNode (label = label , node = node )
304262 self .symbol_map .register (new_symbols = node .outputs , child = labeled_node )
305263
306- def handle_for (self , tree : ast .For , scope : object_scope . ScopeProxy ) -> None :
264+ def visit_For (self , tree : ast .For ) -> None :
307265 for_node = for_parser .parse_for_node (
308- tree , scope , self .symbol_map , WorkflowParser
266+ tree , self . scope , self .symbol_map , WorkflowParser
309267 )
310268 # Accumulators consumed by the for body are no longer available here
311269 self .symbol_map .declared_accumulators -= set (for_node .outputs )
312270 self ._digest_flow_control ("for" , for_node )
313271
314- def handle_while (self , tree : ast .While , scope : object_scope . ScopeProxy ) -> None :
272+ def visit_While (self , tree : ast .While ) -> None :
315273 while_node = while_parser .parse_while_node (
316- tree , scope , self .symbol_map , WorkflowParser
274+ tree , self . scope , self .symbol_map , WorkflowParser
317275 )
318276 self ._digest_flow_control ("while" , while_node )
319277
320- def handle_if (self , tree : ast .If , scope : object_scope .ScopeProxy ) -> None :
321- if_node = if_parser .parse_if_node (tree , scope , self .symbol_map , WorkflowParser )
278+ def visit_If (self , tree : ast .If ) -> None :
279+ if_node = if_parser .parse_if_node (
280+ tree , self .scope , self .symbol_map , WorkflowParser
281+ )
322282 self ._digest_flow_control ("if" , if_node )
323283
324- def handle_try (self , tree : ast .Try , scope : object_scope . ScopeProxy ) -> None :
284+ def visit_Try (self , tree : ast .Try ) -> None :
325285 try_node = try_parser .parse_try_node (
326- tree , scope , self .symbol_map , WorkflowParser
286+ tree , self . scope , self .symbol_map , WorkflowParser
327287 )
328288 self ._digest_flow_control ("try" , try_node )
329289
290+ def visit_Expr (self , stmt : ast .Expr ) -> None :
291+ if is_append_call (stmt .value ):
292+ self ._handle_appending_to_accumulator (cast (ast .Call , stmt .value ))
293+ else :
294+ self .generic_visit (stmt )
295+
296+ def _handle_appending_to_accumulator (self , append_call : ast .Call ) -> None :
297+ used_accumulator = cast (
298+ ast .Name , cast (ast .Attribute , append_call .func ).value
299+ ).id
300+ appended_symbol = cast (ast .Name , append_call .args [0 ]).id
301+ self .symbol_map .use_accumulator (used_accumulator , appended_symbol )
302+ appended_source = self .symbol_map [appended_symbol ]
303+ if isinstance (appended_source , edge_models .SourceHandle ):
304+ self .symbol_map .produce (appended_symbol )
305+
306+ def generic_visit (self , stmt : ast .AST ) -> None :
307+ raise TypeError (
308+ f"Workflow python definitions can only interpret a subset of assignments, "
309+ f"and flow controls (for/while/if/try) and (when parsing a function "
310+ f"definition) a return, but ast found "
311+ f"{ type (stmt )} "
312+ )
313+
314+
315+ class _WorkflowFunctionParser (WorkflowParser ):
316+ def __init__ (
317+ self ,
318+ scope : object_scope .ScopeProxy ,
319+ symbol_map : symbol_scope .SymbolScope ,
320+ * ,
321+ fully_qualified_name : str | None = None ,
322+ version : str | None = None ,
323+ func : FunctionType ,
324+ output_labels : Collection [str ],
325+ ):
326+ super ().__init__ (scope , symbol_map , fully_qualified_name , version )
327+ self ._func = func
328+ self ._output_labels = output_labels
329+ self ._found_return = False
330+
331+ @property
332+ def found_return (self ) -> bool :
333+ return self ._found_return
334+
335+ def visit_Return (self , stmt : ast .Return ) -> None :
336+ if self ._found_return :
337+ raise ValueError (
338+ "Workflow python definitions must have exactly one return."
339+ )
340+ self ._found_return = True
341+ self .handle_return (stmt , self ._func , self ._output_labels )
342+
330343 def handle_return (
331344 self ,
332345 body : ast .Return ,
@@ -373,16 +386,6 @@ def handle_return(
373386 )
374387 self .symbol_map .produce (port , symbol )
375388
376- def handle_appending_to_accumulator (self , append_call : ast .Call ) -> None :
377- used_accumulator = cast (
378- ast .Name , cast (ast .Attribute , append_call .func ).value
379- ).id
380- appended_symbol = cast (ast .Name , append_call .args [0 ]).id
381- self .symbol_map .use_accumulator (used_accumulator , appended_symbol )
382- appended_source = self .symbol_map [appended_symbol ]
383- if isinstance (appended_source , edge_models .SourceHandle ):
384- self .symbol_map .produce (appended_symbol )
385-
386389
387390def is_append_call (node : ast .expr | ast .Expr ) -> bool :
388391 """Check if node is an append call to a known accumulator."""
0 commit comments