Dynamic stride support through waveasm#1091
Conversation
b51b0b3 to
2bb545c
Compare
136a9f0 to
432dcaa
Compare
47f2ef2 to
1995206
Compare
Squashed cherry-pick of suryajasper/dynamic-strides-waveasm onto 4waveasm-256x192x256. Merges partial kernel argument preloading, extract_strided_metadata handler, and dynamic stride test updates. Commits included: - Handle memref.extract_strided_metadata in waveasm backend - Update dynamic strides test & compile options to include waveasm - xfail waveasm dynamic strides tests w/ dynamic dims or buffer ops - Fix dynamic strides + dynamic dims through waveasm & accumulator bitcast - Fixed dynamic strides with bufops w/ waveasm - Fix mxfp waveasm example to use (2,2) wave shape - Fixed waveasm dynamic strides to use partial kernel argument preloading Made-with: Cursor
Squash-merge of iree-org#1091 (dynamic-strides-waveasm) onto reduce_reg_pres branch, adapted for SGPR-based scalar arg mapping. Key changes from PR iree-org#1091: - Handle memref.extract_strided_metadata in waveasm backend - Handle memref.cast propagation for dynamic strides - Handle memref.reinterpret_cast for dynamic stride tracking - Update dynamic strides test & compile options for waveasm - Fix dynamic strides + dynamic dims + buffer ops through waveasm - Partial kernel argument preloading for gfx950 Conflict resolutions (kept our branch's approach): - Scalar kernel args mapped to dedicated SGPRs (not VGPRs) - AGPR-to-VGPR copy before buffer stores preserved - getUserSgprCount() helper adopted from PR header changes - program.getSymName() kept (getKernelName not available) Signed-off-by: Sanket Pandit <sanket.pandit@amd.com> Made-with: Cursor
fbbf5ba to
0d47c33
Compare
Squash-merge of suryajasper/dynamic-strides-waveasm branch that adds support for dynamic strides and dynamic dims through the WaveASM backend. Required for hipblaslt integration where strides are runtime values. Made-with: Cursor
There was a problem hiding this comment.
I went through the full diff (348 additions, 218 deletions across 9 files). The dynamic stride propagation design is solid — lookupSRD + computeVOffsetFromIndices replacing the hardcoded SRD at s[8:11] is a big step forward, and the static-shape computeBufferSizeFromMemRef formula is correct. The typed-op migration away from RawOp string manipulation is also the right call.
That said, two things need to be fixed before this can land:
Must fix
1. computeBufferSizeFromMemRef dynamic fallback is too small (HandlerUtils.cpp:91)
The fallback was reduced from 0x7FFFFFFE (~2 GB) to 0x20000000 (512 MB). The old value was specifically chosen to sit one byte below the Python frontend's OOB sentinel at 0x7FFFFFFF. With the new value, any dynamic-stride buffer over 512 MB silently gets zeros for the upper portion — e.g., a memref<16384x16384xf32> with dynamic strides uses 1 GB, so the upper half reads as zero. No warning, no error, just wrong GEMM results.
The static-shape path above this line is strictly better and should stay. Just revert the dynamic fallback to 0x7FFFFFFE.
2. handleVectorMaskedLoad is missing getPendingSRDBaseAdjust (TranslateFromMLIR.cpp)
The three sibling handlers (handleMemRefLoad, handleMemRefStore, handleVectorStore) all call getPendingSRDBaseAdjust before using the SRD. This handler doesn't. When a masked load is the first consumer of a base-adjusted memref, it'll use the unadjusted SRD base address — reads go to the wrong memory region.
Should fix
3. RegionBuilder.cpp — blanket non-fatal catch is too broad
Converting ALL loop body translation failures from fatal to silent skip was needed for the mask computation chains, but the catch-all scope means a failed arith.muli for stride computation gets silently dropped too. Either whitelist the specific ops that are safe to skip, or at least emitWarning() so failures are visible.
4. Assert passthrough is zero in the masked load handler
The exec-mask path relies on hardware OOB zeroing for inactive lanes. That only works if the passthrough value is zero. Adding an assertion documents this invariant and catches future regressions.
Cleanup
5. _xfail_waveasm_dynamic_strides in dynamic_strides_test.py is dead code (empty pass body), and import pytest is unused (will fail F401 linting). Either wire up actual xfail conditions or remove both.
6. Comment in compile_options.py says "Enable dynamic strides through Wave runtime and LLVM backend" but the property now returns self.wave_runtime unconditionally. Update the comment to match.
Overall this is good work — the architecture is sound and the net effect is a major correctness improvement for dynamic stride handling. Just need the two P0s addressed before merge.
0d47c33 to
7f0151a
Compare
harsh-nod
left a comment
There was a problem hiding this comment.
Verdict: APPROVE (confidence 92/100)
All P0 correctness issues have been verified as resolved in the latest commit:
num_recordsdynamic fallback — restored to safe0x7FFFFFFElimit (HandlerUtils.cpp has zero net diff)handleVectorMaskedLoadSRD adjustment — now includesgetPendingSRDBaseAdjustat lines 1276–1281, matching all sibling handlers- Passthrough zero assertion — comprehensive check at lines 1252–1266 covering both float and int splat values
- Dead code / stale comments — all cleaned up
Three low-priority follow-up items (all P3, none blocking):
- Narrow
RegionBuilder.cppcatch scope from catch-all to a whitelist of safe-to-skip ops - Monitor CI time from doubled test matrix (32 cases vs 16)
- Consider adding waveasm toolchain availability guards in tests
The architecture is sound — the lookupSRD + computeVOffsetFromIndices pattern replacing hardcoded s[8:11], and PackOp-based SRD construction replacing RawOp string manipulation, are meaningful correctness and maintainability improvements. Ready to merge.
2dfca42 to
09be915
Compare
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
…ast to output buffer Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Replace manual SGPR index arithmetic (getNextSwizzleSRDIndex) and RawOp string-based SRD fills with origin/main's typed-op pattern: - emitSRDBaseAdjustment: ExtractOp to decompose source SRD, typed SALU ops on virtual registers, PackOp to reassemble, setSRDValue instead of setSRDIndex - emitSRDPrologue: S_MOV_B64/S_MOV_B32 with DCEProtectOp for both GFX95 and non-GFX95 SRD word fill - lookupSRD: check getSRDValue before getSRDIndex fallback - handleVectorStore: use lookupSRD() instead of inline SRD lookup - Remove getNextSwizzleSRDIndex method, nextSwizzleSRDIndex member, and prologue seeding Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
09be915 to
9a145d0
Compare
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
9a145d0 to
419c338
Compare
This PR adds support for dynamic strides through the waveasm backend. There are 4 main cases that needed to be addressed to ensure complete support. The loads were already using correct static offsets. The only issue was `num_records` being too small for non-contiguous layouts, causing the hardware to return zero for valid accesses. Fix: `computeBufferSizeFromMemRef`. Before the loads could use dynamic strides, the runtime stride values had to actually *reach* the kernel. This is where the cascade of prologue issues hit: 1. **Preload SGPR limit exceeded** — the 9 arguments couldn't all be preloaded, so the kernel got garbage values. Fix: partial preloading strategy. 2. **Scalar reservations DCE'd** — even after loading remaining scalars via `s_load_dword`, the register allocator reused those SGPRs. Fix: `scalar_sgpr_base`/`scalar_sgpr_count` attributes + explicit reservation in `LinearScanPass`. 3. **Mask ops aborted translation** — `vector<Nxindex>` arithmetic in the loop body caused a hard failure. Fix: non-fatal loop body failures in `RegionBuilder`. Even with correct stride math, the loads were broken because: 1. **RawOps dropped in loop bodies** — the SRD fill instructions inside loops were silently discarded. Fix: `AssemblyEmitter` handles `RawOp` in nested regions. 2. **SRD adjustment inside the loop** — lazy emission placed it in the loop body; LICM partially hoisted it, corrupting registers on iteration > 0. Fix: eager adjustment in `handleFatRawBufferCast`. 3. **`handleMemRefLoad` hardcoded the wrong SRD** — it used `s[8:11]` regardless of which buffer was being accessed, and fell back to literal `0` for voffset. Fix: rewrite to use `lookupSRD` + `computeVOffsetFromIndices`. SRD construction throughout (`emitSRDPrologue`, `emitSRDBaseAdjustment`) now uses origin/main's `PackOp`-based pattern: `ExtractOp` to decompose the source SRD, typed SALU ops (`S_ADD_U32`/`S_ADDC_U32`, `S_MOV_B32`/`S_MOV_B64`) on virtual registers, and `PackOp` to reassemble — letting the register allocator handle physical assignment. Prologue SRD fills use `DCEProtectOp` to prevent elimination. This eliminates all manual SGPR index arithmetic (`getNextSwizzleSRDIndex`) and `RawOp` string manipulation for SRD words. The Python codegen emitted per-element `memref.load` with `vector<8xindex>` OOB-index-selection, which waveasm has no handlers for. The loads were dead on arrival. Fix: allow `vector.maskedload` with buffer ops on the waveasm backend in `read_write.py`. ``` Frontend generates translatable IR → Case 4 fix (read_write.py) ↓ Kernel arguments reach the GPU correctly → Case 2 fixes (preloading, SGPR reservation) ↓ SRD prologue fills correct registers → Case 2 fixes (typed ops + DCEProtectOp) ↓ SRD base adjustment builds SSA SRD value → Case 3 fixes (ExtractOp + SALU + PackOp) ↓ Loads/stores use correct SRD + voffset → Case 3 fix (lookupSRD checks getSRDValue first) ↓ num_records doesn't falsely reject accesses → Case 1 fix (computeBufferSizeFromMemRef) ``` --------- Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com> Co-authored-by: Gaurav Verma <48321602+xintin@users.noreply.github.com>
This PR adds support for dynamic strides through the waveasm backend. There are 4 main cases that needed to be addressed to ensure complete support. The loads were already using correct static offsets. The only issue was `num_records` being too small for non-contiguous layouts, causing the hardware to return zero for valid accesses. Fix: `computeBufferSizeFromMemRef`. Before the loads could use dynamic strides, the runtime stride values had to actually *reach* the kernel. This is where the cascade of prologue issues hit: 1. **Preload SGPR limit exceeded** — the 9 arguments couldn't all be preloaded, so the kernel got garbage values. Fix: partial preloading strategy. 2. **Scalar reservations DCE'd** — even after loading remaining scalars via `s_load_dword`, the register allocator reused those SGPRs. Fix: `scalar_sgpr_base`/`scalar_sgpr_count` attributes + explicit reservation in `LinearScanPass`. 3. **Mask ops aborted translation** — `vector<Nxindex>` arithmetic in the loop body caused a hard failure. Fix: non-fatal loop body failures in `RegionBuilder`. Even with correct stride math, the loads were broken because: 1. **RawOps dropped in loop bodies** — the SRD fill instructions inside loops were silently discarded. Fix: `AssemblyEmitter` handles `RawOp` in nested regions. 2. **SRD adjustment inside the loop** — lazy emission placed it in the loop body; LICM partially hoisted it, corrupting registers on iteration > 0. Fix: eager adjustment in `handleFatRawBufferCast`. 3. **`handleMemRefLoad` hardcoded the wrong SRD** — it used `s[8:11]` regardless of which buffer was being accessed, and fell back to literal `0` for voffset. Fix: rewrite to use `lookupSRD` + `computeVOffsetFromIndices`. SRD construction throughout (`emitSRDPrologue`, `emitSRDBaseAdjustment`) now uses origin/main's `PackOp`-based pattern: `ExtractOp` to decompose the source SRD, typed SALU ops (`S_ADD_U32`/`S_ADDC_U32`, `S_MOV_B32`/`S_MOV_B64`) on virtual registers, and `PackOp` to reassemble — letting the register allocator handle physical assignment. Prologue SRD fills use `DCEProtectOp` to prevent elimination. This eliminates all manual SGPR index arithmetic (`getNextSwizzleSRDIndex`) and `RawOp` string manipulation for SRD words. The Python codegen emitted per-element `memref.load` with `vector<8xindex>` OOB-index-selection, which waveasm has no handlers for. The loads were dead on arrival. Fix: allow `vector.maskedload` with buffer ops on the waveasm backend in `read_write.py`. ``` Frontend generates translatable IR → Case 4 fix (read_write.py) ↓ Kernel arguments reach the GPU correctly → Case 2 fixes (preloading, SGPR reservation) ↓ SRD prologue fills correct registers → Case 2 fixes (typed ops + DCEProtectOp) ↓ SRD base adjustment builds SSA SRD value → Case 3 fixes (ExtractOp + SALU + PackOp) ↓ Loads/stores use correct SRD + voffset → Case 3 fix (lookupSRD checks getSRDValue first) ↓ num_records doesn't falsely reject accesses → Case 1 fix (computeBufferSizeFromMemRef) ``` --------- Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com> Co-authored-by: Gaurav Verma <48321602+xintin@users.noreply.github.com>
This PR adds support for dynamic strides through the waveasm backend.
4 Cases for Dynamic Strides
There are 4 main cases that needed to be addressed to ensure complete support.
Case 1: waveasm + dynamic strides
The loads were already using correct static offsets. The only issue was
num_recordsbeing too small for non-contiguous layouts, causing the hardware to return zero for valid accesses. Fix:computeBufferSizeFromMemRef.Case 2: waveasm + dynamic strides + dynamic dims
Before the loads could use dynamic strides, the runtime stride values had to actually reach the kernel. This is where the cascade of prologue issues hit:
s_load_dword, the register allocator reused those SGPRs. Fix:scalar_sgpr_base/scalar_sgpr_countattributes + explicit reservation inLinearScanPass.vector<Nxindex>arithmetic in the loop body caused a hard failure. Fix: non-fatal loop body failures inRegionBuilder.Case 3: waveasm + dynamic strides + buffer ops
Even with correct stride math, the loads were broken because:
AssemblyEmitterhandlesRawOpin nested regions.handleFatRawBufferCast.handleMemRefLoadhardcoded the wrong SRD — it useds[8:11]regardless of which buffer was being accessed, and fell back to literal0for voffset. Fix: rewrite to uselookupSRD+computeVOffsetFromIndices.SRD construction throughout (
emitSRDPrologue,emitSRDBaseAdjustment) now uses origin/main'sPackOp-based pattern:ExtractOpto decompose the source SRD, typed SALU ops (S_ADD_U32/S_ADDC_U32,S_MOV_B32/S_MOV_B64) on virtual registers, andPackOpto reassemble — letting the register allocator handle physical assignment. Prologue SRD fills useDCEProtectOpto prevent elimination. This eliminates all manual SGPR index arithmetic (getNextSwizzleSRDIndex) andRawOpstring manipulation for SRD words.Case 4: waveasm + dynamic strides + dynamic dims + buffer ops
The Python codegen emitted per-element
memref.loadwithvector<8xindex>OOB-index-selection, which waveasm has no handlers for. The loads were dead on arrival. Fix: allowvector.maskedloadwith buffer ops on the waveasm backend inread_write.py.Full Picture