fix(codegen): lower scatter_update to tile.scatter, remove orch memcpy#1537
Conversation
hw-native-sys#1490) tensor.scatter_update now expands to per-element flat-index tile.scatter (pto.tscatter) + select preserve-blend during ConvertTensorToTileOps, so it runs on device and respects TensorMap producer sync. Removes the orch raw-memcpy path and the dead GetTensorDataPtr helper, drops the scratch arg from tile.scatter_update, and deletes the textract/tinsert PTO codegen.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRemoves the ChangesScatter Update Refactoring: API, Lowering, and Backend Cleanup
🎯 4 (Complex) | ⏱️ ~60 minutes Possibly Related PRs
Suggested Reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the scatter_update operation (at both the tensor and tile levels) to lower directly to a whole-row tile.scatter using flat indices, eliminating the need for a temporary scratch tile. Consequently, the scratch parameter has been removed from the Python APIs, IR definitions, and C++ codegen. The review feedback suggests improving the C++ IR transformation in op_conversion_registry.cpp by using INTERNAL_CHECK_SPAN instead of CHECK to conform to project conventions, and adding an explicit safety check to verify that the number of source rows matches the index size.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/st/runtime/cross_core/test_chained_matmul_cast.py`:
- Around line 66-68: The expected-value path in compute_expected currently does
a pure FP32 matmul but the tested kernel performs FP32→BF16→FP32 (CubeVecCast),
so modify compute_expected to mimic that round-trip: cast tensors["a"] and
tensors["w"] to torch.bfloat16 then back to torch.float32 before calling
torch.matmul, and store the result into tensors["y"]; this will make the
reference follow the same FP32→BF16→FP32 behavior as the implementation under
test.
In `@tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py`:
- Around line 1989-1993: Replace the fragile substring assertions with exact-op
checks: after running passes.convert_tensor_to_tile_ops() and getting After,
either (A) inspect the IR Call nodes in After (e.g., traverse After to collect
call.op.name values) and assert "tile.scatter" appears the expected number of
times and "tile.scatter_mask" is not present, or (B) if staying with the printed
text from ir.python_print(After), use a word-boundary regex like
r"\btile\.scatter\b" to assert exact matches and separately assert
r"\btile\.scatter_mask\b" is absent; update the assertions accordingly (refer to
passes.PassContext, passes.convert_tensor_to_tile_ops, ir.python_print, and the
After variable).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d60d101c-7572-4fb4-9790-02f7045606e7
📒 Files selected for processing (16)
docs/en/user/02-operation_reference.mddocs/zh-cn/user/02-operation_reference.mdinclude/pypto/codegen/codegen_base.hpython/pypto/ir/op/tile_ops.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/codegen/orchestration/orchestration_codegen.cppsrc/codegen/pto/pto_codegen.cppsrc/codegen/tensor_op_codegen.cppsrc/ir/op/tile_ops/transform.cppsrc/ir/transforms/op_conversion_registry.cpptests/st/runtime/cross_core/test_chained_matmul_cast.pytests/st/runtime/ops/test_scatter_update.pytests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/operators/test_tile_ops.pytests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py
💤 Files with no reviewable changes (6)
- include/pypto/codegen/codegen_base.h
- tests/ut/ir/operators/test_tile_ops.py
- src/backend/common/pto_ops_common.cpp
- src/codegen/orchestration/orchestration_codegen.cpp
- src/codegen/tensor_op_codegen.cpp
- tests/ut/codegen/test_pto_codegen_ops.py
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- op_conversion_registry: use INTERNAL_CHECK_SPAN for scatter_update post-bridge invariants (Span in scope) and add explicit check that src rows == index size (b * s) - test_convert_tensor_to_tile_ops: tighten scatter_update assertion to exact pl.tile.scatter( (index form), and assert pl.tile.scatter_mask( is absent so a regression to mask form is caught
Replace `return pl.store(result, ...)` with explicit `dst_t = pl.store(result, ...)` / `return dst_t` across all kernel programs in the scatter_update system tests. Returning the store op result directly was triggering a device-side AICPU stream-sync timeout (hang) on the FP16 case; materializing the store as a statement before the return clears the hang.
…he end The FP16 scatter_update lowering built the flat destination indices in the tscatter-required i16 width, which forced an i32->i16 tile.cast on the col_major [n,1] index view. That cast mis-ordered the indices, scattering whole src rows in reversed order (dst row 0 received src[15]). Keep the entire flat-index computation in i32 (identical to the FP32 path) and narrow only the finished row-major [n,d] flat_idx to the tscatter index width. The final cast runs on a 32-byte-aligned, row-major tile, which is both alignment-legal and correct.
Add a Scatter Update Lowering section to the ConvertTensorToTileOps pass doc (en + zh-cn): the flat-index expansion (flat_idx = index*d + c), the i32-compute / narrow-at-the-end rule, and the generated pto.tscatter + sel preserve-blend op sequence.
- CHECK that m*d fits in i16 for 2-byte dst (the tscatter index width), so an oversized FP16/BF16/INT16 dst raises a clear error instead of silently scattering to wrong rows on flat-index overflow. - Reject 4D input/src in the lowering with a user-facing CHECK: 4D type-checks via the op's deduction but is not yet lowered, so it would otherwise hit an internal error. - Add convert-pass tests covering both rejections.
## Summary `pto.tcvt` (the lowering of `tile.cast`) silently **mis-orders elements when its source tile is `col_major`** — e.g. a reshaped `[n, 1]` index vector narrowed `i32 -> i16`. The same cast on a `row_major` source is correct, so the failure is silent wrong output with no diagnostic. This is what produced reversed scatter rows in the FP16 `tensor.scatter_update` lowering (issue #1549). PyPTO already drives this exact class of ISA constraint through the `ResolveBackendOpLayouts` pass, which reshapes a `[n, 1] col_major` vector to `[1, n] row_major` around a constrained op and restores the layout afterwards. `tile.cast` simply had **no layout spec**, so it was never repaired. ### Changes - **`src/backend/common/pto_ops_common.cpp`**: register `tile.cast` with `set_input_layout(0, row_major)` + `set_output_layout(row_major)`, mirroring `tile.rsqrt` / `tile.cmps` / `tile.sort32`. `ResolveBackendOpLayouts` now repairs every `col_major` caller generically. Row-major callers are unaffected (no repair, zero overhead). - **`tests/ut/ir/transforms/test_resolve_backend_op_layouts_pass.py`**: pass-level regression for a `col_major [16, 1]` `i32 -> i16` cast being repaired through a `[1, 16] row_major` reshape. - **`tests/st/runtime/ops/test_cast.py`**: new end-to-end ST — a `col_major [N, 1]` `i32 -> i16` narrow (the #1549 regression, must preserve element order) plus a `row_major [1, N]` control case. - **`docs/en` + `docs/zh-cn` `20-resolve_backend_op_layouts.md`**: list `tile.cast` among the constrained ops. - **`src/ir/transforms/op_conversion_registry.cpp`**: comment-only trim of the `scatter_update` lowering — its i32-compute / narrow-at-the-end design is kept for the alignment benefit. No behavior change: this path narrows only the **row-major `[n, d]`** flat index, so it never feeds a `col_major` source to `tile.cast`, and the new layout spec is a no-op for it. ## Testing - [x] New + existing `ResolveBackendOpLayouts` UTs pass (5/5) - [x] `tests/ut/ir/transforms/` + `tests/ut/codegen/`: 1687 passed, 26 skipped - [x] `tests/ut/ir/operators/test_tile_ops.py`: 237 passed - [x] Pre-commit hooks (clang-format, cpplint, ruff, pyright, markdownlint) pass - [ ] On-device ST `tests/st/runtime/ops/test_cast.py::TestCast::test_tile_cast_col_major_narrow` (hardware, to be confirmed by reviewer) — the direct col_major-cast regression for this fix. > Note: `test_scatter_update.py::...::test_tile_scatter_update_fp16` is **not** a target of this PR. That path was already worked around in #1537 (narrow only the row-major `[n, d]` flat index), so this PR does not change its codegen and it needs no re-confirmation here. ## Related Issues Fixes #1549 --------- Co-authored-by: Youhezhen <youhezhen@huawei.com>
Summary
tensor.scatter_update/tile.scatter_updatenow lower to a whole-rowtile.scatter(pto.tscatter) + select preserve-blend inConvertTensorToTileOps, so the op runs on a device kernel and respectsTensorMapproducer sync.memcpypath (which dereferenced tensor data pointers, racing with producer tasks) and the now-deadGetTensorDataPtrhelper.scratcharg fromtile.scatter_update(4→3) and deleted the textract/tinsert PTO codegen.Approach
Whole-row scatter is expressed as a per-element flat scatter:
flat_idx[k, c] = index.flat[k] * d + c. The conversion builds the flat index (column arange broadcast viacol_expand+ anindex * drow offset), scatterssrcinto a zeroed base, and reconstructs the DPS row-preserve via a mask scatter +cmps+selblend — mirroringtensor.scatter.The flat-index arithmetic is computed entirely in i32, and only the finished row-major
[n, d]flat index is narrowed to the tscatter-required width (i16 for 2-byte data). This keeps every intermediate index tile in a canonical, 32-byte-aligned, row-major layout.Lowering (generated PTO)
Hardware
pto.tscatterwrites per element (dst.flat[idx[k, c]] = src[k, c]) and treatsdstas write-only, so the preserve semantics are rebuilt on the PyPTO side. Generated kernel for FP32[32, 32]input /[2, 8]index /[16, 32]src:pto.tload×3input_tile,index_tile,src_tilepto.tci[1, d]=0..d-1pto.texpands[n, d]pto.tcolexpandcol_nd[k, c] = cpto.tmulsrow_base[k] = index.flat[k] * dpto.trowexpandaddflat_idx = col_nd + row_base→[n, d]pto.tcvtflat_idxi32→i16 (2-byte dtypes only)pto.texpands[m, d]pto.tscatterscattered= src into zeroed base (written = src, unwritten = 0)pto.texpands×2pto.tscattermask= ones into zeroed base (written = 1, unwritten = 0)pto.tcmpspred = (mask != 0)pto.tselout = sel(pred, scattered, input_tile)pto.tstoreoutto outputtile.sel(notinput * mask) avoids emittingpto.tmul, which A2/A3 reject for bf16/i8. The indexreshape [b, s] → [n, 1]is a buffer-view realias, not a separate PTO op. Full walkthrough:docs/en/dev/passes/12-convert_tensor_to_tile_ops.md(Scatter Update Lowering).FP16 root cause
Two FP16-specific defects surfaced on device and are fixed here:
pl.store(...)op result directly (return pl.store(result, ...)). On the FP16 case this triggered an AICPU stream-sync timeout that cascaded into unrelated tests. Materializing the store as a statement before the return (dst_t = pl.store(...)/return dst_t) clears the hang.i32→i16on a col_major[n, 1]view;tile.castmis-orders elements on a col_major source, so wholesrcrows scattered in reverse (dstrow 0 receivedsrc[15]). Computing the indices in i32 and narrowing only the final row-major[n, d]flat index fixes it. (Narrowing earlier on the[b, s]tile is also invalid — an i16[b, s]row iscols * 2bytes and breaks 32-byte alignment.) The companiontensor.scatterFP16/BF16 path is unaffected: it takes an i16 index from the caller and never casts.Testing
test_tile_scatter_update_fp16passes on device (previously hung / produced incorrect output)Related Issues
Fixes #1490