diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 795217515a7..3b55a9e7e3b 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -3,11 +3,15 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest -import torch import os +import torch +import torch.distributed as dist +from torch.distributed.tensor import distribute_tensor, Shard + import nvfuser_direct as nvfuser from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView +from benchmark_utils import get_benchmark_fns @pytest.mark.mpi @@ -101,6 +105,99 @@ def test_row_parallel_linear_forward(multidevice_test): assert event.input_shapes == [[m, k], [k, n], [m, n]] +def row_parallel_linear_forward_reference( + inp_shard: torch.Tensor, weight_shard: torch.Tensor, num_chunks: int +) -> torch.Tensor: + out = torch.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + inp_chunks = inp_shard.chunk(num_chunks) + out_chunks = out.chunk(num_chunks) + + def wait_stream(stream: torch.cuda.Stream) -> None: + event = torch.cuda.Event() + stream.record_event(event) + torch.cuda.current_stream().wait_event(event) + + main_stream = torch.cuda.current_stream() + worker_streams = [torch.cuda.Stream() for _ in range(num_chunks)] + for inp_chunk, out_chunk, worker_stream in zip( + inp_chunks, out_chunks, worker_streams + ): + with torch.cuda.stream(worker_stream): + wait_stream(main_stream) + torch.matmul(inp_chunk, weight_shard.T, out=out_chunk) + work = dist.all_reduce(out_chunk, op=dist.ReduceOp.SUM, async_op=True) + work.wait() + + for worker_stream in worker_streams: + wait_stream(worker_stream) + + return out + + +@pytest.mark.mpi +def test_row_parallel_linear_forward_reference(setup_default_process_group): + h, s, t = 2, 3, 6 + d = dist.get_world_size() + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + torch.manual_seed(0) + inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + h, h * 4, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight_ref.cuda()).cpu() + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local() + weight_shard = distribute_tensor( + weight_ref, mesh, placements=[Shard(-1)] + ).to_local() + out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s) + + torch.testing.assert_close(out.cpu(), out_ref) + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_row_parallel_linear_forward_reference_benchmark( + setup_default_process_group, benchmark +): + h, s, t = 8192, 2, 8192 + d = dist.get_world_size() + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + torch.manual_seed(0) + inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16) + weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16) + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local() + weight_shard = distribute_tensor( + weight_ref, mesh, placements=[Shard(-1)] + ).to_local() + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: row_parallel_linear_forward_reference(inp_shard, weight_shard, s) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8])