From 6fe8e5bbe497fd24836310315b80be73d1f9de95 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Mar 2026 13:33:57 +0000 Subject: [PATCH 1/3] Initial plan From 2d09df523f2f0fcb572c1e836806d115326da350 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Mar 2026 13:45:45 +0000 Subject: [PATCH 2/3] Fix incorrect hint in x.all_gather causing test failures for shape 128-64-64-32 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/x/all_gather.py | 1 - 1 file changed, 1 deletion(-) diff --git a/iris/x/all_gather.py b/iris/x/all_gather.py index 5f2c734f..a357ab7f 100644 --- a/iris/x/all_gather.py +++ b/iris/x/all_gather.py @@ -72,5 +72,4 @@ def all_gather( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=combined_mask, - hint=(tile.block_m, tile.block_n), ) From a0c7506f4328a8e82f371df00b08e63634f59270 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Mar 2026 15:03:28 +0000 Subject: [PATCH 3/3] Fix x.all_gather hint: use (1, block_n) instead of (block_m, block_n) for correct vectorization Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/x/all_gather.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/iris/x/all_gather.py b/iris/x/all_gather.py index a357ab7f..a8c84bde 100644 --- a/iris/x/all_gather.py +++ b/iris/x/all_gather.py @@ -64,7 +64,12 @@ def all_gather( # Scatter along N dimension: write to [:, ctx.rank * N_local : (ctx.rank+1) * N_local] dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None) - # Use iris.store to write to dest_rank's memory + # Use iris.store to write to dest_rank's memory. + # hint=(1, tile.block_n) asserts per-row contiguity only (BLOCK_N consecutive + # elements within each row). Using (tile.block_m, tile.block_n) would + # assert cross-row contiguity which is false when BLOCK_N < N (stride_m > BLOCK_N), + # causing getOrderFromContiguity to choose dim-0 for vectorization and emitting + # scalar buffer_store_short writes to wrong addresses. iris.store( dst_ptr, tile.data, @@ -72,4 +77,5 @@ def all_gather( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=combined_mask, + hint=(1, tile.block_n), )