diff --git a/CMakeLists.txt b/CMakeLists.txt index ccda3d89fb5..c36c129b2fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -347,9 +347,10 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/host_ir/allocate_and_deallocate.cpp + ${NVFUSER_SRCS_DIR}/host_ir/assign_streams.cpp ${NVFUSER_SRCS_DIR}/host_ir/pass/convert_op_to_communication.cpp ${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp - ${NVFUSER_SRCS_DIR}/host_ir/allocate_and_deallocate.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_scatter_accumulate.cpp diff --git a/csrc/host_ir/allocate_and_deallocate.h b/csrc/host_ir/allocate_and_deallocate.h index fbf7c88470c..491e2000d91 100644 --- a/csrc/host_ir/allocate_and_deallocate.h +++ b/csrc/host_ir/allocate_and_deallocate.h @@ -7,7 +7,6 @@ // clang-format on #pragma once -#include "host_ir/container.h" #include "optimization_pass.h" namespace nvfuser::hir { diff --git a/csrc/host_ir/assign_streams.cpp b/csrc/host_ir/assign_streams.cpp new file mode 100644 index 00000000000..556de39ed8e --- /dev/null +++ b/csrc/host_ir/assign_streams.cpp @@ -0,0 +1,64 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include "host_ir/assign_streams.h" + +#include "host_ir/container.h" +#include "ir/builder.h" + +namespace nvfuser::hir { + +void AssignStreams::runPass(Fusion* fusion) { + auto* hic = dynamic_cast(fusion); + NVF_CHECK(hic != nullptr); + FusionGuard fg(hic); + + for (auto it = hic->topLevel().exprs().begin(); + it != hic->topLevel().exprs().end();) { + auto next_it = std::next(it); + + auto* for_loop = dynamic_cast(*it); + if (for_loop == nullptr) { + it = next_it; + continue; + } + + // We should check that the loop is stream-parallel. This is not necessary + // at this moment because all loops are stream-parallel. This is also hard + // to do becauase hir::ForLoop doesn't point to the source IterDomain. + + auto* get_current_stream = IrBuilder::create(); + Stream* main_stream = get_current_stream->stream(); + hic->topLevel().insert(it, get_current_stream); + + // At the beginning of each iteration: set stream and synchronize with main + // stream + auto* worker_stream = IrBuilder::create(for_loop->index()); + auto* set_stream = IrBuilder::create(worker_stream); + auto* sync_main = IrBuilder::create(main_stream); + auto old_begin = for_loop->body().exprs().begin(); + for_loop->body().insert(old_begin, set_stream); + for_loop->body().insert(old_begin, sync_main); + + // After the loop: create a joining loop to synchronize all worker streams + hic->topLevel().insert( + next_it, IrBuilder::create(main_stream)); + auto* join_loop = IrBuilder::create( + for_loop->index(), for_loop->start(), for_loop->stop()); + hic->topLevel().insert(next_it, join_loop); + + // In the joining loop: synchronize each worker stream + auto* join_worker_stream = IrBuilder::create(join_loop->index()); + auto* sync_worker = IrBuilder::create(join_worker_stream); + join_loop->body().push_back(sync_worker); + + it = next_it; + } +} + +} // namespace nvfuser::hir diff --git a/csrc/host_ir/assign_streams.h b/csrc/host_ir/assign_streams.h new file mode 100644 index 00000000000..cd14fbad95f --- /dev/null +++ b/csrc/host_ir/assign_streams.h @@ -0,0 +1,26 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include "optimization_pass.h" + +namespace nvfuser::hir { + +// A host IR pass that assigns streams to stream-parallel loops. +class AssignStreams : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + + static constexpr std::string_view name() { + return "AssignStreams"; + } +}; + +} // namespace nvfuser::hir diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 861bdfe1ed5..3acba5309bf 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -46,13 +46,12 @@ HostIrEvaluator::HostIrEvaluator( communicator_(communicator), params_(params), expr_evaluator_(), - my_local_device_index_(communicator_ ? communicator_->local_rank() : 0), + my_local_device_index_( + communicator_ == nullptr ? 0 : communicator_->local_rank()), ipc_handle_cache_(expr_evaluator_), multicast_handle_cache_() { const DeviceIdxType device_index = - (communicator_ != nullptr && communicator_->is_available()) - ? communicator_->deviceId() - : 0; + communicator_ == nullptr ? 0 : communicator_->deviceId(); if (isDebugDumpEnabled(DebugDumpOption::HostIr) && device_index == 0) { container_->print(debug()); } @@ -147,16 +146,10 @@ c10::cuda::CUDAStream HostIrEvaluator::getCUDAStream(Stream* stream) { NVF_ERROR(value.hasValue() && value.is()); stream_key = value.as(); } - if (streams_.find(stream_key) == streams_.end()) { - auto i = (communicator_ != nullptr && communicator_->is_available()) - ? communicator_->deviceId() - : 0; - streams_.insert( - {stream_key, - c10::cuda::getStreamFromPool( - /*isHighPriority=*/false, static_cast(i))}); - } - return streams_.at(stream_key); + + auto [it, inserted] = + streams_.try_emplace(stream_key, c10::cuda::getStreamFromPool()); + return it->second; } void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 5b46b7f0abb..3ac90190b0c 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -197,7 +197,7 @@ std::string Stream::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "Stream "; if (index() == nullptr) { - ss << name(); + ss << static_cast(this); } else { ss << index()->toInlineString(); } diff --git a/csrc/host_ir/ir.h b/csrc/host_ir/ir.h index a3b36284e36..6c1c9c61ce9 100644 --- a/csrc/host_ir/ir.h +++ b/csrc/host_ir/ir.h @@ -15,7 +15,6 @@ #include "ir/base_nodes.h" #include "ir/builder.h" #include "multidevice/communication.h" -#include "scheduler/heuristic.h" namespace nvfuser { // This works around a circular dependency: compiled_kernel.h ==> diff --git a/csrc/host_ir/passes.cpp b/csrc/host_ir/passes.cpp index 98078520145..65a8327ba58 100644 --- a/csrc/host_ir/passes.cpp +++ b/csrc/host_ir/passes.cpp @@ -9,11 +9,13 @@ #include "host_ir/passes.h" #include "host_ir/allocate_and_deallocate.h" +#include "host_ir/assign_streams.h" namespace nvfuser::hir { void runPasses(HostIrContainer& hic) { OptimizationPass::runPass(&hic); + OptimizationPass::runPass(&hic); } } // namespace nvfuser::hir diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 9393dc3016b..4fc707c163e 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -35,7 +35,6 @@ namespace nvfuser { class ViewTransform; -class Scope; class IrCloner; struct AnalyzeViewResult; @@ -2506,6 +2505,7 @@ class Scope { return std::ssize(exprs_); } + // Returns an iterator pointing to the inserted expression. Iterator insert(Iterator pos, Expr* expr); Iterator push_back(Expr* e) { diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index d41cb91ceb8..b56e6fee3aa 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -19,7 +19,6 @@ #include "multidevice/c10d_mock.h" #endif -#include "exceptions.h" #include "multidevice/multidevice.h" #include "visibility.h" diff --git a/tests/cpp/test_multidevice_lower_communication_cuda.cpp b/tests/cpp/test_multidevice_lower_communication_cuda.cpp index 3164f193132..971c075c018 100644 --- a/tests/cpp/test_multidevice_lower_communication_cuda.cpp +++ b/tests/cpp/test_multidevice_lower_communication_cuda.cpp @@ -330,9 +330,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( 2 * 1024 * 1024LL, // 2 MB 8 * 1024 * 1024LL, // 8 MB - 32 * 1024 * 1024LL, // 32 MB - 128 * 1024 * 1024LL, // 128 MB - 256 * 1024 * 1024LL // 256 MB + 32 * 1024 * 1024LL // 32 MB ), testing::Values( CommunicationProtocol::kMemcpy, diff --git a/tests/cpp/test_stream.cpp b/tests/cpp/test_stream.cpp index 49a4653c69c..bf7355deef4 100644 --- a/tests/cpp/test_stream.cpp +++ b/tests/cpp/test_stream.cpp @@ -24,12 +24,10 @@ namespace nvfuser { -class StreamTest : public NVFuserTest { - public: - StreamTest() { - EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); - } -}; +// The tests in this file verify building blocks for stream parallelism, e.g., +// sharding propagation and KernelExecutor. End-to-end tests have been moved to +// tests/python/direct/test_stream.py because the Python API is sufficient. +using StreamTest = NVFuserTest; TEST_F(StreamTest, AddPerStream) { constexpr int64_t c = 3; diff --git a/tests/python/direct/test_stream.py b/tests/python/direct/test_stream.py index 32013ae7c21..3e84c82ddc2 100644 --- a/tests/python/direct/test_stream.py +++ b/tests/python/direct/test_stream.py @@ -7,7 +7,7 @@ from nvfuser_direct import FusionDefinition, ParallelType, DataType -def test_matmul(nvfuser_direct_test): +def test_matmul(): c = 3 with FusionDefinition() as fd: @@ -46,7 +46,7 @@ def test_matmul(nvfuser_direct_test): assert event.input_shapes == [[5, 7], [7, 2], [5, 2]] -def test_two_matmuls_inlinable(nvfuser_direct_test): +def test_two_matmuls_inlinable(): c = 3 with FusionDefinition() as fd: @@ -97,7 +97,7 @@ def test_two_matmuls_inlinable(nvfuser_direct_test): assert event.input_shapes[0][0] == 2 -def test_two_matmuls_not_inlinable(nvfuser_direct_test): +def test_two_matmuls_not_inlinable(): c = 3 with FusionDefinition() as fd: diff --git a/tests/python/multidevice/benchmark_utils.py b/tests/python/multidevice/benchmark_utils.py index c585d68c452..0421a01dc23 100644 --- a/tests/python/multidevice/benchmark_utils.py +++ b/tests/python/multidevice/benchmark_utils.py @@ -28,12 +28,11 @@ def wrapper(*args, **kwargs): # Returns two functors, the first with profiler off and the second with profiler # on. The first functor is usually used for warmup and the second for actual -# benchmarking. This way, one -# can collect stats of the first few non-warmup benchmark iterations using +# benchmarking. This way, one can collect stats of the first few non-warmup +# benchmark iterations using # ```bash -# mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: pytest tests/python/multidevice/.py -k --only-mpi : -np pytest tests/python/multidevice/.py -k --only-mpi +# nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: mpirun -np pytest tests/python/multidevice/.py -k --only-mpi # ``` -# and then display the stats using e.g. `nsys stats --report=cuda_gpu_kern_sum -# report1.nsys-rep`. +# and then display the stats using `nsys stats`. def get_benchmark_fns(func): return get_benchmark_fn(func, profile=False), get_benchmark_fn(func, profile=True) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 795217515a7..17922346679 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -3,26 +3,18 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest -import torch import os +import torch +import torch.distributed as dist +from torch.distributed.tensor import distribute_tensor, Shard + import nvfuser_direct as nvfuser from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView +from benchmark_utils import get_benchmark_fns -@pytest.mark.mpi -def test_row_parallel_linear_forward(multidevice_test): - # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. - h, s, t = 2, 3, 6 - d = multidevice_test.size - if (h * 4) % d != 0: - pytest.skip( - f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." - ) - assert t % s == 0 - - mesh = nvfuser.multidevice.DeviceMesh(range(d)) - +def row_parallel_linear_forward(h, mesh, num_chunks): with FusionDefinition() as fd: inp = fd.define_tensor( shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16 @@ -36,11 +28,11 @@ def test_row_parallel_linear_forward(multidevice_test): for tv in (inp, weight): tv.set_device_mesh(mesh) - inp.split(0, s, inner_split=False) + inp.outer_split(0, num_chunks) inp.axis(0).parallelize(nvfuser.ParallelType.stream) - inp.split(2, d, inner_split=False) + inp.outer_split(2, mesh.size) inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) - weight.split(1, d, inner_split=False) + weight.outer_split(1, mesh.size) weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x) # Expected pre-segmentation IR: @@ -63,22 +55,49 @@ def test_row_parallel_linear_forward(multidevice_test): # /\. # s* - # Expected host IR: + # The host IR dumped with NVFUSER_DUMP=host_ir is similar to `row_parallel_linear_forward_reference`: # # %HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) : # T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false) + # GetCurrentStream into Stream 0x174e5c80 # FOR i535 from 0 to 3: - # T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535) + # SetCurrentStream to Stream i535 + # Synchronize Stream 0x174e5c80 + # T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535) + # T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), mem_type=global, size=( ( ceilDiv(i0, 3) ) * 12 ), zero_init=false, resets_to_zero=false) # T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) - # = linear(T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}), + # = linear(T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}), # T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1}) ) - # T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535) - # Communication 250 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}), backend=NCCL) - # Wait Communication 250 + # T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535) + # Communication 272 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}), backend=NCCL) + # Wait Communication 272 + # FOR i535 from 0 to 3: + # Synchronize Stream i535 # } // %HostIrContainer - inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16) - weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16) + return fd + + +@pytest.mark.mpi +def test_row_parallel_linear_forward(multidevice_test): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, s, t = 2, 3, 6 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + mesh = nvfuser.multidevice.DeviceMesh(range(d)) + fd = row_parallel_linear_forward(h, mesh, s) + + inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + h, h * 4, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) out_ref = torch.nn.functional.linear(inp_ref, weight_ref) inp = multidevice_test.shard_tensor(inp_ref, -1, mesh) @@ -101,6 +120,127 @@ def test_row_parallel_linear_forward(multidevice_test): assert event.input_shapes == [[m, k], [k, n], [m, n]] +@pytest.mark.mpi +@pytest.mark.benchmark +def test_row_parallel_linear_forward_benchmark(multidevice_test, benchmark): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, s, t = 8192, 2, 8192 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + mesh = nvfuser.multidevice.DeviceMesh(range(d)) + fd = row_parallel_linear_forward(h, mesh, s) + + inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16, device="cpu") + weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16, device="cpu") + + inp = multidevice_test.shard_tensor(inp_ref, -1, mesh) + weight = multidevice_test.shard_tensor(weight_ref, -1, mesh) + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + +def row_parallel_linear_forward_reference( + inp_shard: torch.Tensor, weight_shard: torch.Tensor, num_chunks: int +) -> torch.Tensor: + out = torch.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + inp_chunks = inp_shard.chunk(num_chunks) + out_chunks = out.chunk(num_chunks) + + def wait_stream(stream: torch.cuda.Stream) -> None: + event = torch.cuda.Event() + stream.record_event(event) + torch.cuda.current_stream().wait_event(event) + + main_stream = torch.cuda.current_stream() + worker_streams = [torch.cuda.Stream() for _ in range(num_chunks)] + for inp_chunk, out_chunk, worker_stream in zip( + inp_chunks, out_chunks, worker_streams + ): + with torch.cuda.stream(worker_stream): + wait_stream(main_stream) + torch.matmul(inp_chunk, weight_shard.T, out=out_chunk) + work = dist.all_reduce(out_chunk, op=dist.ReduceOp.SUM, async_op=True) + work.wait() + + for worker_stream in worker_streams: + wait_stream(worker_stream) + + return out + + +@pytest.mark.mpi +def test_row_parallel_linear_forward_reference(setup_default_process_group): + h, s, t = 2, 3, 6 + d = dist.get_world_size() + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + torch.manual_seed(0) + inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + h, h * 4, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight_ref.cuda()).cpu() + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local() + weight_shard = distribute_tensor( + weight_ref, mesh, placements=[Shard(-1)] + ).to_local() + out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s) + + torch.testing.assert_close(out.cpu(), out_ref) + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_row_parallel_linear_forward_reference_benchmark( + setup_default_process_group, benchmark +): + h, s, t = 8192, 2, 8192 + d = dist.get_world_size() + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + torch.manual_seed(0) + inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16) + weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16) + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local() + weight_shard = distribute_tensor( + weight_ref, mesh, placements=[Shard(-1)] + ).to_local() + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: row_parallel_linear_forward_reference(inp_shard, weight_shard, s) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8])