fix(codegen): require row_major layout for tile.cast (pto.tcvt)#1559
Conversation
📝 WalkthroughWalkthroughThis PR constrains Changestile.cast row-major layout constraint and repair verification
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request resolves an issue where tile.cast (implemented via pto.tcvt) mis-orders elements when its source tile is column-major (e.g., a reshaped [n, 1] index vector). It registers row-major layout constraints for both input and output of tile.cast in pto_ops_common.cpp, allowing the ResolveBackendOpLayouts pass to automatically repair column-major callers by reshaping them to row-major before the cast and back to column-major afterward. The PR also updates relevant documentation, refactors code comments in op_conversion_registry.cpp, and adds a regression unit test to verify the fix. There are no review comments to address.
Fixes hw-native-sys#1549 pto.tcvt silently mis-orders elements when its source tile is col_major (e.g. a reshaped [n, 1] index vector narrowed i32 -> i16): the same cast on a row_major source is correct. This produced reversed scatter rows in the FP16 tensor.scatter_update lowering. Register tile.cast with set_input_layout(0, row_major) and set_output_layout(row_major), mirroring tile.rsqrt / tile.cmps / tile.sort32. ResolveBackendOpLayouts then repairs every col_major caller generically by reshaping [n, 1] -> [1, n] row_major around the cast and restoring the original layout afterwards. Row-major callers are unaffected (no repair). Add a ResolveBackendOpLayouts regression test for a col_major [16, 1] i32 -> i16 cast, and list tile.cast among the constrained ops in the pass docs (en + zh-cn). Also refresh the scatter_update lowering comment: its i32-compute / narrow-at-the-end design is retained for the alignment and canonical-layout benefit, while the col_major mis-ordering it used to dodge is now covered by the general tile.cast layout spec.
… sources Covers hw-native-sys#1549: narrows i32 -> i16 on a col_major [N, 1] view (reshaped from [1, N]) and, as a control, on a row_major [1, N] tile. The col_major case is the regression — element order must be preserved. NOTE: still verifying that the tile.cast row_major layout spec actually engages for this reshape-sourced col_major path in the full pipeline; the repair was not observed to fire in a device-free compile, which needs follow-up before relying on this ST as a passing gate.
a96c187 to
dc183de
Compare
Condense the tile.cast row_major rationale in pto_ops_common.cpp to match the neighboring tile.rsqrt comment style, and drop the historical narration from the scatter_update flat-index comment in op_conversion_registry.cpp, keeping only the alignment rationale that still describes the code.
Summary
pto.tcvt(the lowering oftile.cast) silently mis-orders elements when its source tile iscol_major— e.g. a reshaped[n, 1]index vector narrowedi32 -> i16. The same cast on arow_majorsource is correct, so the failure is silent wrong output with no diagnostic. This is what produced reversed scatter rows in the FP16tensor.scatter_updatelowering (issue #1549).PyPTO already drives this exact class of ISA constraint through the
ResolveBackendOpLayoutspass, which reshapes a[n, 1] col_majorvector to[1, n] row_majoraround a constrained op and restores the layout afterwards.tile.castsimply had no layout spec, so it was never repaired.Changes
src/backend/common/pto_ops_common.cpp: registertile.castwithset_input_layout(0, row_major)+set_output_layout(row_major), mirroringtile.rsqrt/tile.cmps/tile.sort32.ResolveBackendOpLayoutsnow repairs everycol_majorcaller generically. Row-major callers are unaffected (no repair, zero overhead).tests/ut/ir/transforms/test_resolve_backend_op_layouts_pass.py: pass-level regression for acol_major [16, 1]i32 -> i16cast being repaired through a[1, 16] row_majorreshape.tests/st/runtime/ops/test_cast.py: new end-to-end ST — acol_major [N, 1]i32 -> i16narrow (the [Bug] tile.cast (pto.tcvt) narrowing mis-orders elements when the source tile is col_major #1549 regression, must preserve element order) plus arow_major [1, N]control case.docs/en+docs/zh-cn20-resolve_backend_op_layouts.md: listtile.castamong the constrained ops.src/ir/transforms/op_conversion_registry.cpp: comment-only trim of thescatter_updatelowering — its i32-compute / narrow-at-the-end design is kept for the alignment benefit. No behavior change: this path narrows only the row-major[n, d]flat index, so it never feeds acol_majorsource totile.cast, and the new layout spec is a no-op for it.Testing
ResolveBackendOpLayoutsUTs pass (5/5)tests/ut/ir/transforms/+tests/ut/codegen/: 1687 passed, 26 skippedtests/ut/ir/operators/test_tile_ops.py: 237 passedtests/st/runtime/ops/test_cast.py::TestCast::test_tile_cast_col_major_narrow(hardware, to be confirmed by reviewer) — the direct col_major-cast regression for this fix.Related Issues
Fixes #1549