@@ -443,6 +443,9 @@ def __init__(self) -> None:
443443 self ._register_custom ("FLIP" , 1 , 1 , self ._flip )
444444 self ._register_custom ("TFLIP" , 2 , 2 , self ._tflip )
445445 self ._register_custom ("SCATTER" , 3 , 3 , self ._scatter )
446+ # PARALLEL accepts either a single TNS of functions, or any number
447+ # of function arguments passed directly (variadic form).
448+ self ._register_custom ("PARALLEL" , 1 , None , self ._parallel )
446449
447450 def _register_int_only (self , name : str , arity : int , func : Callable [..., int ]) -> None :
448451 def impl (_ : "Interpreter" , args : List [Value ], __ : List [Expression ], ___ : Environment , location : SourceLocation ) -> Value :
@@ -2412,6 +2415,66 @@ def _scatter(
24122415 dst_view [tuple (slice (start , end ) for start , end in slices )] = src_view
24132416 return Value (TYPE_TNS , Tensor (shape = list (dst .shape ), data = out_data ))
24142417
2418+ def _parallel (
2419+ self ,
2420+ interpreter : "Interpreter" ,
2421+ args : List [Value ],
2422+ __ : List [Expression ],
2423+ ___ : Environment ,
2424+ location : SourceLocation ,
2425+ ) -> Value :
2426+ """PARALLEL(TNS: functions):INT
2427+
2428+ Execute each element of the `functions` tensor in parallel. Each
2429+ element must be a `FUNC` value. Wait for all to complete and return
2430+ integer 0 on success. If any element is not a function or any
2431+ invocation raises, an ASMRuntimeError is raised.
2432+ """
2433+ # Support two forms:
2434+ # - PARALLEL(TNS: functions)
2435+ # - PARALLEL(FUNC, FUNC, ...)
2436+ elems : List [Any ]
2437+ if len (args ) == 1 and args [0 ].type == TYPE_TNS :
2438+ tensor = args [0 ].value
2439+ data = tensor .data
2440+ flat = data .ravel ()
2441+ elems = [flat [i ] for i in range (flat .size )]
2442+ else :
2443+ elems = list (args )
2444+
2445+ n = len (elems )
2446+ results : List [Optional [Value ]] = [None ] * n
2447+ errors : List [Optional [BaseException ]] = [None ] * n
2448+
2449+ def worker (idx : int , elem : Any ) -> None :
2450+ try :
2451+ if not isinstance (elem , Value ) or elem .type != TYPE_FUNC :
2452+ raise ASMRuntimeError ("PARALLEL expects functions (either a tensor of FUNC or FUNC arguments)" , location = location , rewrite_rule = "PARALLEL" )
2453+ func = elem .value
2454+ # Invoke with no args and the provided environment as closure
2455+ res = interpreter ._invoke_function_object (func , [], {}, location , ___ )
2456+ results [idx ] = res
2457+ except BaseException as exc :
2458+ errors [idx ] = exc
2459+
2460+ threads : List [threading .Thread ] = []
2461+ for i in range (n ):
2462+ t = threading .Thread (target = worker , args = (i , elems [i ]))
2463+ t .start ()
2464+ threads .append (t )
2465+
2466+ for t in threads :
2467+ t .join ()
2468+
2469+ # Propagate first error if any
2470+ for err in errors :
2471+ if err is not None :
2472+ if isinstance (err , ASMRuntimeError ):
2473+ raise err
2474+ raise ASMRuntimeError (f"PARALLEL worker failed: { err } " , location = location , rewrite_rule = "PARALLEL" )
2475+
2476+ return Value (TYPE_INT , 0 )
2477+
24152478 def _convolve (
24162479 self ,
24172480 interpreter : "Interpreter" ,
0 commit comments