|
21 | 21 | ], |
22 | 22 | ) |
23 | 23 | @pytest.mark.parametrize( |
24 | | - "M, N", |
| 24 | + "M, N, block_size_m, block_size_n", |
25 | 25 | [ |
26 | | - (128, 64), # Small |
27 | | - (1024, 256), # Medium |
28 | | - (8192, 8192), # Large |
| 26 | + (128, 64, 32, 64), # Small |
| 27 | + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) |
| 28 | + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) |
| 29 | + (1024, 256, 32, 64), # Medium |
| 30 | + (8192, 8192, 32, 64), # Large |
29 | 31 | ], |
30 | 32 | ) |
31 | | -def test_all_gather(dtype, M, N): |
| 33 | +def test_all_gather(dtype, M, N, block_size_m, block_size_n): |
32 | 34 | """Test all-gather functionality by comparing against PyTorch's implementation.""" |
33 | 35 | # Ensure torch.distributed is initialized (should be done by test runner) |
34 | 36 | if not dist.is_initialized(): |
@@ -62,7 +64,7 @@ def test_all_gather(dtype, M, N): |
62 | 64 |
|
63 | 65 | # Run Iris all_gather |
64 | 66 | shmem.barrier() |
65 | | - config = Config() |
| 67 | + config = Config(block_size_m=block_size_m, block_size_n=block_size_n) |
66 | 68 | shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) |
67 | 69 | torch.cuda.synchronize() |
68 | 70 |
|
@@ -97,14 +99,16 @@ def test_all_gather(dtype, M, N): |
97 | 99 | ], |
98 | 100 | ) |
99 | 101 | @pytest.mark.parametrize( |
100 | | - "M, N", |
| 102 | + "M, N, block_size_m, block_size_n", |
101 | 103 | [ |
102 | | - (128, 64), # Small |
103 | | - (1024, 256), # Medium |
104 | | - (8192, 8192), # Large |
| 104 | + (128, 64, 32, 64), # Small |
| 105 | + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) |
| 106 | + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) |
| 107 | + (1024, 256, 32, 64), # Medium |
| 108 | + (8192, 8192, 32, 64), # Large |
105 | 109 | ], |
106 | 110 | ) |
107 | | -def test_all_gather_partitioned(dtype, M, N): |
| 111 | +def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n): |
108 | 112 | """Test all-gather with partitioned variant by comparing against PyTorch's implementation.""" |
109 | 113 | # Ensure torch.distributed is initialized (should be done by test runner) |
110 | 114 | if not dist.is_initialized(): |
@@ -140,7 +144,9 @@ def test_all_gather_partitioned(dtype, M, N): |
140 | 144 | # COMM_SMS must be divisible by world_size for partitioned variant |
141 | 145 | comm_sms = 64 # Assuming world_size divides 64 (e.g., 2, 4, 8) |
142 | 146 | shmem.barrier() |
143 | | - config = Config(all_gather_variant="partitioned", comm_sms=comm_sms) |
| 147 | + config = Config( |
| 148 | + block_size_m=block_size_m, block_size_n=block_size_n, all_gather_variant="partitioned", comm_sms=comm_sms |
| 149 | + ) |
144 | 150 | shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) |
145 | 151 | torch.cuda.synchronize() |
146 | 152 |
|
|
0 commit comments