diff --git a/docs/python-callable-serialization.md b/docs/python-callable-serialization.md new file mode 100644 index 000000000..186357a92 --- /dev/null +++ b/docs/python-callable-serialization.md @@ -0,0 +1,641 @@ +# Python Callable Serialization for L3+ Register + +This document specifies a design for registering Python callables after an +L3+ `Worker` has already initialized, and in the common case after child +processes have already started. + +The design is separate from +[callable-ipc-dynamic-register.md](callable-ipc-dynamic-register.md). That +document covers `ChipCallable` binary registration for chip children. This +document covers Python callables consumed by SUB workers and by higher-level +Worker-child dispatch loops. + +It is the design and contract for the implementation in +`python/simpler/worker.py` and `src/common/hierarchical/worker_manager.*`. + +--- + +## 1. Context + +Every task submitted through the hierarchical runtime carries a `callable_id`. +For L3+ Python execution paths, that id is resolved in a Python registry: + +| Submit path | Recipient | Registry entry | +| ----------- | --------- | -------------- | +| `orch.submit_sub(cid, ...)` | SUB child | Python sub callable | +| L4+ `submit_next_level(cid, ...)` | Worker child | Python orch callable | + +Today, these entries must be registered before fork. The child process sees +the parent's `_callable_registry` only through fork-time copy-on-write. Any +parent-side mutation after fork is invisible to the already-running child. + +`ChipCallable` post-init registration already uses a control-plane plus +side-band shm payload because binary callables can be copied and prepared in +chip children. Python callables need the same high-level shape, but the +payload is serialized Python code/data and the recipients are Python-capable +children, not chip children. + +### Goals + +- Allow `Worker.register(py_callable)` after `Worker.init()` at level >= 3. +- Make the returned `cid` usable when `register()` returns. +- Preserve the current registration behavior before children start. +- Reuse the existing mailbox control-plane and per-mailbox serialization + against in-flight dispatch. +- Support unregister and cid reuse. +- Keep the API synchronous and deterministic from the caller's perspective. + +### Non-goals + +- Dynamic registration of `ChipCallable`; that protocol is covered by the + binary callable design. This document only adds the Python-residue cleanup + hook that the existing `ChipCallable` register implementation needs when a + shared cid is reused across target types. +- Cross-host or cross-Python-version serialization. +- Recovering a child process that crashes or wedges while a mailbox control + request is in flight. This design specifies timeout reporting for Python + callable broadcasts, but rebuilding the child process tree is broader + control-plane reliability work. +- Loading untrusted serialized bytes safely. This feature unpickles code from + the same user process and is not a security boundary. +- Automatically registering callables inside arbitrary descendant Workers. + A `Worker.register()` call updates the registry owned by that Worker and + the already-started children that consume that registry. +- Changing `MAX_REGISTERED_CALLABLE_IDS`. + +--- + +## 2. Public Contract + +`Worker.register(target)` keeps one cid space for both `ChipCallable` and +Python callables. The target type selects the dynamic-register route. + +- L2 `ChipCallable`: existing prepare path. +- L2 Python callable: invalid target. +- L3+ before this Worker has started child processes: store the target in the + parent registry; future children will inherit the registry when they start. + This preserves the pre-`init()` behavior and extends it to the post-`init()`, + before-first-`run()` window, where no child process has been forked yet. +- L3+ after this Worker has started child processes: existing binary IPC for + `ChipCallable`, and the new serialized Python IPC path for Python callables. + +The post-start Python path is synchronous: + +1. The parent serializes `target`. +2. The parent allocates a cid and stores `target` in its registry. +3. The parent broadcasts the payload to every Python-capable child that may + resolve this Worker's registry. +4. Each child deserializes the payload and updates its local registry. +5. `register()` returns only after every required child has acknowledged. + +The parent must not submit a newly registered cid until `register()` returns. +The runtime does not attempt to make a cid visible before the synchronous +broadcast completes. + +### Recipients + +The parent routes Python callable registration to Python-capable children: + +- SUB child processes of the same Worker. +- L4+ next-level Worker-child dispatch loops, because they resolve the + parent's registered orch functions before calling `inner_worker.run(...)`. + +L3 chip children are not recipients for Python callable payloads. They can +only consume prepared `ChipCallable` ids. + +To keep the `NEXT_LEVEL` control pool unambiguous, L4+ Workers must use +`add_worker(...)` for next-level children and must not also configure direct +`device_ids`. Direct chip children remain an L3-only configuration. + +Because `Worker.register()` does not currently take a "sub" versus +"next-level orch" kind, the simplest compatible policy is to broadcast to all +Python-capable child groups owned by this Worker. Extra registry entries are +inert if a cid is never submitted to that worker type. + +This preserves the current public API: `Worker.register(target)` does not gain +an explicit target-kind parameter. Submit-time APIs continue to decide how the +cid is interpreted. + +If no Python-capable child exists after children start, registering a Python +callable should fail with a clear `RuntimeError`. Keeping a cid that no child +can ever resolve is more confusing than rejecting it. + +### Callable Shape + +The runtime does not validate function signatures at register time. Existing +dispatch-time behavior remains: + +- SUB callables are invoked as `fn(args)`. +- Worker-child orchestration callables are invoked through + `inner_worker.run(orch_fn, args, cfg)`, so they must match the usual + orchestration shape. + +Signature errors surface from the child execution path and are reported +through the mailbox error field, as they are today. + +--- + +## 3. Serialization + +The payload must fit outside the 4 KB mailbox, so Python callables use a +side-band POSIX shm exactly like dynamic `ChipCallable` registration. The +mailbox carries only a shm name and cid. + +### Serializer Policy + +Dynamic Python callable registration uses `cloudpickle`. + +`cloudpickle` is a runtime dependency of the `simpler` package, not only a test +dependency, because child processes deserialize user callables during normal +`Worker.register()` operation. + +Registration before children start already allows lambdas and closures because +the startup path copies the registry directly. A dynamic feature that rejects +these common shapes would be surprising and would make several existing L3/L4 +test patterns impossible to move to dynamic registration. + +Stdlib `pickle` is not used for this path because it serializes most +functions by module/name reference and is therefore limited to importable +top-level functions and callable classes. It is a useful mental model for the +trust boundary, but it is not the runtime format. + +This design assumes child processes are forked from the same Python process and +therefore share the same Python major/minor version, installed package set, and +`cloudpickle` runtime. If a future startup mode uses `spawn` or independently +provisioned interpreters, dynamic Python callable registration is supported only +when the child environment is version-compatible with the parent and can import +the callable's dependencies. + +### Callable Shape and Closure Semantics + +Post-start registration supports callable shapes that `cloudpickle` can +serialize and the child can deserialize in the same Python environment: + +- importable top-level functions; +- lambdas and nested functions whose captured values are serializable; +- callable class instances whose instance state is serializable. + +This is not identical to registration before children start. Startup children +inherit a snapshot of the parent's address space, so a closure may appear to +work because the child inherited the captured object at startup. Post-start +registration sends serialized bytes to an already-running child, so captured +objects are copied or reconstructed through `cloudpickle`. + +Callables should not rely on captured process-local resources being equivalent +to fork inheritance. Examples include locks, events, open files, sockets, +`SharedMemory.buf` memoryviews, mmap views, `Worker` or `ChipWorker` instances, +nanobind/C++ handles, and device-pointer wrappers. Prefer capturing stable +identifiers that the child can reopen or reconstruct, such as a shared-memory +name instead of a live `SharedMemory.buf` object. + +### Payload Format + +The parent serializes the callable with `cloudpickle`, then wraps those bytes in +the Python-callable wire header described below. The resulting complete payload +is an in-memory byte blob. The C++ broadcast binding creates the side-band POSIX +shm, copies that complete payload into it, fan-outs the shm name to children, +and unlinks the shm after all child round-trips have completed. Python does not +create or unlink the broadcast shm. + +The Python binding must accept a Python buffer object, preferably `bytes`, not +only a raw integer pointer. The binding copies the buffer into the staging shm +before it releases the Python object reference or fans out worker threads. The +binding must not retain a raw pointer into the Python buffer after returning or +after releasing the GIL for control fan-out. + +The shm starts with a minimal Python-callable payload header followed by the +exact bytes returned by `cloudpickle.dumps(target)`: + +| Field | Size | Value | +| ----- | ---- | ----- | +| magic | 4 bytes | `SPYC` | +| version | 1 byte | `1` | +| serializer | 1 byte | `1` for `cloudpickle` | +| flags | 2 bytes | reserved, must be zero | +| payload_size | 8 bytes | little-endian unsigned byte count | + +The first implementation accepts only `(magic="SPYC", version=1, +serializer=1, flags=0)`. Unknown magic, version, serializer, non-zero flags, +payload size larger than the mapped shm, malformed bytes, or incompatible +pickle data fail through the normal mailbox error field. The child treats +`payload_size` as the authoritative byte count and ignores any trailing bytes +in the shm object, because some platforms expose POSIX shm at a page-rounded +size even when the parent requested the exact payload length. + +### Child Deserialization + +Each recipient child: + +1. Opens the shm by name. +2. Validates the payload header. +3. Copies the payload region into `bytes`. +4. Verifies that `cid` is in `[0, MAX_REGISTERED_CALLABLE_IDS)`. +5. Deserializes the callable with `cloudpickle.loads(payload_bytes)`. +6. Verifies that the result is callable. +7. Installs it into the child's local registry under the requested cid. +8. Closes the shm and acknowledges `CONTROL_DONE`. + +The child intentionally copies the payload region before deserializing it. +This avoids coupling `cloudpickle.loads(...)` to the lifetime rules of an +active `SharedMemory.buf` memoryview and keeps shm close/unlink behavior simple. + +For cid reuse after partial unregister failures, Python registration should +overwrite `registry[cid]` in the child. The parent only allocates free cids +from its own registry, so an existing child entry at the same cid is residue +from a prior best-effort failure and should be replaced. + +Because Python callables and `ChipCallable` objects share one cid space, the +same cleanup rule also applies when a cid is reused across target types. A +post-start `ChipCallable` registration performed after this feature lands must +clear any stale Python dispatch entry for the same cid from Python-capable +children owned by the same Worker before the cid is reported usable. This is a +v2 integration hook on the existing `ChipCallable` register implementation, not +a new binary payload protocol. Otherwise a failed Python unregister followed by +`ChipCallable` reuse could leave a Worker-child dispatch loop resolving the old +Python callable. + +--- + +## 4. Control Plane + +Add new control subcommands rather than overloading the existing +`CTRL_REGISTER` used for `ChipCallable`: + +```text +CTRL_PY_REGISTER = 10 +CTRL_PY_UNREGISTER = 11 +``` + +The mailbox layout for `CTRL_PY_REGISTER` mirrors binary register: + +| Offset | Field | Notes | +| ------ | ----- | ----- | +| `OFF_CALLABLE` | sub_cmd = `CTRL_PY_REGISTER` | uint64 | +| `CTRL_OFF_ARG0` | cid | low 32 bits | +| `OFF_ARGS[0..]` | NUL-terminated shm name | fixed-width slot | + +`CTRL_PY_UNREGISTER` carries only the cid in `CTRL_OFF_ARG0`. + +### Parent-Side Flow + +`Worker.register(target)` gains a Python-callable dynamic route: + +1. Reject non-callable Python targets. +2. If the first hierarchical startup is in progress, wait for that startup to + either complete or fail without holding `_registry_lock`. A registration + must not return through the startup path after the fork-time registry + snapshot has already been taken. +3. If this Worker has not started child processes, hold `_registry_lock`, + allocate the smallest free cid, insert + `self._callable_registry[cid] = target`, and return the cid; future children + will inherit the registry when they start. +4. If no configured Python-capable child group exists, raise `RuntimeError`. +5. Serialize the target with `cloudpickle.dumps(...)` and wrap it in the + complete Python-callable wire payload. +6. Hold `_registry_lock`, allocate the smallest free cid, insert + `self._callable_registry[cid] = target`, and release `_registry_lock`. +7. Broadcast `CTRL_PY_REGISTER` to required Python-capable worker groups. +8. On any failure, reacquire `_registry_lock`, remove the parent registry entry + if it still points at this target, and raise. +9. Return cid on success. + +The "configured Python-capable child group" check uses the Worker's own +configuration, not child-process state: + +- `num_sub_workers > 0` means SUB children will consume this registry. +- `len(_next_level_workers) > 0` means Worker children will consume this + registry. + +This check applies only after child processes have started. Before children +start, including after `init()` but before the first `run()`, registration uses +the parent-registry path and does not reject unused Python callables. + +If no free cid exists in `[0, MAX_REGISTERED_CALLABLE_IDS)`, register raises +`RuntimeError` before mutating the parent registry or broadcasting to children. +The caller can recover by unregistering unused callables and retrying. + +The startup race is handled by a one-time hierarchical startup state, not by a +run-wide quiescent guard: + +- `_hierarchical_start_state` is protected by a dedicated + `_hierarchical_start_mu` / `_hierarchical_start_cv`; `_registry_lock` + protects only the registry contents. +- Startup begins as `not_started`. `_start_hierarchical()` holds + `_hierarchical_start_cv`, then `_registry_lock`, while it moves to + `starting` and takes the registry snapshot. It releases `_registry_lock` + before any `os.fork()`, and moves to `started` only after child mailboxes + are registered with the C++ Worker. +- A Python callable register/unregister that observes `starting` waits on a + condition variable without holding `_registry_lock`. After startup succeeds, + it uses the post-start control path; after startup fails, it raises. +- `_start_hierarchical()` snapshots `self._callable_registry` while holding + `_registry_lock`, then forks children from that immutable snapshot. It must + not hold `_registry_lock` across `os.fork()`. + +Once children have started, dynamic Python registration is allowed while +`Worker.run()` is actively submitting or draining tasks. The operation is still +synchronous: the caller must wait for `register()` to return before submitting +the new cid. Per-child `mailbox_mu_` serialization orders each +`CTRL_PY_REGISTER` / `CTRL_PY_UNREGISTER` round trip against any in-flight +`TASK_READY` on that same child mailbox. + +`_registry_lock` protects parent-side cid allocation and registry mutation +only. It is not held while waiting for child ACKs from +`broadcast_control_all`. + +This requires a generic C++ binding that can broadcast a control command to a +selected worker pool: + +```python +_Worker.broadcast_control_all(worker_type, sub_cmd, cid, payload=None, + timeout_s=None) +``` + +`worker_type` selects `SUB` versus `NEXT_LEVEL`; `sub_cmd` is +`CTRL_PY_REGISTER` or `CTRL_PY_UNREGISTER`. For register, `payload` is the +complete Python-callable wire payload, passed as a Python buffer object. That +means the bytes start with the `SPYC` header and the header's payload region is +the exact `cloudpickle.dumps(target)` result. Passing raw `cloudpickle` bytes +directly to `_Worker.broadcast_control_all(..., CTRL_PY_REGISTER, ...)` is +invalid because the C++ binding is a generic staging layer and does not add or +interpret Python-callable headers. For unregister, `payload` is absent. The +binding owns shm creation, copying, fan-out, and unlink when a payload is +present, matching `broadcast_register_all` for binary callables while avoiding +four near-identical Python-specific bindings. + +For a selected worker pool, fan-out is parallel: C++ starts one worker thread +per target child, each round trip holds that child's `mailbox_mu_`, and the +binding waits for every child to publish `CONTROL_DONE` before returning the +per-child results. Latency is bounded by the slowest child round trip, not by +the sum of all child round trips. + +`timeout_s` is optional. When set, each child round trip that does not publish +`CONTROL_DONE` before the deadline returns a failed result with a timeout error +message. The timeout does not repair the wedged child or reclaim a mailbox +that is still owned by a stuck control command; it only bounds the caller's +wait and makes the partial failure visible to Python policy code. + +The Python `Worker` facade uses a finite default timeout for its own dynamic +Python callable register/unregister broadcasts. The default is 30 seconds and +can be overridden per Worker with `py_control_timeout_s`. + +Once a control request is staged and fan-out begins, the binding returns +structured per-child results. It does not switch between "raise" and "return +errors" based on `sub_cmd`. Python decides whether those results are strict or +best-effort: + +```text +ControlResult(worker_type, worker_index, ok, error_message) +``` + +- `Worker.register()` treats any failed `CTRL_PY_REGISTER` result as strict: + it removes the new parent registry entry and raises. +- `Worker.unregister()` treats failed `CTRL_PY_UNREGISTER` results as + best-effort: it warns, then releases its parent cid slot after the broadcast + has returned. +- The cross-type reuse hook treats failed Python-residue cleanup as strict: it + fails the `ChipCallable` registration before starting binary + `CTRL_REGISTER`. + +Argument conversion and setup failures that happen before a selected worker +pool can be contacted, such as a non-buffer `payload` object, an empty payload +buffer for a register command, or shm creation failure, may still raise +directly from the binding. Once fan-out begins, child-side failures and +timeouts are reported through `ControlResult`. + +The existing `mailbox_mu_` must be held for each child round trip, just like +binary register. This serializes Python register/unregister against +`TASK_READY` dispatch on the same child. + +Every child `CONTROL_REQUEST` handler, including existing chip-child handlers, +must reject unknown subcommands by writing `OFF_ERROR` and publishing +`CONTROL_DONE`. A misrouted Python control command must fail visibly, not ACK +as a successful no-op. + +### Cross-Type Reuse Hook + +The existing post-start `ChipCallable` register path keeps its binary payload +protocol, but gains one v2 hook before reporting a reused cid as usable: + +1. After allocating a cid for a `ChipCallable`, check whether that cid may have + held a Python callable in this Worker lifetime. +2. If so, broadcast `CTRL_PY_UNREGISTER` to every Python-capable child group + owned by this Worker as an idempotent clear operation. +3. If that clear operation reports any child error, fail the `ChipCallable` + registration, remove the new parent registry entry, and do not start the + binary `CTRL_REGISTER` broadcast. +4. If the clear succeeds, continue through the existing binary + `broadcast_register_all` path. + +This hook is needed only for Python-capable child registries. Chip children +continue to rely on the existing binary self-heal before +`prepare_callable_from_blob`. + +### Parent-Side Unregister + +`Worker.unregister(cid)` uses the registered target type to select the +unregister route: + +1. If the first hierarchical startup is in progress, wait for it to complete + without holding `_registry_lock`. +2. Hold `_registry_lock`. +3. Raise `KeyError` if `cid` is absent from the parent registry or already has + an unregister in progress. +4. If the Worker has not started child processes yet, pop the parent entry and + return. Future children will inherit the already-removed registry. +5. Mark `cid` as pending unregister, then release `_registry_lock`. A pending + cid remains unavailable for reuse until the broadcast finishes. +6. For a post-start `ChipCallable`, keep the existing binary unregister path. +7. If the target is a Python callable and this Worker has started child + processes, broadcast `CTRL_PY_UNREGISTER` to every Python-capable child + group configured for this Worker, regardless of when the callable was + originally registered. +8. Warn on per-child unregister errors. Reacquire `_registry_lock`, pop the + parent registry entry unconditionally, clear the pending marker, and make + the cid slot reusable. + +Python callable unregister never cascades into `inner_worker.unregister(...)`. +For L4+ Worker children it removes only the parent-owned dispatch registry entry +inside `_child_worker_loop`, matching the `CTRL_PY_REGISTER` ownership rule. + +Unregister is still best-effort, but reuse must self-heal. Before any +post-start `ChipCallable` registration for a cid that may have previously held a +Python callable, the parent must clear that cid from all Python-capable child +registries owned by the same Worker. This can reuse `CTRL_PY_UNREGISTER` as an +idempotent "clear Python dispatch entry" command. If the clear step fails during +registration, the new registration fails, the parent pops the newly allocated +cid, and no reverse rollback is attempted. + +### SUB Child Handler + +`_sub_worker_loop` currently handles `TASK_READY` and `SHUTDOWN`. It gains a +`CONTROL_REQUEST` branch: + +- `CTRL_PY_REGISTER`: deserialize the callable and store `registry[cid] = fn`. +- `CTRL_PY_UNREGISTER`: `registry.pop(cid, None)`. +- Any unknown control subcommand: write `OFF_ERROR`, publish `CONTROL_DONE`, + and leave the registry unchanged. + +The loop is single-threaded, and parent-side `mailbox_mu_` serializes control +commands against task dispatch, so no child-side lock is required. + +### Worker-Child Handler + +`_child_worker_loop` already has a `CONTROL_REQUEST` branch for binary +callable cascade. It gains Python subcommands with different semantics: + +- `CTRL_PY_REGISTER`: deserialize and store into the `registry` dict passed + to `_child_worker_loop`. +- `CTRL_PY_UNREGISTER`: remove from that same `registry`. +- Existing binary `CTRL_REGISTER`: before cascading the `ChipCallable` into + `inner_worker._register_at(...)`, remove `registry[cid]` from the + Worker-child dispatch registry. This self-heals stale Python callable residue + when a cid is reused as a `ChipCallable`. +- Any unknown control subcommand: write `OFF_ERROR`, publish `CONTROL_DONE`, + and leave both the parent-owned dispatch registry and `inner_worker` + unchanged. + +This registry is the dispatch registry used when the parent submits a cid to +the Worker child. It is distinct from `inner_worker._callable_registry`. +Updating it makes a dynamically registered parent orch function visible to +the already-started Worker child. + +The Python callable is not automatically cascaded into +`inner_worker._callable_registry`. Registering callables owned by an inner +Worker remains a separate operation on that Worker. This keeps cid ownership +local and avoids unexpected collisions with entries the inner Worker already +owns. + +Registry ownership in a Worker-child process is: + +- Parent `CTRL_PY_REGISTER`: mutates the parent dispatch `registry`, read by + `_child_worker_loop`; does not cascade. +- Parent `CTRL_PY_UNREGISTER`: removes from the parent dispatch `registry`; + does not cascade. +- Parent binary `CTRL_REGISTER`: mutates `inner_worker._callable_registry` and + cascades through the inner Worker's own register route. +- Parent binary `CTRL_UNREGISTER`: mutates `inner_worker._callable_registry` + and cascades through the inner Worker's own unregister route. +- Inner Worker register/unregister: mutates `inner_worker._callable_registry` + and is owned by the inner Worker. + +The parent dispatch registry and `inner_worker._callable_registry` may contain +the same numeric cid for different owners. A parent Python unregister must not +call `inner_worker.unregister(cid)`, because that could delete a callable that +belongs to the inner Worker. Cross-type cleanup before parent `ChipCallable` +reuse clears stale Python entries from SUB registries and from Worker-child +parent dispatch registries. It does not clear +`inner_worker._callable_registry`; the binary register then cascades into +`inner_worker` through the normal binary route. + +--- + +## 5. Failure Modes and Tests + +### Failure Semantics + +| Trigger | Handling | +| ------- | -------- | +| `cloudpickle` unavailable | `simpler.worker` import fails. | +| Serializer cannot encode target | Parent raises before cid allocation | +| Post-start no Python child group | Parent raises before cid allocation | +| cid space exhausted | Parent raises before parent mutation | +| Startup race | Wait, then use post-start route | +| Duplicate unregister for same cid | Raise before second broadcast | +| Child cannot open shm | Child writes `OFF_ERROR`; parent raises | +| Child receives invalid cid | Child writes `OFF_ERROR`; parent raises | +| Child deserialization fails | Child writes `OFF_ERROR`; parent raises | +| Result is not callable | Child writes `OFF_ERROR`; parent raises | +| Unknown control subcommand | Child writes `OFF_ERROR`; parent raises | +| Some children succeed before another fails | Parent raises; no rollback | +| Unregister fails on some children | Parent warns and pops its registry | +| Cross-type cid reuse | New register clears or overwrites child residue | +| Child `cloudpickle.loads` times out | Failed child result | +| Child crashes during control | Timeout result, or hang if unset | + +No reverse rollback is attempted after partial register success. A successful +child may retain a registry entry for a cid the parent reports as failed. +Future cid reuse must overwrite it for Python registration, or clear it before +`ChipCallable` registration, matching the best-effort unregister contract. + +Python deserialization has a larger liveness surface than binary callable +prepare. `cloudpickle.loads(...)` may import modules or run user-defined object +reconstruction hooks, and that code can block, spin, or wedge the child before +it writes `CONTROL_DONE`. For Python callable broadcasts, callers should pass a +finite `timeout_s` so `broadcast_control_all` can return a failed per-child +result instead of waiting forever. Timeout does not make the child healthy; it +only lets `Worker.register()` fail visibly and lets best-effort cleanup report +which child did not respond. Child liveness detection, process replacement, and +hierarchical recovery remain out of scope for this feature. + +### Concurrency + +- Parent registry mutation stays under `_registry_lock`. +- The first `Worker.run()` marks hierarchical startup as `starting` before + taking the startup registry snapshot. Concurrent register/unregister callers + wait for startup to finish, then use the correct post-start route. +- `_registry_lock` is released before any broadcast waits for child ACKs. The + parent registry entry, plus the pending-unregister marker for unregister, + keeps the cid unavailable for reuse while the IPC operation is in flight. +- Each child mailbox round trip stays under `mailbox_mu_`, so post-start Python + register/unregister can run during `Worker.run()` and will serialize against + `TASK_READY` on each recipient mailbox. +- `register()` is synchronous. A caller that races `register()` and + `Worker.run()` from different Python threads must still wait for + `register()` to return before submitting the new cid. +- Child registry mutation is serialized by the mailbox state machine. +- `unregister()` is synchronous from the caller's perspective. The user remains + responsible for not unregistering a cid with outstanding submitted work. + +### Test Plan + +Keep the first implementation's tests focused on behavior and ownership, not on +format evolution: + +- Unit test `cloudpickle` round trip for the supported callable shapes. +- Unit test that closures over serializable Python values work, and that + specific known-unpickleable captures fail before cid visibility. +- Unit test that child-side deserialize and execute failures are reported + through the normal mailbox error path. +- Unit test that Python register before children start uses the startup + registry path and performs no control broadcast. +- Unit test that first-run startup is serialized against Python register, so a + racing register cannot miss the startup registry snapshot. +- Unit test that post-start Python register during an active `Worker.run()` + succeeds after the relevant child mailbox reaches a safe control point. +- Unit test that unregister keeps the cid unavailable for reuse until its + broadcast has completed, even though `_registry_lock` is not held across the + broadcast. +- Unit test that post-start Python register rejects Workers with no SUB workers + and no next-level Worker children. +- Unit test selected-pool routing: `worker_type=SUB` reaches only + `sub_threads_`, and `worker_type=NEXT_LEVEL` reaches only + `next_level_threads_`. +- Unit test that `broadcast_control_all` returns the same structured + per-child result shape for register and unregister commands. +- Unit test that `broadcast_control_all(timeout_s=...)` reports a timed-out + child as a failed per-child result without blocking indefinitely. +- L3 integration test: start an L3 Worker with SUB workers, run once to start + children, dynamically register a Python sub callable, then + `submit_sub(cid, ...)`. +- L4 integration test: start an L4 Worker with an L3 child, run once to start + children, dynamically register an L3 orchestration callable on the L4 parent, + then `submit_next_level(cid, ...)`. +- Unregister test: once children have started, Python callable unregister + broadcasts `CTRL_PY_UNREGISTER`, pops the parent registry, and allows cid + reuse only after `unregister()` returns, regardless of whether the callable + was registered before or after children started. +- Cross-type reuse test: stale Python dispatch residue from a failed + best-effort unregister is cleared when the same cid is reused for a + `ChipCallable`. +- Failure test: unsupported or non-serializable callable raises without + consuming a parent cid slot. + +## Related + +- [task-flow.md](task-flow.md) explains how `Callable`, `TaskArgs`, and + `CallConfig` move through L3+ dispatch. +- [worker-manager.md](worker-manager.md) explains WorkerThread mailbox + dispatch and forked Python child loops. +- [callable-ipc-dynamic-register.md](callable-ipc-dynamic-register.md) + covers dynamic binary `ChipCallable` registration. diff --git a/docs/python-packaging.md b/docs/python-packaging.md index b7e936dd3..8e9bddff3 100644 --- a/docs/python-packaging.md +++ b/docs/python-packaging.md @@ -51,12 +51,19 @@ Internal coupling: `simpler_setup.toolchain`, `simpler_setup.kernel_compiler`, a | Category | Packages | | -------- | -------- | -| `simpler` runtime | No third-party Python deps. Requires platform backend: simulation (`a*sim`) or NPU hardware (`a2a3`/`a5` with CANN toolkit) | -| `simpler_setup` runtime | `torch` (tensor operations in golden scripts, test comparison) | +| `simpler` runtime | `cloudpickle`; platform backend | +| `simpler_setup` runtime | `torch` for golden/test tensor operations | | Build | `scikit-build-core`, `nanobind`, `cmake` | | Test | `pytest` (ut-py, st), `googletest` + `ctest` (ut-cpp) | -`pyproject.toml` declares no `[project.dependencies]` — both `torch` and `pytest` are environment prerequisites, not pip-installed transitively. This is intentional: torch's index URL (`--index-url https://download.pytorch.org/whl/cpu`) and hardware-specific builds make automatic resolution impractical. +`pyproject.toml` declares `cloudpickle` as a `[project.dependencies]` runtime +dependency. `torch` and `pytest` remain environment prerequisites, not +pip-installed transitively. This is intentional: torch's index URL +(`--index-url https://download.pytorch.org/whl/cpu`) and hardware-specific +builds make automatic resolution impractical. + +The `simpler` runtime also requires a platform backend: simulation (`a*sim`) or +NPU hardware (`a2a3`/`a5` with CANN toolkit). ### `PROJECT_ROOT` resolution diff --git a/pyproject.toml b/pyproject.toml index 20296b40b..fb819b591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ build-backend = "scikit_build_core.build" name = "simpler" version = "0.1.0" requires-python = ">=3.9" +dependencies = ["cloudpickle>=2.2"] [project.optional-dependencies] # ``torch>=2.3`` is required by ``simpler_setup.torch_interop`` (uses diff --git a/python/bindings/worker_bind.h b/python/bindings/worker_bind.h index b950054e3..19c6598dd 100644 --- a/python/bindings/worker_bind.h +++ b/python/bindings/worker_bind.h @@ -77,6 +77,12 @@ inline void bind_worker(nb::module_ &m) { // --- WorkerType --- nb::enum_(m, "WorkerType").value("NEXT_LEVEL", WorkerType::NEXT_LEVEL).value("SUB", WorkerType::SUB); + nb::class_(m, "ControlResult") + .def_ro("worker_type", &ControlResult::worker_type) + .def_ro("worker_index", &ControlResult::worker_index) + .def_ro("ok", &ControlResult::ok) + .def_ro("error_message", &ControlResult::error_message); + // --- TaskState --- nb::enum_(m, "TaskState") .value("FREE", TaskState::FREE) @@ -246,6 +252,33 @@ inline void bind_worker(nb::module_ &m) { "Best-effort broadcast of CTRL_UNREGISTER to every NEXT_LEVEL child in parallel. " "Returns a list of per-child error strings (empty on full success)." ) + .def( + "broadcast_control_all", + [](Worker &self, WorkerType worker_type, uint64_t sub_cmd, int32_t cid, nb::object payload, + nb::object timeout_s) { + std::string payload_bytes; + const void *payload_ptr = nullptr; + size_t payload_size = 0; + if (!payload.is_none()) { + Py_buffer view; + if (PyObject_GetBuffer(payload.ptr(), &view, PyBUF_CONTIG_RO) != 0) { + throw nb::python_error(); + } + payload_bytes.assign(static_cast(view.buf), static_cast(view.len)); + PyBuffer_Release(&view); + payload_ptr = payload_bytes.data(); + payload_size = payload_bytes.size(); + } + double timeout_val = timeout_s.is_none() ? -1.0 : nb::cast(timeout_s); + nb::gil_scoped_release release; + return self.broadcast_control_all(worker_type, sub_cmd, cid, payload_ptr, payload_size, timeout_val); + }, + nb::arg("worker_type"), nb::arg("sub_cmd"), nb::arg("cid"), nb::arg("payload") = nb::none(), + nb::arg("timeout_s") = nb::none(), + "Broadcast an arbitrary CONTROL_REQUEST to the selected worker pool. " + "If payload is a Python buffer, C++ stages it in POSIX shm and writes the shm name " + "into the mailbox. Returns per-child ControlResult entries." + ) .def( "control_alloc_domain", &Worker::control_alloc_domain, nb::arg("worker_id"), nb::arg("request_shm_name"), nb::arg("reply_shm_name"), nb::call_guard(), diff --git a/python/simpler/worker.py b/python/simpler/worker.py index ae816c0c0..e4956e708 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -11,8 +11,9 @@ Callable identity is a ``cid`` (int), allocated exclusively by ``Worker.register(callable)``. ``Worker.run`` and the orchestrator's ``submit_next_level`` / ``submit_sub`` all take this cid — never the raw -``ChipCallable`` / Python function. L≥3 ``register()`` must run **before** -``init()`` so forked chip / sub children inherit the registry via COW. +``ChipCallable`` / Python function. L≥3 Python callables registered before +child startup are inherited through the fork-time snapshot; later +registrations are serialized and sent through the mailbox control plane. Usage:: @@ -64,9 +65,11 @@ def my_l4_orch(orch, args, config): from multiprocessing.shared_memory import SharedMemory from typing import Any, Optional +import cloudpickle from _task_interface import ( # pyright: ignore[reportMissingImports] MAX_REGISTERED_CALLABLE_IDS, RunTiming, + WorkerType, _mailbox_load_i32, _mailbox_store_i32, read_args_from_blob, @@ -93,6 +96,7 @@ def my_l4_orch(orch, args, config): # that a hung child fails the suite instead of the CI job timing out. _BOOTSTRAP_WAIT_TIMEOUT_S = 120.0 _BOOTSTRAP_POLL_INTERVAL_S = 0.001 +_PY_CONTROL_TIMEOUT_S = 30.0 # --------------------------------------------------------------------------- @@ -163,11 +167,18 @@ def my_l4_orch(orch, args, config): # rootinfo_path) and caches the handle on the ChipWorker so subsequent # CTRL_ALLOC_DOMAIN calls can find it. _CTRL_COMM_INIT = 9 +_CTRL_PY_REGISTER = 10 +_CTRL_PY_UNREGISTER = 11 # Layout of the CTRL_COMM_INIT request shm. _COMM_INIT_HEADER = struct.Struct(" bytes: + payload = cloudpickle.dumps(target) + return ( + _PY_CALLABLE_HEADER.pack( + _PY_CALLABLE_MAGIC, + _PY_CALLABLE_VERSION, + _PY_CALLABLE_SERIALIZER_CLOUDPICKLE, + 0, + len(payload), + ) + + payload + ) + + +def _load_py_callable_from_shm(shm_name: str): + shm = SharedMemory(name=shm_name) + try: + shm_buf = shm.buf + assert shm_buf is not None + if shm.size < _PY_CALLABLE_HEADER.size: + raise RuntimeError(f"python callable payload too small: {shm.size} bytes") + magic, version, serializer, flags, payload_size = _PY_CALLABLE_HEADER.unpack_from(shm_buf, 0) + if magic != _PY_CALLABLE_MAGIC: + raise RuntimeError(f"invalid python callable payload magic: {magic!r}") + if version != _PY_CALLABLE_VERSION: + raise RuntimeError(f"unsupported python callable payload version: {version}") + if serializer != _PY_CALLABLE_SERIALIZER_CLOUDPICKLE: + raise RuntimeError(f"unsupported python callable serializer: {serializer}") + if flags != 0: + raise RuntimeError(f"unsupported python callable payload flags: {flags}") + expected_size = _PY_CALLABLE_HEADER.size + int(payload_size) + if expected_size > shm.size: + raise RuntimeError(f"python callable payload size mismatch: header={payload_size}, shm={shm.size}") + payload = bytes(shm_buf[_PY_CALLABLE_HEADER.size : expected_size]) + finally: + shm.close() + + fn = cloudpickle.loads(payload) + if not callable(fn): + raise RuntimeError(f"python callable payload decoded to non-callable {type(fn).__name__}") + return fn + + +def _handle_py_callable_control(buf, registry: dict, sub_cmd: int, *, context: str) -> None: + cid = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) & 0xFFFFFFFF + if cid >= MAX_REGISTERED_CALLABLE_IDS: + raise RuntimeError(f"{context}: cid {cid} out of range") + if sub_cmd == _CTRL_PY_REGISTER: + shm_name = _read_shm_name(buf, _OFF_ARGS) + registry[cid] = _load_py_callable_from_shm(shm_name) + elif sub_cmd == _CTRL_PY_UNREGISTER: + registry.pop(cid, None) + else: + raise RuntimeError(f"{context}: unknown control sub-command {int(sub_cmd)}") + + def _mailbox_addr(shm: SharedMemory) -> int: buf = shm.buf assert buf is not None @@ -299,6 +366,17 @@ def _sub_worker_loop(buf, registry: dict) -> None: msg = _format_exc("sub_worker", e) _write_error(buf, code, msg) _mailbox_store_i32(state_addr, _TASK_DONE) + elif state == _CONTROL_REQUEST: + sub_cmd = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0] + code = 0 + msg = "" + try: + _handle_py_callable_control(buf, registry, int(sub_cmd), context="sub_worker") + except Exception as e: # noqa: BLE001 + code = 1 + msg = _format_exc("sub_worker control", e) + _write_error(buf, code, msg) + _mailbox_store_i32(state_addr, _CONTROL_DONE) elif state == _SHUTDOWN: break @@ -575,6 +653,8 @@ def _run_chip_main_loop( # noqa: PLR0912 -- TASK_READY + 6 control sub-commands _handle_ctrl_release_domain(cw, buf) elif sub_cmd == _CTRL_COMM_INIT: _handle_ctrl_comm_init(cw, buf) + else: + raise RuntimeError(f"unknown control sub-command {int(sub_cmd)}") except Exception as e: # noqa: BLE001 code = 1 if sub_cmd in (_CTRL_REGISTER, _CTRL_UNREGISTER): @@ -709,16 +789,34 @@ def _child_worker_loop( # cid_val onto the registry slot keeps the inner-side # cid identical to the outer-side cid — both the L4 # scheduler and the L3 children index by the same int. + registry.pop(int(cid_val), None) inner_worker._register_at(int(cid_val), callable_obj) elif sub_cmd == _CTRL_UNREGISTER: cid_val = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) & 0xFFFFFFFF inner_worker.unregister(int(cid_val)) + elif sub_cmd in (_CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER): + _handle_py_callable_control( + buf, + registry, + int(sub_cmd), + context=f"child_worker level={inner_worker.level}", + ) + else: + raise RuntimeError(f"unknown control sub-command {int(sub_cmd)}") except Exception as e: # noqa: BLE001 code = 1 op = ( "register" if sub_cmd == _CTRL_REGISTER - else ("unregister" if sub_cmd == _CTRL_UNREGISTER else f"ctrl={int(sub_cmd)}") + else ( + "unregister" + if sub_cmd == _CTRL_UNREGISTER + else ( + "py_register" + if sub_cmd == _CTRL_PY_REGISTER + else ("py_unregister" if sub_cmd == _CTRL_PY_UNREGISTER else f"ctrl={int(sub_cmd)}") + ) + ) ) msg = _format_exc(f"child_worker level={inner_worker.level} {op}", e) _write_error(buf, code, msg) @@ -760,6 +858,12 @@ def __init__( # dispatch) is now handled at the C++ boundary via mailbox_mu_, so # no quiescent-state guard is needed. self._registry_lock = threading.Lock() + self._pending_unregister_cids: set[int] = set() + self._py_callable_cids_seen: set[int] = set() + self._py_control_timeout_s = float(config.get("py_control_timeout_s", _PY_CONTROL_TIMEOUT_S)) + self._hierarchical_start_state = "not_started" + self._hierarchical_start_mu = threading.Lock() + self._hierarchical_start_cv = threading.Condition(self._hierarchical_start_mu) # Level-2 internals self._chip_worker: Optional[ChipWorker] = None @@ -820,51 +924,128 @@ def register(self, target) -> int: ``orch.submit_sub(cid, …)``. Timing constraints: - - L3+: Python callables (sub fn / orch fn) must be registered - **before** ``init()`` so the COW-inherited registry is visible to - forked chip / sub children. ChipCallables may be registered either - before init (pre-warmed via ``_CTRL_PREPARE`` during ``init()``) - or after init (broadcast to chip children via - ``_Worker.broadcast_register_all``; see - docs/callable-ipc-dynamic-register.md). Post-init register at - L3+ is ChipCallable-only. + - L3+: registrations before child processes start are inherited + by forked children through the startup registry snapshot. + Registrations after child processes start use the mailbox + control plane: ChipCallables keep the binary path, while Python + callables are serialized with cloudpickle and broadcast to + Python-capable child groups. - L2: may be called either before or after ``init()`` (no fork, no COW constraint). When called post-init, ChipCallables are prepared on the device immediately; pre-init registrations are batched and prepared at the end of ``init()``. + + See docs/python-callable-serialization.md for the Python dynamic + register path and docs/callable-ipc-dynamic-register.md for the + ChipCallable binary path. """ + if self.level == 2 and not isinstance(target, ChipCallable): + raise TypeError("Worker.register: level 2 only supports ChipCallable targets") + if self.level >= 3: + if not isinstance(target, ChipCallable): + if not callable(target): + raise TypeError("Worker.register: non-ChipCallable target must be callable") + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + if self._hierarchical_start_state != "started" and not getattr(self, "_hierarchical_started", False): + with self._registry_lock: + cid = self._allocate_cid() + self._callable_registry[cid] = target + if not isinstance(target, ChipCallable): + self._py_callable_cids_seen.add(cid) + return cid + if not isinstance(target, ChipCallable): + return self._post_start_register_python(target) + with self._registry_lock: - if self.level >= 3 and self._initialized and not isinstance(target, ChipCallable): - # L3+ post-init: only ChipCallable can cross the process - # boundary. Python callables (sub fn / orch fn) must be - # registered before init() so forked children inherit them. - raise NotImplementedError( - "Worker.register() at level >= 3 must be called before init() " - "for non-ChipCallable targets; only ChipCallable is supported " - "post-init (see docs/callable-ipc-dynamic-register.md)" - ) cid = self._allocate_cid() self._callable_registry[cid] = target + if self.level >= 3 and not isinstance(target, ChipCallable): + self._py_callable_cids_seen.add(cid) + + # L3+ post-init ChipCallable: broadcast to chip / next-level children + # via C++ after parent-side cid allocation is complete. The registry + # entry keeps the cid reserved while mailbox_mu_ serializes the wire + # round trip against dispatch. + if self.level >= 3 and self._initialized and isinstance(target, ChipCallable): + try: + self._post_init_register(cid, target) + except Exception: + with self._registry_lock: + if self._callable_registry.get(cid) is target: + self._callable_registry.pop(cid, None) + raise + return cid - # L3+ post-init ChipCallable: broadcast to chip / next-level - # children via C++. Done inside the registry lock so a concurrent - # register cannot allocate the same cid we are about to pop on - # failure. Per-WorkerThread mailbox_mu_ already provides the C++ - # serialisation against in-flight dispatch. - if self.level >= 3 and self._initialized and isinstance(target, ChipCallable): - try: - self._post_init_register(cid, target) - except Exception: + # L2 post-init: pre-warm immediately so the very first + # `Worker.run(cid, …)` is a clean cache hit. + if self.level == 2 and self._initialized and isinstance(target, ChipCallable): + assert self._chip_worker is not None + self._chip_worker.prepare_callable(cid, target) + return cid + + def _python_worker_types(self) -> list[WorkerType]: + worker_types: list[WorkerType] = [] + if self._config.get("num_sub_workers", 0) > 0: + worker_types.append(WorkerType.SUB) + if self._next_level_workers: + worker_types.append(WorkerType.NEXT_LEVEL) + return worker_types + + def _post_start_register_python(self, target) -> int: + worker_types = self._python_worker_types() + if not worker_types: + raise RuntimeError( + "Worker.register: no Python-capable child workers are configured " + "for dynamic Python callable registration" + ) + payload = _pack_py_callable_payload(target) + with self._registry_lock: + cid = self._allocate_cid() + self._callable_registry[cid] = target + self._py_callable_cids_seen.add(cid) + try: + self._broadcast_py_control(worker_types, _CTRL_PY_REGISTER, cid, payload=payload, strict=True) + except Exception: + with self._registry_lock: + if self._callable_registry.get(cid) is target: self._callable_registry.pop(cid, None) - raise - return cid + raise + return cid - # L2 post-init: pre-warm immediately so the very first - # `Worker.run(cid, …)` is a clean cache hit. - if self.level == 2 and self._initialized and isinstance(target, ChipCallable): - assert self._chip_worker is not None - self._chip_worker.prepare_callable(cid, target) - return cid + def _broadcast_py_control( + self, + worker_types: list[WorkerType], + sub_cmd: int, + cid: int, + *, + payload: Optional[bytes] = None, + strict: bool, + ) -> list[str]: + if not worker_types: + return [] + assert self._worker is not None + errors: list[str] = [] + for worker_type in worker_types: + results = self._worker.broadcast_control_all( + worker_type, + int(sub_cmd), + int(cid), + payload, + timeout_s=self._py_control_timeout_s, + ) + for result in results: + if not result.ok: + errors.append(f"{result.worker_type}[{result.worker_index}]: {result.error_message}") + if errors and strict: + raise RuntimeError( + f"Worker control broadcast cid={cid} sub_cmd={sub_cmd} failed on " + f"{len(errors)} child workers; first error: {errors[0]}" + ) + return errors def _allocate_cid(self) -> int: """Return the smallest unused cid in [0, MAX_REGISTERED_CALLABLE_IDS). @@ -875,7 +1056,7 @@ def _allocate_cid(self) -> int: would silently overwrite the next gap-after-the-hole. """ for i in range(MAX_REGISTERED_CALLABLE_IDS): - if i not in self._callable_registry: + if i not in self._callable_registry and i not in self._pending_unregister_cids: return i # The AICPU side keeps a fixed-size orch_so_table_ keyed by cid; # raise here so the failure surfaces at register-time with a @@ -897,18 +1078,21 @@ def _register_at(self, cid: int, target: ChipCallable) -> None: on a single integer key. Plain ``register`` allocates the next free slot and is therefore unsuitable here. """ + if not isinstance(target, ChipCallable): + raise TypeError("_register_at: target must be a ChipCallable") with self._registry_lock: if cid in self._callable_registry: raise RuntimeError(f"_register_at: cid={cid} already occupied") - if not isinstance(target, ChipCallable): - raise TypeError("_register_at: target must be a ChipCallable") self._callable_registry[cid] = target - if self.level >= 3 and self._initialized: - try: - self._post_init_register(cid, target) - except Exception: - self._callable_registry.pop(cid, None) - raise + + if self.level >= 3 and self._initialized: + try: + self._post_init_register(cid, target) + except Exception: + with self._registry_lock: + if self._callable_registry.get(cid) is target: + self._callable_registry.pop(cid, None) + raise def _post_init_register(self, cid: int, target: ChipCallable) -> None: """Broadcast a new ChipCallable to every NEXT_LEVEL child via C++. @@ -927,8 +1111,31 @@ def _post_init_register(self, cid: int, target: ChipCallable) -> None: if not getattr(self, "_hierarchical_started", False): return assert self._worker is not None + if cid in self._py_callable_cids_seen: + self._broadcast_py_control(self._python_worker_types(), _CTRL_PY_UNREGISTER, cid, strict=True) + self._py_callable_cids_seen.discard(cid) self._worker.broadcast_register_all(int(cid), int(target.buffer_ptr()), int(target.buffer_size())) + def _pre_start_unregister_if_needed(self, cid: int) -> bool: + if self.level < 3: + return False + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + if self._hierarchical_start_state == "started" or getattr(self, "_hierarchical_started", False): + return False + with self._registry_lock: + if cid not in self._callable_registry: + raise KeyError(f"Worker.unregister: cid={cid} not registered") + if cid in self._pending_unregister_cids: + raise KeyError(f"Worker.unregister: cid={cid} already pending unregister") + target = self._callable_registry.pop(cid) + if not isinstance(target, ChipCallable): + self._py_callable_cids_seen.discard(cid) + return True + def unregister(self, cid: int) -> None: """Drop *cid* from the registry and propagate to chip children. @@ -947,17 +1154,46 @@ def unregister(self, cid: int) -> None: Raises: KeyError: cid was never registered. """ + if self._pre_start_unregister_if_needed(cid): + return + target = None with self._registry_lock: if cid not in self._callable_registry: raise KeyError(f"Worker.unregister: cid={cid} not registered") + if cid in self._pending_unregister_cids: + raise KeyError(f"Worker.unregister: cid={cid} already pending unregister") + target = self._callable_registry[cid] if self.level >= 3 and self._initialized and getattr(self, "_hierarchical_started", False): - self._broadcast_unregister(cid) + self._pending_unregister_cids.add(cid) elif self.level == 2 and self._initialized: assert self._chip_worker is not None self._chip_worker.unregister_callable(cid) - # Drop the registry entry unconditionally — even if a chip child - # reported an error, holding the slot would just waste cid budget. - self._callable_registry.pop(cid, None) + self._callable_registry.pop(cid, None) + return + else: + self._callable_registry.pop(cid, None) + return + + try: + if isinstance(target, ChipCallable): + self._broadcast_unregister(cid) + else: + errors = self._broadcast_py_control( + self._python_worker_types(), + _CTRL_PY_UNREGISTER, + cid, + strict=False, + ) + if errors: + sys.stderr.write( + f"Worker.unregister(cid={cid}): {len(errors)} Python children reported errors " + f"(continuing best-effort). First error: {errors[0]}\n" + ) + sys.stderr.flush() + finally: + with self._registry_lock: + self._callable_registry.pop(cid, None) + self._pending_unregister_cids.discard(cid) def _broadcast_unregister(self, cid: int) -> None: """Broadcast _CTRL_UNREGISTER via C++ to every NEXT_LEVEL child. @@ -983,6 +1219,8 @@ def add_worker(self, worker: "Worker") -> None: """ if self.level < 4: raise RuntimeError("Worker.add_worker() requires level >= 4") + if self._config.get("device_ids", []): + raise RuntimeError("Worker.add_worker() cannot be combined with device_ids on the same Worker") if self._initialized: raise RuntimeError("Worker.add_worker() must be called before init()") if worker._initialized: @@ -1030,6 +1268,8 @@ def _init_hierarchical(self) -> None: device_ids = self._config.get("device_ids", []) n_sub = self._config.get("num_sub_workers", 0) heap_ring_size = self._config.get("heap_ring_size", None) + if self.level >= 4 and device_ids: + raise RuntimeError("Worker level >= 4 must use add_worker(); device_ids are only supported on L3 Workers") # 1. Allocate sub-worker mailboxes (unified layout, MAILBOX_SIZE each). for _ in range(n_sub): @@ -1081,97 +1321,118 @@ def _init_hierarchical(self) -> None: def _start_hierarchical(self) -> None: # noqa: PLR0912 -- three parallel fork loops (sub/chip/next) + bootstrap wait + scheduler register/init; branches track the fork order documented in the body """Fork child processes and start C++ scheduler. Called on first run().""" - if self._hierarchical_started: - return - self._hierarchical_started = True - device_ids = self._config.get("device_ids", []) n_sub = self._config.get("num_sub_workers", 0) - # Fork SubWorker processes (MUST be before any C++ threads) - registry = self._callable_registry - for i in range(n_sub): - pid = os.fork() - if pid == 0: - buf = self._sub_shms[i].buf - assert buf is not None - _sub_worker_loop(buf, registry) - os._exit(0) - else: - self._sub_pids.append(pid) - - # Fork ChipWorker processes (L3 with device_ids). Always use the - # plain task-loop variant; the base communicator is established - # lazily on first ``orch.allocate_domain`` via CTRL_COMM_INIT. - chip_log_level, chip_log_info_v = _simpler_log.get_current_config() - if device_ids: - for idx, dev_id in enumerate(device_ids): + try: + # Fork children from an immutable snapshot. The state transition + # and snapshot are one gate, so dynamic register/unregister callers + # cannot return through the pre-start path after this point. + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "started": + return + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + self._hierarchical_start_state = "starting" + with self._registry_lock: + registry = dict(self._callable_registry) + self._hierarchical_start_cv.notify_all() + + # Fork SubWorker processes (MUST be before any C++ threads) + for i in range(n_sub): pid = os.fork() if pid == 0: - buf = self._chip_shms[idx].buf + buf = self._sub_shms[i].buf assert buf is not None - _chip_process_loop( - buf, - self._l3_bins, - dev_id, - registry, - chip_log_level, - chip_log_info_v, - ) + _sub_worker_loop(buf, registry) os._exit(0) else: - self._chip_pids.append(pid) - - # Fork next-level Worker children (L4+ with Worker children). - # Each child process: init the inner Worker (which mmaps its own - # HeapRing and allocates its own child mailboxes), then enter - # _child_worker_loop. The inner Worker's own children are forked - # lazily on first run() inside _child_worker_loop, so the process - # tree nests correctly: L4 → L3 child → L3's chip/sub children. - for idx, inner_worker in enumerate(self._next_level_workers): - pid = os.fork() - if pid == 0: - buf = self._next_level_shms[idx].buf - assert buf is not None - inner_worker.init() - _child_worker_loop(buf, registry, inner_worker) - os._exit(0) - else: - self._next_level_pids.append(pid) - - # _Worker was constructed in _init_hierarchical (pre-fork) so - # children inherit the HeapRing MAP_SHARED mmap. Register PROCESS-mode - # workers via the unified mailbox. - dw = self._worker - assert dw is not None - - # Register chip workers as NEXT_LEVEL (L3) - if device_ids: - for shm in self._chip_shms: - dw.add_next_level_worker(_mailbox_addr(shm)) - - # Register Worker children as NEXT_LEVEL (L4+) - for shm in self._next_level_shms: - dw.add_next_level_worker(_mailbox_addr(shm)) + self._sub_pids.append(pid) + + # Fork ChipWorker processes (L3 with device_ids). Always use the + # plain task-loop variant; the base communicator is established + # lazily on first ``orch.allocate_domain`` via CTRL_COMM_INIT. + chip_log_level, chip_log_info_v = _simpler_log.get_current_config() + if device_ids: + for idx, dev_id in enumerate(device_ids): + pid = os.fork() + if pid == 0: + buf = self._chip_shms[idx].buf + assert buf is not None + _chip_process_loop( + buf, + self._l3_bins, + dev_id, + registry, + chip_log_level, + chip_log_info_v, + ) + os._exit(0) + else: + self._chip_pids.append(pid) + + # Fork next-level Worker children (L4+ with Worker children). + # Each child process: init the inner Worker (which mmaps its own + # HeapRing and allocates its own child mailboxes), then enter + # _child_worker_loop. The inner Worker's own children are forked + # lazily on first run() inside _child_worker_loop, so the process + # tree nests correctly: L4 → L3 child → L3's chip/sub children. + for idx, inner_worker in enumerate(self._next_level_workers): + pid = os.fork() + if pid == 0: + buf = self._next_level_shms[idx].buf + assert buf is not None + inner_worker.init() + _child_worker_loop(buf, registry, inner_worker) + os._exit(0) + else: + self._next_level_pids.append(pid) - for shm in self._sub_shms: - dw.add_sub_worker(_mailbox_addr(shm)) + # _Worker was constructed in _init_hierarchical (pre-fork) so + # children inherit the HeapRing MAP_SHARED mmap. Register PROCESS-mode + # workers via the unified mailbox. + dw = self._worker + assert dw is not None - # Start Scheduler + WorkerThreads (C++ threads start here, after fork) - dw.init() + # Register chip workers as NEXT_LEVEL (L3) + if device_ids: + for shm in self._chip_shms: + dw.add_next_level_worker(_mailbox_addr(shm)) - self._orch = Orchestrator(dw.get_orchestrator(), self) + # Register Worker children as NEXT_LEVEL (L4+) + for shm in self._next_level_shms: + dw.add_next_level_worker(_mailbox_addr(shm)) - # Pre-warm every chip child: for each registered ChipCallable cid, - # send `_CTRL_PREPARE` to all chip children so the first - # `submit_next_level` does not pay the H2D upload cost. Sub fns / - # orch fns do not need pre-warming — the registry is already - # COW-inherited. - if device_ids: - for cid, target in self._callable_registry.items(): - if isinstance(target, ChipCallable): - for worker_id in range(len(self._chip_shms)): - dw.control_prepare(worker_id, int(cid)) + for shm in self._sub_shms: + dw.add_sub_worker(_mailbox_addr(shm)) + + # Start Scheduler + WorkerThreads (C++ threads start here, after fork) + dw.init() + + self._orch = Orchestrator(dw.get_orchestrator(), self) + + # Pre-warm every chip child: for each registered ChipCallable cid, + # send `_CTRL_PREPARE` to all chip children so the first + # `submit_next_level` does not pay the H2D upload cost. Sub fns / + # orch fns do not need pre-warming — the registry is already + # COW-inherited. + if device_ids: + for cid, target in registry.items(): + if isinstance(target, ChipCallable): + for worker_id in range(len(self._chip_shms)): + dw.control_prepare(worker_id, int(cid)) + + self._hierarchical_started = True + with self._hierarchical_start_cv: + self._hierarchical_start_state = "started" + self._hierarchical_start_cv.notify_all() + except Exception: + with self._hierarchical_start_cv: + self._hierarchical_start_state = "failed" + self._hierarchical_start_cv.notify_all() + raise # ------------------------------------------------------------------ # Hierarchical abort diff --git a/src/common/hierarchical/worker.h b/src/common/hierarchical/worker.h index c90f05af6..3ff7ec1be 100644 --- a/src/common/hierarchical/worker.h +++ b/src/common/hierarchical/worker.h @@ -115,6 +115,11 @@ class Worker { manager_.broadcast_register_all(cid, reinterpret_cast(blob_ptr), static_cast(blob_size)); } std::vector broadcast_unregister_all(int32_t cid) { return manager_.broadcast_unregister_all(cid); } + std::vector broadcast_control_all( + WorkerType type, uint64_t sub_cmd, int32_t cid, const void *payload, size_t payload_size, double timeout_s + ) { + return manager_.broadcast_control_all(type, sub_cmd, cid, payload, payload_size, timeout_s); + } private: int32_t level_; diff --git a/src/common/hierarchical/worker_manager.cpp b/src/common/hierarchical/worker_manager.cpp index f50e4bc2d..c26f3e2fe 100644 --- a/src/common/hierarchical/worker_manager.cpp +++ b/src/common/hierarchical/worker_manager.cpp @@ -157,6 +157,9 @@ void WorkerThread::dispatch_process(TaskSlotState &s, int32_t group_index) { // orch thread waits for the dispatch to finish before claiming the // mailbox; without this they would race on MAILBOX_OFF_STATE. std::lock_guard lk(mailbox_mu_); + if (mailbox_control_timed_out_) { + throw std::runtime_error("WorkerThread::dispatch_process: mailbox has an unresolved timed-out control command"); + } // Clear the child-writable error fields so stale bytes from a prior // dispatch cannot masquerade as a fresh failure. @@ -338,12 +341,26 @@ static uint64_t read_control_result(const char *mbox) { // from the child, throws and leaves the mailbox in IDLE before unwinding // (so the next claim starts from a clean state). The `op_name` is used // only for the exception message. -void WorkerThread::run_control_command(const char *op_name) { +void WorkerThread::run_control_command(const char *op_name, double timeout_s) { + if (mailbox_control_timed_out_) { + throw std::runtime_error(std::string(op_name) + " failed: mailbox has an unresolved timed-out control command"); + } int32_t zero_err = 0; std::memcpy(mbox() + MAILBOX_OFF_ERROR, &zero_err, sizeof(int32_t)); std::memset(mbox() + MAILBOX_OFF_ERROR_MSG, 0, MAILBOX_ERROR_MSG_SIZE); write_mailbox_state(MailboxState::CONTROL_REQUEST); - while (read_mailbox_state() != MailboxState::CONTROL_DONE) {} + auto deadline = std::chrono::steady_clock::time_point::max(); + if (timeout_s >= 0.0) { + deadline = + std::chrono::steady_clock::now() + + std::chrono::duration_cast(std::chrono::duration(timeout_s)); + } + while (read_mailbox_state() != MailboxState::CONTROL_DONE) { + if (std::chrono::steady_clock::now() >= deadline) { + mailbox_control_timed_out_ = true; + throw std::runtime_error(std::string(op_name) + " timed out waiting for CONTROL_DONE"); + } + } int32_t err = 0; std::memcpy(&err, mbox() + MAILBOX_OFF_ERROR, sizeof(int32_t)); if (err != 0) { @@ -392,6 +409,21 @@ void WorkerThread::control_unregister(int32_t cid) { run_control_command("control_unregister"); } +void WorkerThread::control_generic(uint64_t sub_cmd, int32_t cid, const char *shm_name, double timeout_s) { + std::lock_guard lk(mailbox_mu_); + std::memcpy(mbox() + MAILBOX_OFF_CALLABLE, &sub_cmd, sizeof(uint64_t)); + uint64_t cid_v = static_cast(cid); + std::memcpy(mbox() + CTRL_OFF_ARG0, &cid_v, sizeof(uint64_t)); + const char *name = shm_name ? shm_name : ""; + size_t name_len = std::strlen(name); + if (name_len + 1 > CTRL_SHM_NAME_BYTES) { + throw std::runtime_error(std::string("control_generic: shm name too long: ") + name); + } + if (name_len > 0) std::memcpy(mbox() + MAILBOX_OFF_ARGS, name, name_len); + std::memset(mbox() + MAILBOX_OFF_ARGS + name_len, 0, CTRL_SHM_NAME_BYTES - name_len); + run_control_command("control_generic", timeout_s); +} + void WorkerThread::control_free(uint64_t ptr) { std::lock_guard lk(mailbox_mu_); write_control_args(mbox(), CTRL_FREE, ptr); @@ -661,3 +693,45 @@ std::vector WorkerManager::broadcast_unregister_all(int32_t cid) { } return errors; } + +std::vector WorkerManager::broadcast_control_all( + WorkerType type, uint64_t sub_cmd, int32_t cid, const void *payload, size_t payload_size, double timeout_s +) { + auto &threads = (type == WorkerType::NEXT_LEVEL) ? next_level_threads_ : sub_threads_; + const char *type_name = (type == WorkerType::NEXT_LEVEL) ? "NEXT_LEVEL" : "SUB"; + + std::vector results; + results.reserve(threads.size()); + for (size_t i = 0; i < threads.size(); ++i) { + results.push_back(ControlResult{type_name, static_cast(i), true, ""}); + } + if (threads.empty()) return results; + + std::unique_ptr shm; + std::string shm_name; + if (payload != nullptr || payload_size != 0) { + if (payload == nullptr || payload_size == 0) { + throw std::runtime_error("broadcast_control_all: payload pointer and size must both be set"); + } + shm_name = make_shm_name(cid); + shm = std::make_unique(shm_name, payload_size); + std::memcpy(shm->addr(), payload, payload_size); + } + + std::vector workers; + workers.reserve(threads.size()); + for (size_t i = 0; i < threads.size(); ++i) { + workers.emplace_back([&, i]() { + try { + threads[i]->control_generic(sub_cmd, cid, shm_name.empty() ? nullptr : shm_name.c_str(), timeout_s); + } catch (const std::exception &e) { + results[i].ok = false; + results[i].error_message = strip_control_prefix(e.what(), "control_generic"); + } + }); + } + for (auto &t : workers) + t.join(); + + return results; +} diff --git a/src/common/hierarchical/worker_manager.h b/src/common/hierarchical/worker_manager.h index f7b00b4ff..76a4bf2c7 100644 --- a/src/common/hierarchical/worker_manager.h +++ b/src/common/hierarchical/worker_manager.h @@ -121,6 +121,8 @@ static constexpr uint64_t CTRL_RELEASE_DOMAIN = 8; // Caches the comm handle on the chip's ChipWorker so subsequent // CTRL_ALLOC_DOMAIN calls can find it. static constexpr uint64_t CTRL_COMM_INIT = 9; +static constexpr uint64_t CTRL_PY_REGISTER = 10; +static constexpr uint64_t CTRL_PY_UNREGISTER = 11; // Control args reuse the task mailbox region (mutually exclusive with task dispatch): // offset 16: uint64 arg0 (size for malloc; ptr for free; dst for copy; cid for register) @@ -137,6 +139,13 @@ static constexpr ptrdiff_t CTRL_OFF_RESULT = 40; // of "simpler-cb---" with pid < 32-bit max. static constexpr size_t CTRL_SHM_NAME_BYTES = 32; +struct ControlResult { + std::string worker_type; + int32_t worker_index{0}; + bool ok{false}; + std::string error_message; +}; + // ============================================================================= // WorkerDispatch — per-dispatch handle handed to a WorkerThread. // ============================================================================= @@ -213,6 +222,7 @@ class WorkerThread { // for the in-flight TASK_DONE before claiming the mailbox. void control_register(int32_t cid, const char *shm_name); void control_unregister(int32_t cid); + void control_generic(uint64_t sub_cmd, int32_t cid, const char *shm_name, double timeout_s); // Dynamic CommDomain allocate / release. `request_shm_name` carries the // request payload (header + rank_ids + buffer_nbytes); for alloc the child @@ -244,6 +254,7 @@ class WorkerThread { // dispatch loop and the orch-thread control_* path. Per-WorkerThread, // so different workers can dispatch in parallel. std::mutex mailbox_mu_; + bool mailbox_control_timed_out_{false}; void loop(); void dispatch_process(TaskSlotState &s, int32_t group_index); @@ -251,7 +262,7 @@ class WorkerThread { // Common tail for the four control_* methods. Caller writes the args // region and holds `mailbox_mu_`; this helper signals the child, // spin-polls CONTROL_DONE, and throws on a non-zero child error code. - void run_control_command(const char *op_name); + void run_control_command(const char *op_name, double timeout_s = -1.0); char *mbox() const { return static_cast(mailbox_); } MailboxState read_mailbox_state() const; @@ -312,6 +323,9 @@ class WorkerManager { // worker in parallel. Returns a vector of per-worker error strings // (empty on full success). Caller decides whether to log / surface. std::vector broadcast_unregister_all(int32_t cid); + std::vector broadcast_control_all( + WorkerType type, uint64_t sub_cmd, int32_t cid, const void *payload, size_t payload_size, double timeout_s + ); // Write SHUTDOWN to every registered mailbox. void shutdown_children(); diff --git a/tests/ut/py/test_worker/test_host_worker.py b/tests/ut/py/test_worker/test_host_worker.py index 4a5d11079..51db62f19 100644 --- a/tests/ut/py/test_worker/test_host_worker.py +++ b/tests/ut/py/test_worker/test_host_worker.py @@ -18,8 +18,28 @@ import pytest from _task_interface import MAX_REGISTERED_CALLABLE_IDS # pyright: ignore[reportMissingImports] -from simpler.task_interface import ChipCallable, DataType, TaskArgs, TensorArgType -from simpler.worker import Worker +from simpler.task_interface import ( + MAILBOX_SIZE, + ChipCallable, + DataType, + TaskArgs, + TensorArgType, + WorkerType, + _Worker, +) +from simpler.worker import ( + _CONTROL_REQUEST, + _CTRL_PY_REGISTER, + _CTRL_PY_UNREGISTER, + _IDLE, + _OFF_STATE, + Worker, + _buffer_field_addr, + _mailbox_addr, + _mailbox_load_i32, + _mailbox_store_i32, + _pack_py_callable_payload, +) # --------------------------------------------------------------------------- # Helpers @@ -44,6 +64,33 @@ def _increment_counter(buf) -> None: struct.pack_into("i", buf, 0, v + 1) +def _add_counter(buf, delta: int) -> None: + v = struct.unpack_from("i", buf, 0)[0] + struct.pack_into("i", buf, 0, v + delta) + + +def _set_flag(buf, offset: int, value: int) -> None: + struct.pack_into("i", buf, offset, value) + + +def _get_flag(buf, offset: int) -> int: + return struct.unpack_from("i", buf, offset)[0] + + +def _roundtrip_py_callable_payload(target): + from simpler.worker import _load_py_callable_from_shm, _pack_py_callable_payload # noqa: PLC0415 + + payload = _pack_py_callable_payload(target) + shm = SharedMemory(create=True, size=len(payload)) + try: + assert shm.buf is not None + shm.buf[: len(payload)] = payload + return _load_py_callable_from_shm(shm.name) + finally: + shm.close() + shm.unlink() + + # --------------------------------------------------------------------------- # Test: lifecycle (init / close without submitting any tasks) # --------------------------------------------------------------------------- @@ -65,15 +112,168 @@ def test_context_manager(self): hw.register(lambda args: None) # close() called by __exit__, no exception - def test_register_python_fn_after_init_raises(self): - # Post-init register of a non-ChipCallable (lambda / sub fn) is - # rejected because Python callables cannot cross the fork boundary. - # ChipCallable is the only post-init target — see the next test. + def test_l2_rejects_python_callable(self): + hw = Worker(level=2, device_id=0, platform="a2a3sim", runtime="tensormap_and_ringbuffer") + with pytest.raises(TypeError, match="level 2 only supports ChipCallable"): + hw.register(lambda args: None) + + def test_register_python_fn_after_init_before_start_succeeds(self): + # init() allocates mailboxes but does not fork children. Python + # callables registered in this window still land in the startup + # snapshot consumed by the first run(). hw = Worker(level=3, num_sub_workers=0) hw.init() - with pytest.raises(NotImplementedError, match="only ChipCallable is supported post-init"): - hw.register(lambda args: None) - hw.close() + try: + cid = hw.register(lambda args: None) + assert cid in hw._callable_registry + finally: + hw.close() + + def test_register_python_fn_after_init_before_start_does_not_broadcast(self): + class BroadcastTrap: + def broadcast_control_all(self, *args, **kwargs): + raise AssertionError("pre-start Python register must not broadcast") + + hw = Worker(level=3, num_sub_workers=1) + hw.init() + real_worker = hw._worker + try: + hw._worker = BroadcastTrap() + cid = hw.register(lambda args: None) + assert cid in hw._callable_registry + finally: + hw._worker = real_worker + hw.close() + + def test_register_python_fn_after_start_no_python_children_raises(self): + hw = Worker(level=3, num_sub_workers=0) + hw.init() + try: + hw.run(lambda orch, args, cfg: None) + with pytest.raises(RuntimeError, match="no Python-capable child"): + hw.register(lambda args: None) + finally: + hw.close() + + def test_register_waits_for_first_startup_then_uses_post_start_path(self): + hw = Worker(level=3, num_sub_workers=1) + hw.init() + try: + with hw._hierarchical_start_cv: + hw._hierarchical_start_state = "starting" + + observed = {} + + def fake_post_start_register(target): + observed["target"] = target + observed["state"] = hw._hierarchical_start_state + observed["hierarchical_started"] = hw._hierarchical_started + return 7 + + hw._post_start_register_python = fake_post_start_register + result: list[int] = [] + errors: list[BaseException] = [] + wait_entered = threading.Event() + original_wait = hw._hierarchical_start_cv.wait + + def wait_with_signal(timeout=None): + wait_entered.set() + return original_wait(timeout) + + hw._hierarchical_start_cv.wait = wait_with_signal + + def do_register(): + try: + result.append(hw.register(lambda args: None)) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + t = threading.Thread(target=do_register) + t.start() + assert wait_entered.wait(timeout=2.0) + with hw._hierarchical_start_cv: + hw._hierarchical_started = True + hw._hierarchical_start_state = "started" + hw._hierarchical_start_cv.notify_all() + t.join(timeout=2.0) + + assert not t.is_alive() + assert errors == [] + assert result == [7] + assert observed["state"] == "started" + assert observed["hierarchical_started"] is True + finally: + if "original_wait" in locals(): + hw._hierarchical_start_cv.wait = original_wait + hw.close() + + def test_register_blocks_startup_snapshot_from_not_started_window(self): + hw = Worker(level=3, num_sub_workers=0) + hw.init() + + real_registry_lock = hw._registry_lock + register_waiting = threading.Event() + release_register = threading.Event() + startup_snapshot_attempted = threading.Event() + result: list[int] = [] + errors: list[BaseException] = [] + + class BlockingRegistryLock: + def __enter__(self): + thread_name = threading.current_thread().name + if thread_name == "register-thread": + register_waiting.set() + if not release_register.wait(timeout=2.0): + raise TimeoutError("test timed out waiting to release register") + elif thread_name == "startup-thread": + startup_snapshot_attempted.set() + return real_registry_lock.__enter__() + + def __exit__(self, exc_type, exc, tb): + return real_registry_lock.__exit__(exc_type, exc, tb) + + def locked(self): + return real_registry_lock.locked() + + hw._registry_lock = BlockingRegistryLock() + + def do_register(): + try: + result.append(hw.register(lambda args: None)) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + def do_startup(): + try: + hw._start_hierarchical() + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + register_thread = threading.Thread(target=do_register, name="register-thread") + startup_thread = threading.Thread(target=do_startup, name="startup-thread") + try: + register_thread.start() + assert register_waiting.wait(timeout=2.0) + + startup_thread.start() + assert not startup_snapshot_attempted.wait(timeout=0.2) + + release_register.set() + register_thread.join(timeout=2.0) + startup_thread.join(timeout=2.0) + + assert not register_thread.is_alive() + assert not startup_thread.is_alive() + assert errors == [] + assert result == [0] + assert startup_snapshot_attempted.is_set() + assert hw._hierarchical_start_state == "started" + finally: + release_register.set() + register_thread.join(timeout=2.0) + startup_thread.join(timeout=2.0) + hw._registry_lock = real_registry_lock + hw.close() def test_register_chip_callable_after_init_no_chips_succeeds(self): # With no chip children (device_ids unset), the C++ broadcast is a @@ -136,6 +336,621 @@ def test_unregister_chip_callable_after_init_no_chips_succeeds(self): finally: hw.close() + def test_register_chip_callable_broadcast_runs_without_registry_lock(self): + hw = Worker(level=3, num_sub_workers=0) + hw._initialized = True + hw._hierarchical_started = True + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + observed = {} + + def fake_post_init_register(cid, target): + observed["cid"] = cid + observed["target"] = target + observed["locked"] = hw._registry_lock.locked() + + hw._post_init_register = fake_post_init_register + + cid = hw.register(callable_obj) + + assert observed == {"cid": cid, "target": callable_obj, "locked": False} + assert hw._callable_registry[cid] is callable_obj + + def test_register_at_broadcast_runs_without_registry_lock(self): + hw = Worker(level=3, num_sub_workers=0) + hw._initialized = True + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + observed = {} + + def fake_post_init_register(cid, target): + observed["cid"] = cid + observed["target"] = target + observed["locked"] = hw._registry_lock.locked() + + hw._post_init_register = fake_post_init_register + + hw._register_at(7, callable_obj) + + assert observed == {"cid": 7, "target": callable_obj, "locked": False} + assert hw._callable_registry[7] is callable_obj + + def test_python_control_broadcast_passes_default_timeout(self): + from simpler.worker import _CTRL_PY_UNREGISTER, _PY_CONTROL_TIMEOUT_S # noqa: PLC0415 + + class FakeControlWorker: + def __init__(self): + self.calls = [] + + def broadcast_control_all(self, worker_type, sub_cmd, cid, payload, timeout_s=None): + self.calls.append((worker_type, sub_cmd, cid, payload, timeout_s)) + return [] + + fake = FakeControlWorker() + hw = Worker(level=3, num_sub_workers=1) + hw._worker = fake + + errors = hw._broadcast_py_control([WorkerType.SUB], _CTRL_PY_UNREGISTER, 3, strict=False) + + assert errors == [] + assert fake.calls == [(WorkerType.SUB, _CTRL_PY_UNREGISTER, 3, None, _PY_CONTROL_TIMEOUT_S)] + + def test_cloudpickle_payload_roundtrip_supported_callable_shapes(self): + class AddValue: + def __init__(self, value): + self.value = value + + def __call__(self, arg): + return arg + self.value + + scale = 3 + + def nested(arg): + return arg * scale + + cases = [ + (lambda arg: arg + 1, 4, 5), + (nested, 4, 12), + (AddValue(7), 4, 11), + ] + for target, arg, expected in cases: + loaded = _roundtrip_py_callable_payload(target) + assert callable(loaded) + assert loaded(arg) == expected + + def test_python_unregister_child_failure_warns_pops_and_allows_reuse(self, capsys): + from simpler.worker import _CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + cid = hw.register(lambda args: None) + hw._initialized = True + hw._hierarchical_started = True + calls = [] + + def fake_broadcast(worker_types, sub_cmd, broadcast_cid, *, payload=None, strict): + calls.append((list(worker_types), sub_cmd, broadcast_cid, strict)) + if sub_cmd == _CTRL_PY_UNREGISTER: + return ["SUB[0]: injected unregister failure"] + if sub_cmd == _CTRL_PY_REGISTER: + return [] + raise AssertionError(f"unexpected sub_cmd={sub_cmd}") + + hw._broadcast_py_control = fake_broadcast + + hw.unregister(cid) + + captured = capsys.readouterr() + assert "Python children reported errors" in captured.err + assert "injected unregister failure" in captured.err + assert cid not in hw._callable_registry + assert cid not in hw._pending_unregister_cids + + reused = hw.register(lambda args: None) + assert reused == cid + assert calls[0] == ([WorkerType.SUB], _CTRL_PY_UNREGISTER, cid, False) + assert calls[1] == ([WorkerType.SUB], _CTRL_PY_REGISTER, cid, True) + + def test_pending_unregister_cid_is_not_reused_until_broadcast_returns(self): + from simpler.worker import _CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + cid = hw.register(lambda args: None) + hw._initialized = True + hw._hierarchical_started = True + + broadcast_started = threading.Event() + release_broadcast = threading.Event() + errors: list[BaseException] = [] + + def fake_broadcast(worker_types, sub_cmd, broadcast_cid, *, payload=None, strict): + if sub_cmd == _CTRL_PY_UNREGISTER: + broadcast_started.set() + assert release_broadcast.wait(timeout=2.0) + elif sub_cmd == _CTRL_PY_REGISTER: + return [] + else: + raise AssertionError(f"unexpected sub_cmd={sub_cmd}") + return [] + + hw._broadcast_py_control = fake_broadcast + + def do_unregister(): + try: + hw.unregister(cid) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + t = threading.Thread(target=do_unregister) + t.start() + assert broadcast_started.wait(timeout=2.0) + + cid_during_unregister = hw.register(lambda args: None) + assert cid_during_unregister != cid + assert cid in hw._pending_unregister_cids + + release_broadcast.set() + t.join(timeout=2.0) + assert not t.is_alive() + assert errors == [] + + cid_after_unregister = hw.register(lambda args: None) + assert cid_after_unregister == cid + + def test_register_python_sub_callable_after_start_succeeds(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + + def bootstrap(orch, args, cfg): + orch.submit_sub(bootstrap_cid) + + hw.run(bootstrap) + counter_name = counter_shm.name + + def dynamic_sub(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + dynamic_cid = hw.register(dynamic_sub) + + def run_dynamic(orch, args, cfg): + orch.submit_sub(dynamic_cid) + + hw.run(run_dynamic) + hw.close() + + assert _read_counter(counter_buf) == 1 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_post_start_python_register_waits_for_active_sub_mailbox(self): + import time # noqa: PLC0415 + + control_shm = SharedMemory(create=True, size=8) + counter_shm, counter_buf = _make_shared_counter() + hw = Worker(level=3, num_sub_workers=1) + run_errors: list[BaseException] = [] + register_errors: list[BaseException] = [] + dynamic_cids: list[int] = [] + run_thread = None + register_thread = None + try: + assert control_shm.buf is not None + _set_flag(control_shm.buf, 0, 0) # started + _set_flag(control_shm.buf, 4, 0) # release + control_name = control_shm.name + counter_name = counter_shm.name + + def blocking_sub(args): + import time as child_time # noqa: PLC0415 + + shm = SharedMemory(name=control_name) + try: + _set_flag(shm.buf, 0, 1) + while _get_flag(shm.buf, 4) == 0: + child_time.sleep(0.001) + finally: + shm.close() + + blocking_cid = hw.register(blocking_sub) + hw.init() + + def run_blocking(): + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(blocking_cid)) + except BaseException as exc: # noqa: BLE001 + run_errors.append(exc) + + run_thread = threading.Thread(target=run_blocking) + run_thread.start() + + deadline = time.monotonic() + 2.0 + while _get_flag(control_shm.buf, 0) == 0 and time.monotonic() < deadline: + time.sleep(0.001) + assert _get_flag(control_shm.buf, 0) == 1 + + def dynamic_sub(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + def do_register(): + try: + dynamic_cids.append(hw.register(dynamic_sub)) + except BaseException as exc: # noqa: BLE001 + register_errors.append(exc) + + register_thread = threading.Thread(target=do_register) + register_thread.start() + register_thread.join(timeout=0.05) + assert register_thread.is_alive() + + _set_flag(control_shm.buf, 4, 1) + run_thread.join(timeout=2.0) + register_thread.join(timeout=2.0) + + assert not run_thread.is_alive() + assert not register_thread.is_alive() + assert run_errors == [] + assert register_errors == [] + assert len(dynamic_cids) == 1 + + hw.run(lambda orch, args, cfg: orch.submit_sub(dynamic_cids[0])) + assert _read_counter(counter_buf) == 1 + finally: + if control_shm.buf is not None: + _set_flag(control_shm.buf, 4, 1) + if run_thread is not None: + run_thread.join(timeout=2.0) + if register_thread is not None: + register_thread.join(timeout=2.0) + hw.close() + control_shm.close() + control_shm.unlink() + counter_shm.close() + counter_shm.unlink() + + def test_post_start_unregister_pre_start_python_callable_removes_child_entry(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + cid = hw.register(lambda args: _increment_counter(counter_buf)) + hw.init() + + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + assert _read_counter(counter_buf) == 1 + + hw.unregister(cid) + assert cid not in hw._callable_registry + with pytest.raises(RuntimeError, match="not registered"): + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + + counter_name = counter_shm.name + + def replacement(args): + shm = SharedMemory(name=counter_name) + try: + _add_counter(shm.buf, 10) + finally: + shm.close() + + reused = hw.register(replacement) + assert reused == cid + hw.run(lambda orch, args, cfg: orch.submit_sub(reused)) + hw.close() + + assert _read_counter(counter_buf) == 11 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_post_start_unregister_post_start_python_callable_removes_child_entry(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + + counter_name = counter_shm.name + + def dynamic(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + cid = hw.register(dynamic) + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + assert _read_counter(counter_buf) == 1 + + hw.unregister(cid) + assert cid not in hw._callable_registry + with pytest.raises(RuntimeError, match="not registered"): + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + + reused = hw.register(dynamic) + assert reused == cid + hw.run(lambda orch, args, cfg: orch.submit_sub(reused)) + hw.close() + + assert _read_counter(counter_buf) == 2 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_post_start_dynamic_python_callable_execute_failure_propagates(self): + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + + def boom(args): + raise RuntimeError("dynamic callable boom") + + cid = hw.register(boom) + with pytest.raises(RuntimeError, match="dynamic callable boom"): + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + finally: + hw.close() + + def test_broadcast_control_all_accepts_memoryview_payload(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + + def bootstrap(orch, args, cfg): + orch.submit_sub(bootstrap_cid) + + hw.run(bootstrap) + counter_name = counter_shm.name + + def dynamic_sub(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + cid = 5 + worker_impl = hw._worker + assert worker_impl is not None + results = worker_impl.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_REGISTER, + cid, + memoryview(_pack_py_callable_payload(dynamic_sub)), + ) + assert len(results) == 1 + assert results[0].ok + + def run_dynamic(orch, args, cfg): + orch.submit_sub(cid) + + hw.run(run_dynamic) + hw.close() + + assert _read_counter(counter_buf) == 1 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_broadcast_control_all_reports_malformed_payload(self): + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + worker_impl = hw._worker + assert worker_impl is not None + results = worker_impl.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"bad") + assert len(results) == 1 + assert not results[0].ok + assert "payload" in results[0].error_message + finally: + hw.close() + + def test_broadcast_control_all_empty_payload_raises_before_fanout(self): + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + worker_impl = hw._worker + assert worker_impl is not None + with pytest.raises(RuntimeError, match="payload pointer and size"): + worker_impl.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"") + finally: + hw.close() + + def test_broadcast_control_all_timeout_reports_failed_child(self): + shm = SharedMemory(create=True, size=MAILBOX_SIZE) + dw = _Worker(3) + try: + assert shm.buf is not None + _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) + dw.add_sub_worker(_mailbox_addr(shm)) + dw.init() + results = dw.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_UNREGISTER, + 0, + None, + timeout_s=0.001, + ) + assert len(results) == 1 + assert not results[0].ok + assert "timed out" in results[0].error_message + finally: + dw.close() + shm.close() + shm.unlink() + + def test_broadcast_control_all_selected_pool_routing(self): + def make_mailbox(): + shm = SharedMemory(create=True, size=MAILBOX_SIZE) + assert shm.buf is not None + _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) + return shm + + for selected_type, selected_kind in ( + (WorkerType.SUB, "SUB"), + (WorkerType.NEXT_LEVEL, "NEXT_LEVEL"), + ): + sub_shm = make_mailbox() + next_shm = make_mailbox() + dw = _Worker(3) + try: + dw.add_sub_worker(_mailbox_addr(sub_shm)) + dw.add_next_level_worker(_mailbox_addr(next_shm)) + dw.init() + results = dw.broadcast_control_all( + selected_type, + _CTRL_PY_UNREGISTER, + 0, + None, + timeout_s=0.001, + ) + assert len(results) == 1 + assert results[0].worker_type == selected_kind + sub_state = _mailbox_load_i32(_buffer_field_addr(sub_shm.buf, _OFF_STATE)) + next_state = _mailbox_load_i32(_buffer_field_addr(next_shm.buf, _OFF_STATE)) + if selected_type == WorkerType.SUB: + assert sub_state == _CONTROL_REQUEST + assert next_state == _IDLE + else: + assert sub_state == _IDLE + assert next_state == _CONTROL_REQUEST + finally: + dw.close() + sub_shm.close() + sub_shm.unlink() + next_shm.close() + next_shm.unlink() + + def test_broadcast_control_all_result_shape_for_register_and_unregister(self): + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + worker_impl = hw._worker + assert worker_impl is not None + register_results = worker_impl.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_REGISTER, + 5, + b"bad", + ) + unregister_results = worker_impl.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_UNREGISTER, + bootstrap_cid, + None, + ) + + for result in (register_results[0], unregister_results[0]): + assert isinstance(result.worker_type, str) + assert isinstance(result.worker_index, int) + assert isinstance(result.ok, bool) + assert isinstance(result.error_message, str) + assert not register_results[0].ok + assert unregister_results[0].ok + finally: + hw.close() + + def test_nonserializable_dynamic_python_callable_does_not_consume_cid(self): + lock = threading.Lock() + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + before = dict(hw._callable_registry) + + def captures_lock(args): + lock.acquire(False) + + with pytest.raises(TypeError, match="lock"): + hw.register(captures_lock) + assert hw._callable_registry == before + finally: + hw.close() + + def test_chip_register_reuse_clears_seen_python_cid_before_binary_register(self): + from simpler.worker import _CTRL_PY_UNREGISTER # noqa: PLC0415 + + calls = [] + + class FakeWorker: + def broadcast_register_all(self, cid, blob_ptr, blob_size): + calls.append(("binary_register", cid, blob_size)) + + hw = Worker(level=3, num_sub_workers=1) + hw._initialized = True + hw._hierarchical_started = True + hw._worker = FakeWorker() + hw._py_callable_cids_seen.add(0) + + def fake_py_control(worker_types, sub_cmd, cid, *, payload=None, strict): + calls.append(("py_clear", list(worker_types), sub_cmd, cid, strict)) + return [] + + hw._broadcast_py_control = fake_py_control + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + + cid = hw.register(callable_obj) + + assert cid == 0 + assert calls[0] == ("py_clear", [WorkerType.SUB], _CTRL_PY_UNREGISTER, 0, True) + assert calls[1][0] == "binary_register" + assert 0 not in hw._py_callable_cids_seen + + hw._callable_registry.pop(0) + calls.clear() + + cid = hw.register(ChipCallable.build(signature=[], func_name="y", binary=b"\x00", children=[])) + + assert cid == 0 + assert len(calls) == 1 + assert calls[0][0:2] == ("binary_register", 0) + + def test_chip_register_reuse_fails_before_binary_register_when_python_clear_fails(self): + calls = [] + + class FakeWorker: + def broadcast_register_all(self, cid, blob_ptr, blob_size): + calls.append(("binary_register", cid)) + + hw = Worker(level=3, num_sub_workers=1) + hw._initialized = True + hw._hierarchical_started = True + hw._worker = FakeWorker() + hw._py_callable_cids_seen.add(0) + + def fake_py_control(worker_types, sub_cmd, cid, *, payload=None, strict): + calls.append(("py_clear", cid, strict)) + raise RuntimeError("clear failed") + + hw._broadcast_py_control = fake_py_control + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + + with pytest.raises(RuntimeError, match="clear failed"): + hw.register(callable_obj) + + assert calls == [("py_clear", 0, True)] + assert hw._callable_registry == {} + def test_unregister_middle_cid_reuses_hole(self): # `_allocate_cid` must fill the smallest hole, not append at # len(registry). The bug it guards against: register 0/1/2, diff --git a/tests/ut/py/test_worker/test_l4_recursive.py b/tests/ut/py/test_worker/test_l4_recursive.py index d52b019b6..4f1ed290c 100644 --- a/tests/ut/py/test_worker/test_l4_recursive.py +++ b/tests/ut/py/test_worker/test_l4_recursive.py @@ -104,6 +104,17 @@ def test_add_initialized_child_raises(self): child.close() w4.close() + def test_l4_device_ids_rejected(self): + w4 = Worker(level=4, device_ids=[0], num_sub_workers=0) + with pytest.raises(RuntimeError, match="device_ids are only supported on L3"): + w4.init() + + def test_add_worker_with_device_ids_rejected(self): + w4 = Worker(level=4, device_ids=[0], num_sub_workers=0) + child = Worker(level=3, num_sub_workers=0) + with pytest.raises(RuntimeError, match="cannot be combined with device_ids"): + w4.add_worker(child) + def test_malloc_on_l4_raises_index_error(self): # L4 has no chip mailboxes — `Worker.malloc` must surface IndexError # rather than silently dispatch CTRL_MALLOC to a next_level (L3 worker) @@ -175,6 +186,38 @@ def test_l4_register_then_unregister_recycles_cid(self): finally: w4.close() + def test_l4_register_python_orch_after_start_succeeds(self): + counter_shm, counter_buf = _make_shared_counter() + try: + l3 = Worker(level=3, num_sub_workers=1) + l3_sub_cid = l3.register(lambda args: _increment_counter(counter_buf)) + + w4 = Worker(level=4, num_sub_workers=0) + bootstrap_cid = w4.register(lambda orch, args, config: None) + w4.add_worker(l3) + w4.init() + + def bootstrap(orch, args, config): + orch.submit_next_level(bootstrap_cid, TaskArgs(), CallConfig()) + + w4.run(bootstrap) + + def dynamic_l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid) + + dynamic_cid = w4.register(dynamic_l3_orch) + + def l4_orch(orch, args, config): + orch.submit_next_level(dynamic_cid, TaskArgs(), CallConfig()) + + w4.run(l4_orch) + w4.close() + + assert _read_counter(counter_buf) == 1 + finally: + counter_shm.close() + counter_shm.unlink() + # --------------------------------------------------------------------------- # Test: L4 → L3 PROCESS mode — single dispatch