Skip to content

Commit c188781

Browse files
committed
Fused all-gather matmul
1 parent 29ce019 commit c188781

1 file changed

Lines changed: 149 additions & 0 deletions

File tree

samples/all_gather_matmul.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Example demonstrating all-gather and matrix multiplication in a single kernel.
7+
8+
Run with:
9+
torchrun --nproc-per-node 4 --standalone all_gather_matmul.py
10+
"""
11+
12+
import torch
13+
import torch.distributed as dist
14+
import torch.distributed._symmetric_memory as symm_mem
15+
import cuda.tile as ct
16+
17+
# cuTile kernel for gather then matmul
18+
# Limitation:
19+
# Only support 4 ranks because cuTile does not support tuple as input
20+
@ct.kernel
21+
def gather_matmul_kernel(
22+
inp_0, inp_1, inp_2, inp_3,
23+
w,
24+
out,
25+
tile_m: ct.Constant[int],
26+
tile_n: ct.Constant[int],
27+
tile_k: ct.Constant[int],
28+
):
29+
# Number of m tiles per peer
30+
num_tiles_m_per_peer = ct.cdiv(inp_0.shape[0], tile_m)
31+
num_tiles_k = ct.num_tiles(w, axis=0, shape=(tile_k, tile_n))
32+
33+
# 0-dim maps to m_tile_idx, 1-dim maps to n_tile_idx
34+
m_tile_idx = ct.bid(0)
35+
n_tile_idx = ct.bid(1)
36+
37+
# Which peer is this tile in?
38+
peer = m_tile_idx // num_tiles_m_per_peer
39+
peer_inp = inp_0 if peer == 0 else inp_1 if peer == 1 else inp_2 if peer == 2 else inp_3
40+
m_tile_idx_in_peer = m_tile_idx % num_tiles_m_per_peer
41+
42+
# Initialize accumulator
43+
accumulator = ct.full((tile_m, tile_n), 0, dtype=ct.float32)
44+
zero_pad = ct.PaddingMode.ZERO
45+
46+
# Convert fp32 to tf32 to use tensorcore
47+
dtype = ct.tfloat32 if peer_inp.dtype == ct.float32 else peer_inp.dtype
48+
49+
for k in range(num_tiles_k):
50+
# Load remote input tile
51+
a = ct.load(peer_inp, index=(m_tile_idx_in_peer, k), shape=(tile_m, tile_k), padding_mode=zero_pad).astype(dtype)
52+
# Load weight tile
53+
b = ct.load(w, index=(k, n_tile_idx), shape=(tile_k, tile_n), padding_mode=zero_pad).astype(dtype)
54+
# Perform matrix multiplication
55+
accumulator = ct.mma(a, b, accumulator)
56+
57+
# Cast result back to output dtype
58+
accumulator = ct.astype(accumulator, out.dtype)
59+
60+
# Store result tile
61+
ct.store(out, index=(m_tile_idx, n_tile_idx), tile=accumulator)
62+
63+
64+
# Host-side launcher for all-gather
65+
def cutile_gather_matmul(
66+
inp: torch.Tensor,
67+
w: torch.Tensor,
68+
group: dist.ProcessGroup,
69+
):
70+
handle = symm_mem.rendezvous(inp, group.group_name)
71+
world_size = handle.world_size
72+
inp_tuple = tuple(
73+
handle.get_buffer(rank, inp.shape, inp.dtype, 0) for rank in range(world_size)
74+
)
75+
assert world_size == 4
76+
77+
# Allocate output tensor
78+
M = inp.shape[0]
79+
M_out = M * world_size
80+
N = w.shape[1]
81+
out = torch.empty(M_out, N, device=inp.device)
82+
83+
assert inp.shape[1] == w.shape[0], "reduction dimension mismatch"
84+
K = inp.shape[1]
85+
tile_m = 128
86+
tile_n = 128
87+
tile_k = 128
88+
assert M % tile_m == 0
89+
assert N % tile_n == 0
90+
assert K % tile_k == 0
91+
92+
# Map each output tile to a block
93+
grid = (ct.cdiv(M_out, tile_m), ct.cdiv(N, tile_n),)
94+
ct.launch(
95+
torch.cuda.current_stream(),
96+
grid,
97+
gather_matmul_kernel,
98+
(*inp_tuple, w, out, tile_m, tile_n, tile_k),
99+
)
100+
101+
return out
102+
103+
104+
# Reference gather then matmul implementation
105+
def ref_gather_matmul(
106+
inp: torch.Tensor,
107+
w: torch.Tensor,
108+
group: dist.ProcessGroup,
109+
):
110+
world_size = dist.get_world_size(group)
111+
ag_scratch = torch.empty((world_size * inp.shape[0], inp.shape[1]), device=inp.device)
112+
dist.all_gather_into_tensor(ag_scratch, inp, group=group)
113+
out = ag_scratch @ w
114+
return out
115+
116+
117+
def main():
118+
dist.init_process_group(backend="nccl")
119+
rank = dist.get_rank()
120+
world_size = dist.get_world_size()
121+
device = torch.device(f"cuda:{rank}")
122+
group = dist.group.WORLD
123+
torch.manual_seed(rank + 52)
124+
print(f"Rank {rank} of {world_size} is initializing")
125+
126+
bs = 256
127+
hid = 1024
128+
out_hid = 512
129+
ref_inp = torch.rand((bs, hid), device=device)
130+
inp = symm_mem.empty(bs, hid, device=device).copy_(ref_inp)
131+
w = torch.rand((hid, out_hid), device=device)
132+
133+
expected_out = ref_gather_matmul(ref_inp, w, group)
134+
135+
out = cutile_gather_matmul(inp, w, group)
136+
137+
torch.testing.assert_close(
138+
out,
139+
expected_out,
140+
atol=1e-3,
141+
rtol=1e-3,
142+
)
143+
144+
print(f"Rank {rank} of {world_size}: correct")
145+
dist.destroy_process_group()
146+
147+
148+
if __name__ == "__main__":
149+
main()

0 commit comments

Comments
 (0)