From 626fa7b0be4a947daae68089c6cf56cea8a44259 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Mar 2026 15:35:41 +0000 Subject: [PATCH 1/2] Initial plan From 177aec9374869348dfdb0b012d67ddcdb1a48db9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Mar 2026 15:40:54 +0000 Subject: [PATCH 2/2] Add partial-width tile shapes to all_gather tests Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/x/test_all_gather.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/x/test_all_gather.py b/tests/x/test_all_gather.py index 93dff4ad..17cb74f3 100644 --- a/tests/x/test_all_gather.py +++ b/tests/x/test_all_gather.py @@ -78,7 +78,9 @@ def x_all_gather_kernel( @pytest.mark.parametrize( "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [ - (128, 64, 64, 32), # Small + (128, 64, 64, 32), # Small – BLOCK_N < N (partial-width, was the original failing case) + (128, 128, 64, 32), # Multiple N blocks per rank – BLOCK_N < N/world_size (2 tiles in N per rank) + (256, 128, 64, 16), # Very small BLOCK_N to stress 16-bit vectorization with partial-width tiles (1024, 256, 128, 128), # Medium (2048, 2048, 256, 256), # Large # TODO: Fix non-aligned dimension handling in all_gather for irregular tiling @@ -258,7 +260,9 @@ def x_all_gather_ctx_api_kernel( (torch.float32, 1e-5, 1e-5), ], ) -@pytest.mark.parametrize("M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64)]) +@pytest.mark.parametrize( + "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64), (128, 64, 64, 32), (128, 128, 64, 32)] +) def test_all_gather_ctx_api(gather_dim, dtype, atol, rtol, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): """Test tile-level all-gather using direct function call (ctx methods removed).""" if not dist.is_initialized():