Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions iris/x/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def all_gather(
# Scatter along N dimension: write to [:, ctx.rank * N_local : (ctx.rank+1) * N_local]
dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None)

# Use iris.store to write to dest_rank's memory
# Use iris.store to write to dest_rank's memory.
# hint=(1, tile.block_n) asserts per-row contiguity only (BLOCK_N consecutive
# elements within each row). Using (tile.block_m, tile.block_n) would
# assert cross-row contiguity which is false when BLOCK_N < N (stride_m > BLOCK_N),
# causing getOrderFromContiguity to choose dim-0 for vectorization and emitting
# scalar buffer_store_short writes to wrong addresses.
iris.store(
dst_ptr,
tile.data,
ctx.rank, # from_rank (current rank)
dest_rank, # to_rank (destination rank)
ctx.heap_bases,
mask=combined_mask,
hint=(tile.block_m, tile.block_n),
hint=(1, tile.block_n),
)