Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/x/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def x_all_gather_kernel(
@pytest.mark.parametrize(
"M, N, BLOCK_SIZE_M, BLOCK_SIZE_N",
[
(128, 64, 64, 32), # Small
(128, 64, 64, 32), # Small – BLOCK_N < N (partial-width, was the original failing case)
(128, 128, 64, 32), # Multiple N blocks per rank – BLOCK_N < N/world_size (2 tiles in N per rank)
(256, 128, 64, 16), # Very small BLOCK_N to stress 16-bit vectorization with partial-width tiles
(1024, 256, 128, 128), # Medium
(2048, 2048, 256, 256), # Large
# TODO: Fix non-aligned dimension handling in all_gather for irregular tiling
Expand Down Expand Up @@ -258,7 +260,9 @@ def x_all_gather_ctx_api_kernel(
(torch.float32, 1e-5, 1e-5),
],
)
@pytest.mark.parametrize("M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64)])
@pytest.mark.parametrize(
"M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64), (128, 64, 64, 32), (128, 128, 64, 32)]
)
def test_all_gather_ctx_api(gather_dim, dtype, atol, rtol, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N):
"""Test tile-level all-gather using direct function call (ctx methods removed)."""
if not dist.is_initialized():
Expand Down