Skip to content

Commit fb84e10

Browse files
authored
Add partial-width tile shape tests for hint-exercising paths in CCL and X collectives (#436)
2 parents 4324247 + a2b6f90 commit fb84e10

4 files changed

Lines changed: 36 additions & 24 deletions

File tree

tests/ccl/test_all_gather.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121
],
2222
)
2323
@pytest.mark.parametrize(
24-
"M, N",
24+
"M, N, block_size_m, block_size_n",
2525
[
26-
(128, 64), # Small
27-
(1024, 256), # Medium
28-
(8192, 8192), # Large
26+
(128, 64, 32, 64), # Small
27+
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
28+
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
29+
(1024, 256, 32, 64), # Medium
30+
(8192, 8192, 32, 64), # Large
2931
],
3032
)
31-
def test_all_gather(dtype, M, N):
33+
def test_all_gather(dtype, M, N, block_size_m, block_size_n):
3234
"""Test all-gather functionality by comparing against PyTorch's implementation."""
3335
# Ensure torch.distributed is initialized (should be done by test runner)
3436
if not dist.is_initialized():
@@ -62,7 +64,7 @@ def test_all_gather(dtype, M, N):
6264

6365
# Run Iris all_gather
6466
shmem.barrier()
65-
config = Config()
67+
config = Config(block_size_m=block_size_m, block_size_n=block_size_n)
6668
shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config)
6769
torch.cuda.synchronize()
6870

@@ -97,14 +99,16 @@ def test_all_gather(dtype, M, N):
9799
],
98100
)
99101
@pytest.mark.parametrize(
100-
"M, N",
102+
"M, N, block_size_m, block_size_n",
101103
[
102-
(128, 64), # Small
103-
(1024, 256), # Medium
104-
(8192, 8192), # Large
104+
(128, 64, 32, 64), # Small
105+
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
106+
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
107+
(1024, 256, 32, 64), # Medium
108+
(8192, 8192, 32, 64), # Large
105109
],
106110
)
107-
def test_all_gather_partitioned(dtype, M, N):
111+
def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n):
108112
"""Test all-gather with partitioned variant by comparing against PyTorch's implementation."""
109113
# Ensure torch.distributed is initialized (should be done by test runner)
110114
if not dist.is_initialized():
@@ -140,7 +144,9 @@ def test_all_gather_partitioned(dtype, M, N):
140144
# COMM_SMS must be divisible by world_size for partitioned variant
141145
comm_sms = 64 # Assuming world_size divides 64 (e.g., 2, 4, 8)
142146
shmem.barrier()
143-
config = Config(all_gather_variant="partitioned", comm_sms=comm_sms)
147+
config = Config(
148+
block_size_m=block_size_m, block_size_n=block_size_n, all_gather_variant="partitioned", comm_sms=comm_sms
149+
)
144150
shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config)
145151
torch.cuda.synchronize()
146152

tests/ccl/test_all_reduce.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@
3232
],
3333
)
3434
@pytest.mark.parametrize(
35-
"M, N",
35+
"M, N, block_size_m, block_size_n",
3636
[
37-
(128, 64), # Small
38-
(1024, 256), # Medium
39-
(8192, 8192), # Large
37+
(128, 64, 32, 64), # Small
38+
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
39+
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
40+
(1024, 256, 32, 64), # Medium
41+
(8192, 8192, 32, 64), # Large
4042
],
4143
)
42-
def test_all_reduce(variant, dtype, M, N):
44+
def test_all_reduce(variant, dtype, M, N, block_size_m, block_size_n):
4345
"""Test all-reduce functionality by comparing against PyTorch's implementation."""
4446
# Ensure torch.distributed is initialized (should be done by test runner)
4547
if not dist.is_initialized():
@@ -70,7 +72,7 @@ def test_all_reduce(variant, dtype, M, N):
7072

7173
# Run Iris all_reduce with specified variant
7274
shmem.barrier()
73-
config = Config(all_reduce_variant=variant)
75+
config = Config(all_reduce_variant=variant, block_size_m=block_size_m, block_size_n=block_size_n)
7476
if variant == "two_shot":
7577
# Test both distribution modes for two_shot
7678
config.all_reduce_distribution = 0 # striding

tests/ccl/test_all_to_all.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121
],
2222
)
2323
@pytest.mark.parametrize(
24-
"M, N",
24+
"M, N, block_size_m, block_size_n",
2525
[
26-
(128, 64), # Small
27-
(1024, 256), # Medium
28-
(8192, 8192), # Large
26+
(128, 64, 32, 64), # Small
27+
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
28+
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
29+
(1024, 256, 32, 64), # Medium
30+
(8192, 8192, 32, 64), # Large
2931
],
3032
)
31-
def test_all_to_all(dtype, M, N):
33+
def test_all_to_all(dtype, M, N, block_size_m, block_size_n):
3234
"""Test all-to-all functionality by comparing against PyTorch's implementation."""
3335
# Ensure torch.distributed is initialized (should be done by test runner)
3436
if not dist.is_initialized():
@@ -74,7 +76,7 @@ def test_all_to_all(dtype, M, N):
7476

7577
# Run Iris all_to_all
7678
shmem.barrier()
77-
config = Config()
79+
config = Config(block_size_m=block_size_m, block_size_n=block_size_n)
7880
shmem.ccl.all_to_all(iris_output_concat, iris_input_concat, config=config)
7981
torch.cuda.synchronize()
8082

tests/x/test_all_reduce.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def x_all_reduce_spinlock_kernel(
223223
"M, N, BLOCK_SIZE_M, BLOCK_SIZE_N",
224224
[
225225
(128, 64, 64, 32), # Small
226+
(128, 128, 64, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
227+
(256, 128, 64, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
226228
(1024, 256, 128, 128), # Medium
227229
(2048, 2048, 256, 256), # Large
228230
# (100, 100, 64, 64), # Non-aligned dimensions - DISABLED: other=0.0 not supported

0 commit comments

Comments
 (0)