diff --git a/tests/x/test_all_gather.py b/tests/x/test_all_gather.py index 93dff4ad..17cb74f3 100644 --- a/tests/x/test_all_gather.py +++ b/tests/x/test_all_gather.py @@ -78,7 +78,9 @@ def x_all_gather_kernel( @pytest.mark.parametrize( "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [ - (128, 64, 64, 32), # Small + (128, 64, 64, 32), # Small – BLOCK_N < N (partial-width, was the original failing case) + (128, 128, 64, 32), # Multiple N blocks per rank – BLOCK_N < N/world_size (2 tiles in N per rank) + (256, 128, 64, 16), # Very small BLOCK_N to stress 16-bit vectorization with partial-width tiles (1024, 256, 128, 128), # Medium (2048, 2048, 256, 256), # Large # TODO: Fix non-aligned dimension handling in all_gather for irregular tiling @@ -258,7 +260,9 @@ def x_all_gather_ctx_api_kernel( (torch.float32, 1e-5, 1e-5), ], ) -@pytest.mark.parametrize("M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64)]) +@pytest.mark.parametrize( + "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64), (128, 64, 64, 32), (128, 128, 64, 32)] +) def test_all_gather_ctx_api(gather_dim, dtype, atol, rtol, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): """Test tile-level all-gather using direct function call (ctx methods removed).""" if not dist.is_initialized():