Skip to content

Fix x.all_gather hint: use (1, block_n) for correct vectorization with partial-width tiles#434

Merged
mawad-amd merged 3 commits intomuhaawad/hintsfrom
copilot/sub-pr-432-again
Mar 7, 2026
Merged

Fix x.all_gather hint: use (1, block_n) for correct vectorization with partial-width tiles#434
mawad-amd merged 3 commits intomuhaawad/hintsfrom
copilot/sub-pr-432-again

Conversation

Copy link
Contributor

Copilot AI commented Mar 7, 2026

iris.x.all_gather produced incorrect results (max diff 1.0) for shapes where the tile width is narrower than the tensor width (e.g. M=128, N=64, BLOCK_M=64, BLOCK_N=32).

Root Cause

The hint=(tile.block_m, tile.block_n) passed to iris.store applies tl.max_contiguous to the translated pointer block. Triton's getOrderFromContiguity([BLOCK_M, BLOCK_N]) picks the dimension with the highest contiguity for vectorization. When BLOCK_M > BLOCK_N (e.g. 64 > 32), dim-0 is selected — but for a row-major tensor with stride_m = N > BLOCK_N, rows are not adjacent in memory. The compiler emits scalar buffer_store_short (1 float16/instruction) with addresses derived from the incorrect dim-0 assumption, writing to wrong locations.

This also explains why other shapes were unaffected: when BLOCK_M == BLOCK_N, both dims have equal contiguity so stable_sort preserves reverse order and dim-1 is selected correctly.

Assembly evidence (M=128, N=64, BLOCK_M=64, BLOCK_N=32, float16)

Hint ISA instruction Elements/instr Correct?
(BLOCK_M, BLOCK_N) = (64, 32) buffer_store_short ×8 1 float16 (scalar)
(1, BLOCK_N) = (1, 32) buffer_store_dwordx4 ×1 8 float16 (vectorized)
no hint buffer_store_dwordx4 ×1 8 float16 (vectorized)

Fix

Replace hint=(tile.block_m, tile.block_n) with hint=(1, tile.block_n) in iris/x/all_gather.py.

(1, tile.block_n) correctly asserts only per-row contiguity (BLOCK_N consecutive elements in dim-1). getOrderFromContiguity([1, BLOCK_N]) always selects dim-1, producing buffer_store_dwordx4 — 8× wider than the broken hint, identical to the no-hint case, and always correct for any BLOCK_N/N relationship in row-major tensors.


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

…8-64-64-32

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copilot AI changed the title [WIP] Add hint parameters to IRIS collective communication operations Fix x.all_gather correctness failure for partial-width tiles (BLOCK_N < N) Mar 7, 2026
@mawad-amd
Copy link
Collaborator

@copilot Do not merge the current fix as-is. Removing the hint is not acceptable: without the hint, loads/stores become non-vectorized and we cannot ship that.

Please investigate this issue more deeply:

  1. Look at the generated assembly – Compare the store codegen with the hint vs without. Confirm how vectorization is lost without the hint and what the compiler is doing when the hint is present (and why that breaks for BLOCK_N < N).

  2. Propose a solution that keeps stores vectorized – We need both correctness (for partial-width tiles, e.g. BLOCK_N < N) and vectorized stores. Options to explore:

    • Only pass the hint when the tile is full-width (tile.block_n == N or equivalent) so we get vectorization where valid and safe behavior otherwise.
    • Or use a different hint shape / layout that is valid for strided tiles (if Triton supports it).
    • Or fix the pointer/address computation so that the contiguous hint is valid for the actual memory layout.

Report back with what you find in the assembly and a concrete, reasonable solution that preserves vectorization.

… for correct vectorization

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 7, 2026

@copilot Do not merge the current fix as-is. Removing the hint is not acceptable: without the hint, loads/stores become non-vectorized and we cannot ship that.

Please investigate this issue more deep...

Investigated in assembly — here's what was found:

Assembly comparison (M=128, N=64, BLOCK_M=64, BLOCK_N=32, float16):

Hint ISA instruction Elements/instr Correct?
(BLOCK_M, BLOCK_N) = (64, 32) buffer_store_short ×8 1 float16 (scalar!)
(1, BLOCK_N) = (1, 32) buffer_store_dwordx4 ×1 8 float16 (vectorized)
no hint buffer_store_dwordx4 ×1 8 float16 (vectorized)

Why (BLOCK_M, BLOCK_N) causes both de-vectorization AND wrong addresses:
Triton's getOrderFromContiguity([64, 32]) sorts by descending contiguity, choosing dim-0 first (order[0]=0). The vectorization width is then computed for dim-0. But dim-0 is the row direction — with stride_m=64 > BLOCK_N=32, consecutive rows aren't adjacent in memory. The compiler emits scalar buffer_store_short with addresses derived from the incorrect dim-0 assumption.

Why it accidentally worked for other shapes (e.g., BLOCK_M=BLOCK_N=128):
When both dims have equal contiguity, stable_sort preserves the reversed order → dim-1 is selected → stores are correct and vectorized.

Fix (a0c7506): hint=(1, tile.block_n) claims only per-row contiguity. getOrderFromContiguity([1, BLOCK_N]) always selects dim-1 → buffer_store_dwordx4 (8× wider than the broken hint, same as no hint) and always correct for row-major storage.

Copilot AI changed the title Fix x.all_gather correctness failure for partial-width tiles (BLOCK_N < N) Fix x.all_gather hint: use (1, block_n) for correct vectorization with partial-width tiles Mar 7, 2026
@mawad-amd mawad-amd marked this pull request as ready for review March 7, 2026 15:27
@mawad-amd mawad-amd merged commit c50f442 into muhaawad/hints Mar 7, 2026
@mawad-amd mawad-amd deleted the copilot/sub-pr-432-again branch March 7, 2026 15:28
@mawad-amd
Copy link
Collaborator

@copilot I merged this. Can you suggest tests to make sure all APIs are tested against this behavior? Maybe the existing tests are enough but we need some new shapes in the parametrized tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants