feat(ir): Add tile.comm_notify, tile.comm_wait, tile.comm_test cross-rank signal ops#1301
feat(ir): Add tile.comm_notify, tile.comm_wait, tile.comm_test cross-rank signal ops#1301Little-oil wants to merge 6 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds tile.comm_notify and tile.comm_wait: IR ops, IR-level Python bindings, language DSL wrappers, PTO backend lowering to pto.comm.tnotify/twait, English/Chinese docs, and unit + codegen + runtime tests validating INT32 signal semantics and attributes. ChangesTile Signal Operations (tile.comm_notify / tile.comm_wait)
Sequence Diagram(s)sequenceDiagram
participant Program as User Program
participant Lang as language.system_ops
participant IR as ir.tile_ops
participant PTO as PTO Codegen
participant GM as Device GM
Program->>Lang: comm_notify(signal, value, op)
Lang->>IR: normalized signal/value -> tile.comm_notify Call
IR->>PTO: tile.comm_notify
PTO->>GM: emit pto.comm.tnotify (partition_view, i32 value, notifyOp)
Program->>Lang: comm_wait(signal, cmp_value, cmp)
Lang->>IR: normalized cmp_value -> tile.comm_wait Call
IR->>PTO: tile.comm_wait
PTO->>GM: emit pto.comm.twait (partition_view, i32 cmp, cmp attr)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
wait for PTOAS'new version |
There was a problem hiding this comment.
Code Review
This pull request introduces cross-rank signal operations, specifically tile.notify and tile.wait, to support synchronization between different ranks. The changes include documentation in both English and Chinese, Python IR and language-level API definitions, C++ backend codegen for PTO operations, and comprehensive unit and system tests. The feedback suggests refactoring the argument conversion logic in python/pypto/language/op/system_ops.py into a shared helper function to improve maintainability and ensure consistent validation of IntLike arguments across both operations.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/ut/codegen/test_pto_codegen_ops.py (1)
1608-1673: 💤 Low value
TestTileNotifyPtoCodegen— LGTM with one optional improvementThe three tests cover the key paths (set, atomic_add, bad dtype). One gap worth noting: there is no rejection test for an unsupported
opstring (e.g.op="invalid"). If input validation is enforced at the IR construction layer rather than codegen, add the equivalent test totests/ut/ir/operators/test_tile_ops.py; if it's enforced in codegen, a smallpytest.raisescase here would complete the contract coverage.Optionally,
test_tile_notify_set_codegenandtest_tile_notify_atomic_add_codegencan be collapsed into a single@pytest.mark.parametrize("op,attr", [("set", "set"), ("atomic_add", "atomic_add")])test to reduce duplication.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/ut/codegen/test_pto_codegen_ops.py` around lines 1608 - 1673, Add a rejection test for unsupported op strings so tile.notify validates op values: add a new test (e.g. in TestTileNotifyPtoCodegen or in tests/ut/ir/operators/test_tile_ops.py depending on where validation lives) that constructs a program using pl.tile.notify(signal, 1, op="invalid") and asserts it raises (pytest.raises) with an appropriate message; reference the existing helper _generate_mlir and the test names test_tile_notify_set_codegen/test_tile_notify_atomic_add_codegen to locate similar test patterns and mirror their structure (or convert the two positive tests into a single parametric `@pytest.mark.parametrize` if you prefer to reduce duplication).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/ir/op/tile_ops/cross_core.cpp`:
- Around line 90-114: Add IR-level operand validation to the REGISTER_OP
declarations for "tile.notify" and "tile.wait": implement .f_validate handlers
that check the "signal" operand is an INT32 tensor with exactly one element
(shape == 1) and that the secondary operand ("value" for tile.notify,
"cmp_value" for tile.wait) is an INT32 scalar; emit a clear validation error
when these conditions fail so invalid uses fail during IR construction rather
than backend lowering. Ensure the validators reference the op names
("tile.notify", "tile.wait") and the operand names ("signal", "value",
"cmp_value") so reviewers can locate the checks.
In `@tests/st/runtime/test_notify_wait.py`:
- Around line 268-297: The test suite unconditionally exercises PTOAS-only APIs
pto.comm.tnotify / pto.comm.twait causing infra-driven failures when PTOAS is
not present; modify the TestNotifyWait tests to be skipped when the capability
is absent by checking the PTOAS capability at import/runtime (e.g., a helper
like has_ptoas_capability() or checking pto.comm for tnotify/twait) and applying
pytest.skip or pytest.mark.skipif to the whole TestNotifyWait class or
individual test methods (referencing TestNotifyWait, test_notify_* methods, and
pto.comm.tnotify/twait) so the suite only runs when those APIs are available.
---
Nitpick comments:
In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1608-1673: Add a rejection test for unsupported op strings so
tile.notify validates op values: add a new test (e.g. in
TestTileNotifyPtoCodegen or in tests/ut/ir/operators/test_tile_ops.py depending
on where validation lives) that constructs a program using
pl.tile.notify(signal, 1, op="invalid") and asserts it raises (pytest.raises)
with an appropriate message; reference the existing helper _generate_mlir and
the test names test_tile_notify_set_codegen/test_tile_notify_atomic_add_codegen
to locate similar test patterns and mirror their structure (or convert the two
positive tests into a single parametric `@pytest.mark.parametrize` if you prefer
to reduce duplication).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2a67f8d0-4681-494b-8c02-d82da2702789
📒 Files selected for processing (10)
docs/en/dev/ir/05-operators.mddocs/zh-cn/dev/ir/05-operators.mdpython/pypto/ir/op/tile_ops.pypython/pypto/language/op/system_ops.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/cross_core.cpptests/st/runtime/test_notify_wait.pytests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/operators/test_tile_ops.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 1924-1937: CheckCommSignalType currently only rejects rank-0
tensors but must enforce the single-slot contract: verify the tensor contains
exactly one element or reject statically-known non-singleton shapes before
lowering. In CheckCommSignalType (and using span/op_name for diagnostics) keep
the rank>=1 check, then inspect signal_tensor_type->shape_: if all extents are
statically-known, compute the product and REQUIRE it equals 1 (emit a clear
CHECK/INTERNAL_CHECK_SPAN failure referencing op_name and the shape); if any
extent is dynamic/unknown, allow it (since it could be singleton at runtime) but
still reject any statically-known extent >1 early. Return the same
signal_tensor_type on success.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9b35e237-2d6f-4661-af36-5784836ec7cb
📒 Files selected for processing (10)
docs/en/dev/ir/05-operators.mddocs/zh-cn/dev/ir/05-operators.mdpython/pypto/ir/op/tile_ops.pypython/pypto/language/op/system_ops.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/cross_core.cpptests/st/runtime/test_notify_wait.pytests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/operators/test_tile_ops.py
✅ Files skipped from review due to trivial changes (2)
- docs/en/dev/ir/05-operators.md
- docs/zh-cn/dev/ir/05-operators.md
- system_ops: extract _value_to_int32_expr helper shared by comm_notify/ comm_wait/comm_test; rewrap INDEX ConstInt as INT32 so the literal int case satisfies the IR contract (gemini-code-assist) - ir/cross_core: add f_deduce_type validators for tile.comm_notify/ comm_wait/comm_test enforcing 1-element INT32 signal + INT32 scalar value at IR construction (coderabbitai) - backend/pto_ops_common: extend CheckCommSignalType to reject statically-known non-singleton signal shapes before PTO lowering (coderabbitai) - tests/st: gate test_notify_wait suite on PTOAS_HAS_COMM_NOTIFY_WAIT=1 env var so infra without the staged PTOAS build skips cleanly (coderabbitai) - tests/ut: relocate pytest.raises around the @pl.program class body in the reject_non_int32_signal cases, since the new IR-level validators now fire during decoration instead of during codegen
Introduces a pair of cross-rank signaling operations on AIV: - tile.notify(signal, value, op): write or atomic-add an INT32 value to a remote rank's signal slot (1-element INT32 GM tensor). Lowers to pto.comm.tnotify with notifyOp = #pto.notify_op<set|atomic_add>. - tile.wait(signal, cmp_value, cmp): block until a local INT32 signal slot satisfies a comparison. Lowers to pto.comm.twait with cmp = #pto.wait_cmp<eq|ne|gt|ge|lt|le>. All five layers updated: - C++ op registrations in src/ir/op/tile_ops/cross_core.cpp - PTO codegen in src/backend/common/pto_ops_common.cpp - Python IR wrappers in python/pypto/ir/op/tile_ops.py - DSL wrappers in python/pypto/language/op/system_ops.py with re-export through python/pypto/language/op/tile_ops.py - Tests: UT for IR + PTO codegen, ST loopback covering all six cmp variants and both notify ops - Docs: Cross-Rank Signal Operations sections in docs/en/dev/ir/05-operators.md and docs/zh-cn/dev/ir/05-operators.md Note: pto.comm.tnotify / pto.comm.twait require a PTOAS build that exposes those custom ops; the on-board ST will only run on a PTOAS that has the comm dialect enabled.
…ps to tile.comm_{notify,wait}
The previous codegen for tile.notify/tile.wait was broken — PTOAS rejected
the emitted MLIR. Two bugs:
1. Wrong operand type. Codegen emitted the signal as !pto.ptr<i32> (or a
raw tensor_view), but pto.comm.tnotify / pto.comm.twait require
!pto.partition_tensor_view<Nxi32>. Fix: lower the signal Var through
make_tensor_view → partition_view to build a partition view covering
the full signal shape.
2. Wrong assembly syntax. Codegen used the custom format
"pto.comm.tnotify %sig, %v {...} : <type>, i32", but PTOAS's TNotifyOp /
TWaitOp have no custom assemblyFormat — only generic MLIR op syntax is
accepted. Fix: emit "pto.comm.tnotify"(%sig, %v) {...} : (<type>, i32) -> ().
Also rename the ops from tile.notify/tile.wait to tile.comm_notify/tile.comm_wait
for namespace consistency with the pto.comm.* MLIR ops and to keep cross-rank
signaling ops grouped under a comm_* prefix.
ST tests reshaped to mirror the two real usage patterns from simpler's
ep_dispatch_combine kernels (count exchange via atomic_add, done barrier
via atomic_add + wait ge), instead of exhaustively covering every cmp op.
…y/wait ST
Add `tile.comm_test` (non-blocking signal check, returns i1) alongside the
existing `tile.comm_notify`/`tile.comm_wait` cross-rank signal ops. Emits
`pto.comm.ttest(... : !pto.partition_tensor_view<Nxi32>, i32) {cmp = ...} -> i1`
using PTOAS custom assembly syntax.
Also fix `tests/st/runtime/test_notify_wait.py` orchestrators: replaced
`return self.kernel(signal)` with `signal = self.kernel(signal); return signal`
so the kernel call lands in an AssignStmt and gets emitted (previously
KERNELS=[] and the signal stayed at its init value). Add a dedicated
wait-only ST case to isolate the twait codegen path.
- system_ops: extract _value_to_int32_expr helper shared by comm_notify/ comm_wait/comm_test; rewrap INDEX ConstInt as INT32 so the literal int case satisfies the IR contract (gemini-code-assist) - ir/cross_core: add f_deduce_type validators for tile.comm_notify/ comm_wait/comm_test enforcing 1-element INT32 signal + INT32 scalar value at IR construction (coderabbitai) - backend/pto_ops_common: extend CheckCommSignalType to reject statically-known non-singleton signal shapes before PTO lowering (coderabbitai) - tests/st: gate test_notify_wait suite on PTOAS_HAS_COMM_NOTIFY_WAIT=1 env var so infra without the staged PTOAS build skips cleanly (coderabbitai) - tests/ut: relocate pytest.raises around the @pl.program class body in the reject_non_int32_signal cases, since the new IR-level validators now fire during decoration instead of during codegen
Summary
Add three cross-rank signaling ops on AIV that wrap PTOAS's
pto.comm.*custom dialect:pl.tile.comm_notify(signal, value, *, op)— write or atomic-add an INT32 value into a remote rank's signal slot.op ∈ {"set", "atomic_add"}. Lowers topto.comm.tnotify.pl.tile.comm_wait(signal, cmp_value, *, cmp)— block until a local INT32 signal slot satisfies a comparison.cmp ∈ {"eq","ne","gt","ge","lt","le"}. Lowers topto.comm.twait.pl.tile.comm_test(signal, cmp_value, *, cmp)— non-blocking poll of the same comparison; returns apl.Scalar(BOOL). Lowers topto.comm.ttest(returns i1).signalis a 1-element INT32pl.Tensorviewing a slot in the rank's HCCL window (typically obtained viaimport_peer_buffer).value/cmp_valueaccepts Pythonint,pl.Scalar, orpl.Exprand is normalised to INT32.Mirrors the two real usage patterns from simpler's
ep_dispatch_combinekernels (count exchange via atomic_add, done barrier via atomic_add + wait ge).Layers updated
src/ir/op/tile_ops/cross_core.cpp!pto.partition_tensor_view<Nxi32>, generic MLIR op syntax) —src/backend/common/pto_ops_common.cpppython/pypto/ir/op/tile_ops.pypl.tile.*) —python/pypto/language/op/system_ops.py,python/pypto/language/op/tile_ops.pytests/ut/ir/operators/test_tile_ops.py,tests/ut/codegen/test_pto_codegen_ops.pytests/st/runtime/test_notify_wait.pydocs/en/dev/ir/05-operators.md,docs/zh-cn/dev/ir/05-operators.mdpto.comm.tnotify/twait/ttestcustom ops) —.github/workflows/ci.ymlKey design / fix notes
comm_*prefix groups these under cross-rank signaling, parallel to thepto.comm.*MLIR namespace.make_tensor_view → partition_viewto produce a!pto.partition_tensor_view<Nxi32>covering the full signal shape. Assembly uses generic MLIR op syntax ("pto.comm.tnotify"(%sig, %v) {...} : (...) -> ()) because PTOAS defines no customassemblyFormatfor these ops.f_deduce_typeenforces 1-element INT32 signal + INT32 scalar at IR construction, so misuse fails at@pl.programdecoration time.signal = self.kernel(signal); return signal); a barereturn self.kernel(...)is silently dropped (KERNELS=[]).Testing
TestTileCommNotifyOp,TestTileCommWaitOp,TestTileCommTestOp,TestTileNotifyPtoCodegen,TestTileWaitPtoCodegen,TestTileTestPtoCodegen.tests/st/runtime/test_notify_wait.py— gated onPTOAS_HAS_COMM_NOTIFY_WAIT=1so infra without the upgraded PTOAS skips cleanly. Three programs:CountExchangeProgram(atomic_add),WaitOnlyProgram(wait ge),DoneBarrierProgram(notify + wait combined).Notes
tpush_*/tpop_*/tfree_*).