|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +""" |
| 6 | +Example demonstrating all-gather and matrix multiplication in a single kernel. |
| 7 | +
|
| 8 | +Run with: |
| 9 | + torchrun --nproc-per-node 4 --standalone all_gather_matmul.py |
| 10 | +""" |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.distributed as dist |
| 14 | +import torch.distributed._symmetric_memory as symm_mem |
| 15 | +import cuda.tile as ct |
| 16 | + |
| 17 | +# cuTile kernel for gather then matmul |
| 18 | +# Limitation: |
| 19 | +# Only support 4 ranks because cuTile does not support tuple as input |
| 20 | +@ct.kernel |
| 21 | +def gather_matmul_kernel( |
| 22 | + inp_0, inp_1, inp_2, inp_3, |
| 23 | + w, |
| 24 | + out, |
| 25 | + tile_m: ct.Constant[int], |
| 26 | + tile_n: ct.Constant[int], |
| 27 | + tile_k: ct.Constant[int], |
| 28 | +): |
| 29 | + # Number of m tiles per peer |
| 30 | + num_tiles_m_per_peer = ct.cdiv(inp_0.shape[0], tile_m) |
| 31 | + num_tiles_k = ct.num_tiles(w, axis=0, shape=(tile_k, tile_n)) |
| 32 | + |
| 33 | + # 0-dim maps to m_tile_idx, 1-dim maps to n_tile_idx |
| 34 | + m_tile_idx = ct.bid(0) |
| 35 | + n_tile_idx = ct.bid(1) |
| 36 | + |
| 37 | + # Which peer is this tile in? |
| 38 | + peer = m_tile_idx // num_tiles_m_per_peer |
| 39 | + peer_inp = inp_0 if peer == 0 else inp_1 if peer == 1 else inp_2 if peer == 2 else inp_3 |
| 40 | + m_tile_idx_in_peer = m_tile_idx % num_tiles_m_per_peer |
| 41 | + |
| 42 | + # Initialize accumulator |
| 43 | + accumulator = ct.full((tile_m, tile_n), 0, dtype=ct.float32) |
| 44 | + zero_pad = ct.PaddingMode.ZERO |
| 45 | + |
| 46 | + # Convert fp32 to tf32 to use tensorcore |
| 47 | + dtype = ct.tfloat32 if peer_inp.dtype == ct.float32 else peer_inp.dtype |
| 48 | + |
| 49 | + for k in range(num_tiles_k): |
| 50 | + # Load remote input tile |
| 51 | + a = ct.load(peer_inp, index=(m_tile_idx_in_peer, k), shape=(tile_m, tile_k), padding_mode=zero_pad).astype(dtype) |
| 52 | + # Load weight tile |
| 53 | + b = ct.load(w, index=(k, n_tile_idx), shape=(tile_k, tile_n), padding_mode=zero_pad).astype(dtype) |
| 54 | + # Perform matrix multiplication |
| 55 | + accumulator = ct.mma(a, b, accumulator) |
| 56 | + |
| 57 | + # Cast result back to output dtype |
| 58 | + accumulator = ct.astype(accumulator, out.dtype) |
| 59 | + |
| 60 | + # Store result tile |
| 61 | + ct.store(out, index=(m_tile_idx, n_tile_idx), tile=accumulator) |
| 62 | + |
| 63 | + |
| 64 | +# Host-side launcher for all-gather |
| 65 | +def cutile_gather_matmul( |
| 66 | + inp: torch.Tensor, |
| 67 | + w: torch.Tensor, |
| 68 | + group: dist.ProcessGroup, |
| 69 | +): |
| 70 | + handle = symm_mem.rendezvous(inp, group.group_name) |
| 71 | + world_size = handle.world_size |
| 72 | + inp_tuple = tuple( |
| 73 | + handle.get_buffer(rank, inp.shape, inp.dtype, 0) for rank in range(world_size) |
| 74 | + ) |
| 75 | + assert world_size == 4 |
| 76 | + |
| 77 | + # Allocate output tensor |
| 78 | + M = inp.shape[0] |
| 79 | + M_out = M * world_size |
| 80 | + N = w.shape[1] |
| 81 | + out = torch.empty(M_out, N, device=inp.device) |
| 82 | + |
| 83 | + assert inp.shape[1] == w.shape[0], "reduction dimension mismatch" |
| 84 | + K = inp.shape[1] |
| 85 | + tile_m = 128 |
| 86 | + tile_n = 128 |
| 87 | + tile_k = 128 |
| 88 | + assert M % tile_m == 0 |
| 89 | + assert N % tile_n == 0 |
| 90 | + assert K % tile_k == 0 |
| 91 | + |
| 92 | + # Map each output tile to a block |
| 93 | + grid = (ct.cdiv(M_out, tile_m), ct.cdiv(N, tile_n),) |
| 94 | + ct.launch( |
| 95 | + torch.cuda.current_stream(), |
| 96 | + grid, |
| 97 | + gather_matmul_kernel, |
| 98 | + (*inp_tuple, w, out, tile_m, tile_n, tile_k), |
| 99 | + ) |
| 100 | + |
| 101 | + return out |
| 102 | + |
| 103 | + |
| 104 | +# Reference gather then matmul implementation |
| 105 | +def ref_gather_matmul( |
| 106 | + inp: torch.Tensor, |
| 107 | + w: torch.Tensor, |
| 108 | + group: dist.ProcessGroup, |
| 109 | +): |
| 110 | + world_size = dist.get_world_size(group) |
| 111 | + ag_scratch = torch.empty((world_size * inp.shape[0], inp.shape[1]), device=inp.device) |
| 112 | + dist.all_gather_into_tensor(ag_scratch, inp, group=group) |
| 113 | + out = ag_scratch @ w |
| 114 | + return out |
| 115 | + |
| 116 | + |
| 117 | +def main(): |
| 118 | + dist.init_process_group(backend="nccl") |
| 119 | + rank = dist.get_rank() |
| 120 | + world_size = dist.get_world_size() |
| 121 | + device = torch.device(f"cuda:{rank}") |
| 122 | + group = dist.group.WORLD |
| 123 | + torch.manual_seed(rank + 52) |
| 124 | + print(f"Rank {rank} of {world_size} is initializing") |
| 125 | + |
| 126 | + bs = 256 |
| 127 | + hid = 1024 |
| 128 | + out_hid = 512 |
| 129 | + ref_inp = torch.rand((bs, hid), device=device) |
| 130 | + inp = symm_mem.empty(bs, hid, device=device).copy_(ref_inp) |
| 131 | + w = torch.rand((hid, out_hid), device=device) |
| 132 | + |
| 133 | + expected_out = ref_gather_matmul(ref_inp, w, group) |
| 134 | + |
| 135 | + out = cutile_gather_matmul(inp, w, group) |
| 136 | + |
| 137 | + torch.testing.assert_close( |
| 138 | + out, |
| 139 | + expected_out, |
| 140 | + atol=1e-3, |
| 141 | + rtol=1e-3, |
| 142 | + ) |
| 143 | + |
| 144 | + print(f"Rank {rank} of {world_size}: correct") |
| 145 | + dist.destroy_process_group() |
| 146 | + |
| 147 | + |
| 148 | +if __name__ == "__main__": |
| 149 | + main() |
0 commit comments