diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py index f2c50e2f..7858ed18 100644 --- a/tests/ccl/test_all_gather.py +++ b/tests/ccl/test_all_gather.py @@ -21,14 +21,16 @@ ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_gather(dtype, M, N): +def test_all_gather(dtype, M, N, block_size_m, block_size_n): """Test all-gather functionality by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -62,7 +64,7 @@ def test_all_gather(dtype, M, N): # Run Iris all_gather shmem.barrier() - config = Config() + config = Config(block_size_m=block_size_m, block_size_n=block_size_n) shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) torch.cuda.synchronize() @@ -97,14 +99,16 @@ def test_all_gather(dtype, M, N): ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_gather_partitioned(dtype, M, N): +def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n): """Test all-gather with partitioned variant by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -140,7 +144,9 @@ def test_all_gather_partitioned(dtype, M, N): # COMM_SMS must be divisible by world_size for partitioned variant comm_sms = 64 # Assuming world_size divides 64 (e.g., 2, 4, 8) shmem.barrier() - config = Config(all_gather_variant="partitioned", comm_sms=comm_sms) + config = Config( + block_size_m=block_size_m, block_size_n=block_size_n, all_gather_variant="partitioned", comm_sms=comm_sms + ) shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) torch.cuda.synchronize() diff --git a/tests/ccl/test_all_reduce.py b/tests/ccl/test_all_reduce.py index ffd55e9d..1862e0e3 100644 --- a/tests/ccl/test_all_reduce.py +++ b/tests/ccl/test_all_reduce.py @@ -32,14 +32,16 @@ ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_reduce(variant, dtype, M, N): +def test_all_reduce(variant, dtype, M, N, block_size_m, block_size_n): """Test all-reduce functionality by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -70,7 +72,7 @@ def test_all_reduce(variant, dtype, M, N): # Run Iris all_reduce with specified variant shmem.barrier() - config = Config(all_reduce_variant=variant) + config = Config(all_reduce_variant=variant, block_size_m=block_size_m, block_size_n=block_size_n) if variant == "two_shot": # Test both distribution modes for two_shot config.all_reduce_distribution = 0 # striding diff --git a/tests/ccl/test_all_to_all.py b/tests/ccl/test_all_to_all.py index 76478f5a..99e9bf19 100644 --- a/tests/ccl/test_all_to_all.py +++ b/tests/ccl/test_all_to_all.py @@ -21,14 +21,16 @@ ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_to_all(dtype, M, N): +def test_all_to_all(dtype, M, N, block_size_m, block_size_n): """Test all-to-all functionality by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -74,7 +76,7 @@ def test_all_to_all(dtype, M, N): # Run Iris all_to_all shmem.barrier() - config = Config() + config = Config(block_size_m=block_size_m, block_size_n=block_size_n) shmem.ccl.all_to_all(iris_output_concat, iris_input_concat, config=config) torch.cuda.synchronize() diff --git a/tests/x/test_all_reduce.py b/tests/x/test_all_reduce.py index 30549a50..6e8934c1 100644 --- a/tests/x/test_all_reduce.py +++ b/tests/x/test_all_reduce.py @@ -223,6 +223,8 @@ def x_all_reduce_spinlock_kernel( "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [ (128, 64, 64, 32), # Small + (128, 128, 64, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 64, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) (1024, 256, 128, 128), # Medium (2048, 2048, 256, 256), # Large # (100, 100, 64, 64), # Non-aligned dimensions - DISABLED: other=0.0 not supported