Skip to content

Commit 1ff5b61

Browse files
Add PARALLEL
1 parent 1f3fad5 commit 1ff5b61

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

asm-lang.exe

3.6 KB
Binary file not shown.

interpreter.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,9 @@ def __init__(self) -> None:
443443
self._register_custom("FLIP", 1, 1, self._flip)
444444
self._register_custom("TFLIP", 2, 2, self._tflip)
445445
self._register_custom("SCATTER", 3, 3, self._scatter)
446+
# PARALLEL accepts either a single TNS of functions, or any number
447+
# of function arguments passed directly (variadic form).
448+
self._register_custom("PARALLEL", 1, None, self._parallel)
446449

447450
def _register_int_only(self, name: str, arity: int, func: Callable[..., int]) -> None:
448451
def impl(_: "Interpreter", args: List[Value], __: List[Expression], ___: Environment, location: SourceLocation) -> Value:
@@ -2412,6 +2415,66 @@ def _scatter(
24122415
dst_view[tuple(slice(start, end) for start, end in slices)] = src_view
24132416
return Value(TYPE_TNS, Tensor(shape=list(dst.shape), data=out_data))
24142417

2418+
def _parallel(
2419+
self,
2420+
interpreter: "Interpreter",
2421+
args: List[Value],
2422+
__: List[Expression],
2423+
___: Environment,
2424+
location: SourceLocation,
2425+
) -> Value:
2426+
"""PARALLEL(TNS: functions):INT
2427+
2428+
Execute each element of the `functions` tensor in parallel. Each
2429+
element must be a `FUNC` value. Wait for all to complete and return
2430+
integer 0 on success. If any element is not a function or any
2431+
invocation raises, an ASMRuntimeError is raised.
2432+
"""
2433+
# Support two forms:
2434+
# - PARALLEL(TNS: functions)
2435+
# - PARALLEL(FUNC, FUNC, ...)
2436+
elems: List[Any]
2437+
if len(args) == 1 and args[0].type == TYPE_TNS:
2438+
tensor = args[0].value
2439+
data = tensor.data
2440+
flat = data.ravel()
2441+
elems = [flat[i] for i in range(flat.size)]
2442+
else:
2443+
elems = list(args)
2444+
2445+
n = len(elems)
2446+
results: List[Optional[Value]] = [None] * n
2447+
errors: List[Optional[BaseException]] = [None] * n
2448+
2449+
def worker(idx: int, elem: Any) -> None:
2450+
try:
2451+
if not isinstance(elem, Value) or elem.type != TYPE_FUNC:
2452+
raise ASMRuntimeError("PARALLEL expects functions (either a tensor of FUNC or FUNC arguments)", location=location, rewrite_rule="PARALLEL")
2453+
func = elem.value
2454+
# Invoke with no args and the provided environment as closure
2455+
res = interpreter._invoke_function_object(func, [], {}, location, ___)
2456+
results[idx] = res
2457+
except BaseException as exc:
2458+
errors[idx] = exc
2459+
2460+
threads: List[threading.Thread] = []
2461+
for i in range(n):
2462+
t = threading.Thread(target=worker, args=(i, elems[i]))
2463+
t.start()
2464+
threads.append(t)
2465+
2466+
for t in threads:
2467+
t.join()
2468+
2469+
# Propagate first error if any
2470+
for err in errors:
2471+
if err is not None:
2472+
if isinstance(err, ASMRuntimeError):
2473+
raise err
2474+
raise ASMRuntimeError(f"PARALLEL worker failed: {err}", location=location, rewrite_rule="PARALLEL")
2475+
2476+
return Value(TYPE_INT, 0)
2477+
24152478
def _convolve(
24162479
self,
24172480
interpreter: "Interpreter",

0 commit comments

Comments
 (0)