diff --git a/iris/x/all_gather.py b/iris/x/all_gather.py index 5f2c734f..a8c84bde 100644 --- a/iris/x/all_gather.py +++ b/iris/x/all_gather.py @@ -64,7 +64,12 @@ def all_gather( # Scatter along N dimension: write to [:, ctx.rank * N_local : (ctx.rank+1) * N_local] dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None) - # Use iris.store to write to dest_rank's memory + # Use iris.store to write to dest_rank's memory. + # hint=(1, tile.block_n) asserts per-row contiguity only (BLOCK_N consecutive + # elements within each row). Using (tile.block_m, tile.block_n) would + # assert cross-row contiguity which is false when BLOCK_N < N (stride_m > BLOCK_N), + # causing getOrderFromContiguity to choose dim-0 for vectorization and emitting + # scalar buffer_store_short writes to wrong addresses. iris.store( dst_ptr, tile.data, @@ -72,5 +77,5 @@ def all_gather( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=combined_mask, - hint=(tile.block_m, tile.block_n), + hint=(1, tile.block_n), )