Skip to content

Dynamic stride support through waveasm#1091

Merged
xintin merged 16 commits intoiree-org:mainfrom
suryajasper:dynamic-strides-waveasm
Apr 8, 2026
Merged

Dynamic stride support through waveasm#1091
xintin merged 16 commits intoiree-org:mainfrom
suryajasper:dynamic-strides-waveasm

Conversation

@suryajasper
Copy link
Copy Markdown
Contributor

@suryajasper suryajasper commented Mar 10, 2026

This PR adds support for dynamic strides through the waveasm backend.

4 Cases for Dynamic Strides

There are 4 main cases that needed to be addressed to ensure complete support.

Case 1: waveasm + dynamic strides

The loads were already using correct static offsets. The only issue was num_records being too small for non-contiguous layouts, causing the hardware to return zero for valid accesses. Fix: computeBufferSizeFromMemRef.

Case 2: waveasm + dynamic strides + dynamic dims

Before the loads could use dynamic strides, the runtime stride values had to actually reach the kernel. This is where the cascade of prologue issues hit:

  1. Preload SGPR limit exceeded — the 9 arguments couldn't all be preloaded, so the kernel got garbage values. Fix: partial preloading strategy.
  2. Scalar reservations DCE'd — even after loading remaining scalars via s_load_dword, the register allocator reused those SGPRs. Fix: scalar_sgpr_base/scalar_sgpr_count attributes + explicit reservation in LinearScanPass.
  3. Mask ops aborted translationvector<Nxindex> arithmetic in the loop body caused a hard failure. Fix: non-fatal loop body failures in RegionBuilder.

Case 3: waveasm + dynamic strides + buffer ops

Even with correct stride math, the loads were broken because:

  1. RawOps dropped in loop bodies — the SRD fill instructions inside loops were silently discarded. Fix: AssemblyEmitter handles RawOp in nested regions.
  2. SRD adjustment inside the loop — lazy emission placed it in the loop body; LICM partially hoisted it, corrupting registers on iteration > 0. Fix: eager adjustment in handleFatRawBufferCast.
  3. handleMemRefLoad hardcoded the wrong SRD — it used s[8:11] regardless of which buffer was being accessed, and fell back to literal 0 for voffset. Fix: rewrite to use lookupSRD + computeVOffsetFromIndices.

SRD construction throughout (emitSRDPrologue, emitSRDBaseAdjustment) now uses origin/main's PackOp-based pattern: ExtractOp to decompose the source SRD, typed SALU ops (S_ADD_U32/S_ADDC_U32, S_MOV_B32/S_MOV_B64) on virtual registers, and PackOp to reassemble — letting the register allocator handle physical assignment. Prologue SRD fills use DCEProtectOp to prevent elimination. This eliminates all manual SGPR index arithmetic (getNextSwizzleSRDIndex) and RawOp string manipulation for SRD words.

Case 4: waveasm + dynamic strides + dynamic dims + buffer ops

The Python codegen emitted per-element memref.load with vector<8xindex> OOB-index-selection, which waveasm has no handlers for. The loads were dead on arrival. Fix: allow vector.maskedload with buffer ops on the waveasm backend in read_write.py.

Full Picture

Frontend generates translatable IR          → Case 4 fix (read_write.py)
  ↓
Kernel arguments reach the GPU correctly    → Case 2 fixes (preloading, SGPR reservation)
  ↓
SRD prologue fills correct registers        → Case 2 fixes (typed ops + DCEProtectOp)
  ↓
SRD base adjustment builds SSA SRD value    → Case 3 fixes (ExtractOp + SALU + PackOp)
  ↓
Loads/stores use correct SRD + voffset      → Case 3 fix (lookupSRD checks getSRDValue first)
  ↓
num_records doesn't falsely reject accesses → Case 1 fix (computeBufferSizeFromMemRef)

@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch from b51b0b3 to 2bb545c Compare March 10, 2026 01:35
Comment thread wave_lang/kernel/wave/compile.py Outdated
@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch 2 times, most recently from 136a9f0 to 432dcaa Compare March 17, 2026 22:01
@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch from 47f2ef2 to 1995206 Compare March 25, 2026 21:00
suryajasper added a commit to suryajasper/wave that referenced this pull request Mar 25, 2026
Squashed cherry-pick of suryajasper/dynamic-strides-waveasm onto
4waveasm-256x192x256. Merges partial kernel argument preloading,
extract_strided_metadata handler, and dynamic stride test updates.

Commits included:
- Handle memref.extract_strided_metadata in waveasm backend
- Update dynamic strides test & compile options to include waveasm
- xfail waveasm dynamic strides tests w/ dynamic dims or buffer ops
- Fix dynamic strides + dynamic dims through waveasm & accumulator bitcast
- Fixed dynamic strides with bufops w/ waveasm
- Fix mxfp waveasm example to use (2,2) wave shape
- Fixed waveasm dynamic strides to use partial kernel argument preloading

Made-with: Cursor
panditsa added a commit to panditsa/wave that referenced this pull request Mar 30, 2026
Squash-merge of iree-org#1091 (dynamic-strides-waveasm) onto
reduce_reg_pres branch, adapted for SGPR-based scalar arg mapping.

Key changes from PR iree-org#1091:
- Handle memref.extract_strided_metadata in waveasm backend
- Handle memref.cast propagation for dynamic strides
- Handle memref.reinterpret_cast for dynamic stride tracking
- Update dynamic strides test & compile options for waveasm
- Fix dynamic strides + dynamic dims + buffer ops through waveasm
- Partial kernel argument preloading for gfx950

Conflict resolutions (kept our branch's approach):
- Scalar kernel args mapped to dedicated SGPRs (not VGPRs)
- AGPR-to-VGPR copy before buffer stores preserved
- getUserSgprCount() helper adopted from PR header changes
- program.getSymName() kept (getKernelName not available)

Signed-off-by: Sanket Pandit <sanket.pandit@amd.com>
Made-with: Cursor
@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch from fbbf5ba to 0d47c33 Compare March 31, 2026 01:32
@suryajasper suryajasper marked this pull request as ready for review March 31, 2026 01:32
Comment thread waveasm/lib/Transforms/handlers/HandlerUtils.cpp Outdated
Comment thread waveasm/lib/Transforms/AssemblyEmitter.cpp Outdated
Comment thread tests/kernel/dynamic_strides_test.py Outdated
Comment thread tests/kernel/dynamic_strides_test.py Outdated
Comment thread wave_lang/kernel/wave/compile_options.py Outdated
Copy link
Copy Markdown
Collaborator

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below.

Copy link
Copy Markdown
Collaborator

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below.

suryajasper pushed a commit to suryajasper/wave that referenced this pull request Apr 2, 2026
Squash-merge of suryajasper/dynamic-strides-waveasm branch that adds
support for dynamic strides and dynamic dims through the WaveASM backend.
Required for hipblaslt integration where strides are runtime values.

Made-with: Cursor
@harsh-nod harsh-nod dismissed stale reviews from themself April 2, 2026 15:46

Superseded by updated review

Copy link
Copy Markdown
Collaborator

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the full diff (348 additions, 218 deletions across 9 files). The dynamic stride propagation design is solid — lookupSRD + computeVOffsetFromIndices replacing the hardcoded SRD at s[8:11] is a big step forward, and the static-shape computeBufferSizeFromMemRef formula is correct. The typed-op migration away from RawOp string manipulation is also the right call.

That said, two things need to be fixed before this can land:

Must fix

1. computeBufferSizeFromMemRef dynamic fallback is too small (HandlerUtils.cpp:91)

The fallback was reduced from 0x7FFFFFFE (~2 GB) to 0x20000000 (512 MB). The old value was specifically chosen to sit one byte below the Python frontend's OOB sentinel at 0x7FFFFFFF. With the new value, any dynamic-stride buffer over 512 MB silently gets zeros for the upper portion — e.g., a memref<16384x16384xf32> with dynamic strides uses 1 GB, so the upper half reads as zero. No warning, no error, just wrong GEMM results.

The static-shape path above this line is strictly better and should stay. Just revert the dynamic fallback to 0x7FFFFFFE.

2. handleVectorMaskedLoad is missing getPendingSRDBaseAdjust (TranslateFromMLIR.cpp)

The three sibling handlers (handleMemRefLoad, handleMemRefStore, handleVectorStore) all call getPendingSRDBaseAdjust before using the SRD. This handler doesn't. When a masked load is the first consumer of a base-adjusted memref, it'll use the unadjusted SRD base address — reads go to the wrong memory region.

Should fix

3. RegionBuilder.cpp — blanket non-fatal catch is too broad

Converting ALL loop body translation failures from fatal to silent skip was needed for the mask computation chains, but the catch-all scope means a failed arith.muli for stride computation gets silently dropped too. Either whitelist the specific ops that are safe to skip, or at least emitWarning() so failures are visible.

4. Assert passthrough is zero in the masked load handler

The exec-mask path relies on hardware OOB zeroing for inactive lanes. That only works if the passthrough value is zero. Adding an assertion documents this invariant and catches future regressions.

Cleanup

5. _xfail_waveasm_dynamic_strides in dynamic_strides_test.py is dead code (empty pass body), and import pytest is unused (will fail F401 linting). Either wire up actual xfail conditions or remove both.

6. Comment in compile_options.py says "Enable dynamic strides through Wave runtime and LLVM backend" but the property now returns self.wave_runtime unconditionally. Update the comment to match.


Overall this is good work — the architecture is sound and the net effect is a major correctness improvement for dynamic stride handling. Just need the two P0s addressed before merge.

Comment thread tests/kernel/dynamic_strides_test.py
Comment thread waveasm/lib/Transforms/handlers/HandlerUtils.cpp Outdated
Comment thread waveasm/lib/Transforms/TranslateFromMLIR.cpp
@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch from 0d47c33 to 7f0151a Compare April 3, 2026 00:58
@suryajasper suryajasper requested a review from harsh-nod April 3, 2026 01:00
Copy link
Copy Markdown
Collaborator

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verdict: APPROVE (confidence 92/100)

All P0 correctness issues have been verified as resolved in the latest commit:

  • num_records dynamic fallback — restored to safe 0x7FFFFFFE limit (HandlerUtils.cpp has zero net diff)
  • handleVectorMaskedLoad SRD adjustment — now includes getPendingSRDBaseAdjust at lines 1276–1281, matching all sibling handlers
  • Passthrough zero assertion — comprehensive check at lines 1252–1266 covering both float and int splat values
  • Dead code / stale comments — all cleaned up

Three low-priority follow-up items (all P3, none blocking):

  1. Narrow RegionBuilder.cpp catch scope from catch-all to a whitelist of safe-to-skip ops
  2. Monitor CI time from doubled test matrix (32 cases vs 16)
  3. Consider adding waveasm toolchain availability guards in tests

The architecture is sound — the lookupSRD + computeVOffsetFromIndices pattern replacing hardcoded s[8:11], and PackOp-based SRD construction replacing RawOp string manipulation, are meaningful correctness and maintainability improvements. Ready to merge.

@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch 2 times, most recently from 2dfca42 to 09be915 Compare April 7, 2026 21:00
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
…ast to output buffer

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Replace manual SGPR index arithmetic (getNextSwizzleSRDIndex) and RawOp
string-based SRD fills with origin/main's typed-op pattern:
- emitSRDBaseAdjustment: ExtractOp to decompose source SRD, typed SALU
  ops on virtual registers, PackOp to reassemble, setSRDValue instead
  of setSRDIndex
- emitSRDPrologue: S_MOV_B64/S_MOV_B32 with DCEProtectOp for both
  GFX95 and non-GFX95 SRD word fill
- lookupSRD: check getSRDValue before getSRDIndex fallback
- handleVectorStore: use lookupSRD() instead of inline SRD lookup
- Remove getNextSwizzleSRDIndex method, nextSwizzleSRDIndex member,
  and prologue seeding

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch from 09be915 to 9a145d0 Compare April 7, 2026 21:03
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
@suryajasper suryajasper force-pushed the dynamic-strides-waveasm branch from 9a145d0 to 419c338 Compare April 8, 2026 00:39
@xintin xintin merged commit dfa45bb into iree-org:main Apr 8, 2026
18 of 19 checks passed
panditsa pushed a commit to panditsa/wave that referenced this pull request Apr 10, 2026
This PR adds support for dynamic strides through the waveasm backend.

There are 4 main cases that needed to be addressed to ensure complete
support.

The loads were already using correct static offsets. The only issue was
`num_records` being too small for non-contiguous layouts, causing the
hardware to return zero for valid accesses. Fix:
`computeBufferSizeFromMemRef`.

Before the loads could use dynamic strides, the runtime stride values
had to actually *reach* the kernel. This is where the cascade of
prologue issues hit:

1. **Preload SGPR limit exceeded** — the 9 arguments couldn't all be
preloaded, so the kernel got garbage values. Fix: partial preloading
strategy.
2. **Scalar reservations DCE'd** — even after loading remaining scalars
via `s_load_dword`, the register allocator reused those SGPRs. Fix:
`scalar_sgpr_base`/`scalar_sgpr_count` attributes + explicit reservation
in `LinearScanPass`.
3. **Mask ops aborted translation** — `vector<Nxindex>` arithmetic in
the loop body caused a hard failure. Fix: non-fatal loop body failures
in `RegionBuilder`.

Even with correct stride math, the loads were broken because:

1. **RawOps dropped in loop bodies** — the SRD fill instructions inside
loops were silently discarded. Fix: `AssemblyEmitter` handles `RawOp` in
nested regions.
2. **SRD adjustment inside the loop** — lazy emission placed it in the
loop body; LICM partially hoisted it, corrupting registers on iteration
> 0. Fix: eager adjustment in `handleFatRawBufferCast`.
3. **`handleMemRefLoad` hardcoded the wrong SRD** — it used `s[8:11]`
regardless of which buffer was being accessed, and fell back to literal
`0` for voffset. Fix: rewrite to use `lookupSRD` +
`computeVOffsetFromIndices`.

SRD construction throughout (`emitSRDPrologue`, `emitSRDBaseAdjustment`)
now uses origin/main's `PackOp`-based pattern: `ExtractOp` to decompose
the source SRD, typed SALU ops (`S_ADD_U32`/`S_ADDC_U32`,
`S_MOV_B32`/`S_MOV_B64`) on virtual registers, and `PackOp` to
reassemble — letting the register allocator handle physical assignment.
Prologue SRD fills use `DCEProtectOp` to prevent elimination. This
eliminates all manual SGPR index arithmetic (`getNextSwizzleSRDIndex`)
and `RawOp` string manipulation for SRD words.

The Python codegen emitted per-element `memref.load` with
`vector<8xindex>` OOB-index-selection, which waveasm has no handlers
for. The loads were dead on arrival. Fix: allow `vector.maskedload` with
buffer ops on the waveasm backend in `read_write.py`.

```
Frontend generates translatable IR          → Case 4 fix (read_write.py)
  ↓
Kernel arguments reach the GPU correctly    → Case 2 fixes (preloading, SGPR reservation)
  ↓
SRD prologue fills correct registers        → Case 2 fixes (typed ops + DCEProtectOp)
  ↓
SRD base adjustment builds SSA SRD value    → Case 3 fixes (ExtractOp + SALU + PackOp)
  ↓
Loads/stores use correct SRD + voffset      → Case 3 fix (lookupSRD checks getSRDValue first)
  ↓
num_records doesn't falsely reject accesses → Case 1 fix (computeBufferSizeFromMemRef)
```

---------

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Co-authored-by: Gaurav Verma <48321602+xintin@users.noreply.github.com>
panditsa pushed a commit to panditsa/wave that referenced this pull request Apr 10, 2026
This PR adds support for dynamic strides through the waveasm backend.

There are 4 main cases that needed to be addressed to ensure complete
support.

The loads were already using correct static offsets. The only issue was
`num_records` being too small for non-contiguous layouts, causing the
hardware to return zero for valid accesses. Fix:
`computeBufferSizeFromMemRef`.

Before the loads could use dynamic strides, the runtime stride values
had to actually *reach* the kernel. This is where the cascade of
prologue issues hit:

1. **Preload SGPR limit exceeded** — the 9 arguments couldn't all be
preloaded, so the kernel got garbage values. Fix: partial preloading
strategy.
2. **Scalar reservations DCE'd** — even after loading remaining scalars
via `s_load_dword`, the register allocator reused those SGPRs. Fix:
`scalar_sgpr_base`/`scalar_sgpr_count` attributes + explicit reservation
in `LinearScanPass`.
3. **Mask ops aborted translation** — `vector<Nxindex>` arithmetic in
the loop body caused a hard failure. Fix: non-fatal loop body failures
in `RegionBuilder`.

Even with correct stride math, the loads were broken because:

1. **RawOps dropped in loop bodies** — the SRD fill instructions inside
loops were silently discarded. Fix: `AssemblyEmitter` handles `RawOp` in
nested regions.
2. **SRD adjustment inside the loop** — lazy emission placed it in the
loop body; LICM partially hoisted it, corrupting registers on iteration
> 0. Fix: eager adjustment in `handleFatRawBufferCast`.
3. **`handleMemRefLoad` hardcoded the wrong SRD** — it used `s[8:11]`
regardless of which buffer was being accessed, and fell back to literal
`0` for voffset. Fix: rewrite to use `lookupSRD` +
`computeVOffsetFromIndices`.

SRD construction throughout (`emitSRDPrologue`, `emitSRDBaseAdjustment`)
now uses origin/main's `PackOp`-based pattern: `ExtractOp` to decompose
the source SRD, typed SALU ops (`S_ADD_U32`/`S_ADDC_U32`,
`S_MOV_B32`/`S_MOV_B64`) on virtual registers, and `PackOp` to
reassemble — letting the register allocator handle physical assignment.
Prologue SRD fills use `DCEProtectOp` to prevent elimination. This
eliminates all manual SGPR index arithmetic (`getNextSwizzleSRDIndex`)
and `RawOp` string manipulation for SRD words.

The Python codegen emitted per-element `memref.load` with
`vector<8xindex>` OOB-index-selection, which waveasm has no handlers
for. The loads were dead on arrival. Fix: allow `vector.maskedload` with
buffer ops on the waveasm backend in `read_write.py`.

```
Frontend generates translatable IR          → Case 4 fix (read_write.py)
  ↓
Kernel arguments reach the GPU correctly    → Case 2 fixes (preloading, SGPR reservation)
  ↓
SRD prologue fills correct registers        → Case 2 fixes (typed ops + DCEProtectOp)
  ↓
SRD base adjustment builds SSA SRD value    → Case 3 fixes (ExtractOp + SALU + PackOp)
  ↓
Loads/stores use correct SRD + voffset      → Case 3 fix (lookupSRD checks getSRDValue first)
  ↓
num_records doesn't falsely reject accesses → Case 1 fix (computeBufferSizeFromMemRef)
```

---------

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Co-authored-by: Gaurav Verma <48321602+xintin@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants