From 776a1028a94ea1ceba7b2cd33956510f849f4eb7 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 29 Dec 2025 15:56:42 -0800 Subject: [PATCH 1/9] Reference implementation in torch --- tests/python/multidevice/test_overlap.py | 99 +++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 795217515a7..3b55a9e7e3b 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -3,11 +3,15 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest -import torch import os +import torch +import torch.distributed as dist +from torch.distributed.tensor import distribute_tensor, Shard + import nvfuser_direct as nvfuser from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView +from benchmark_utils import get_benchmark_fns @pytest.mark.mpi @@ -101,6 +105,99 @@ def test_row_parallel_linear_forward(multidevice_test): assert event.input_shapes == [[m, k], [k, n], [m, n]] +def row_parallel_linear_forward_reference( + inp_shard: torch.Tensor, weight_shard: torch.Tensor, num_chunks: int +) -> torch.Tensor: + out = torch.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + inp_chunks = inp_shard.chunk(num_chunks) + out_chunks = out.chunk(num_chunks) + + def wait_stream(stream: torch.cuda.Stream) -> None: + event = torch.cuda.Event() + stream.record_event(event) + torch.cuda.current_stream().wait_event(event) + + main_stream = torch.cuda.current_stream() + worker_streams = [torch.cuda.Stream() for _ in range(num_chunks)] + for inp_chunk, out_chunk, worker_stream in zip( + inp_chunks, out_chunks, worker_streams + ): + with torch.cuda.stream(worker_stream): + wait_stream(main_stream) + torch.matmul(inp_chunk, weight_shard.T, out=out_chunk) + work = dist.all_reduce(out_chunk, op=dist.ReduceOp.SUM, async_op=True) + work.wait() + + for worker_stream in worker_streams: + wait_stream(worker_stream) + + return out + + +@pytest.mark.mpi +def test_row_parallel_linear_forward_reference(setup_default_process_group): + h, s, t = 2, 3, 6 + d = dist.get_world_size() + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + torch.manual_seed(0) + inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + h, h * 4, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight_ref.cuda()).cpu() + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local() + weight_shard = distribute_tensor( + weight_ref, mesh, placements=[Shard(-1)] + ).to_local() + out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s) + + torch.testing.assert_close(out.cpu(), out_ref) + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_row_parallel_linear_forward_reference_benchmark( + setup_default_process_group, benchmark +): + h, s, t = 8192, 2, 8192 + d = dist.get_world_size() + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + torch.manual_seed(0) + inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16) + weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16) + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + inp_shard = distribute_tensor(inp_ref, mesh, placements=[Shard(-1)]).to_local() + weight_shard = distribute_tensor( + weight_ref, mesh, placements=[Shard(-1)] + ).to_local() + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: row_parallel_linear_forward_reference(inp_shard, weight_shard, s) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8]) From 2a86e82e407810ba8ccb3c202fa6342432998526 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 1 Jan 2026 18:26:48 -0800 Subject: [PATCH 2/9] Simplify getCUDAStream --- csrc/host_ir/evaluator.cpp | 14 ++++---------- .../test_multidevice_lower_communication_cuda.cpp | 4 +--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 861bdfe1ed5..849c0c35463 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -147,16 +147,10 @@ c10::cuda::CUDAStream HostIrEvaluator::getCUDAStream(Stream* stream) { NVF_ERROR(value.hasValue() && value.is()); stream_key = value.as(); } - if (streams_.find(stream_key) == streams_.end()) { - auto i = (communicator_ != nullptr && communicator_->is_available()) - ? communicator_->deviceId() - : 0; - streams_.insert( - {stream_key, - c10::cuda::getStreamFromPool( - /*isHighPriority=*/false, static_cast(i))}); - } - return streams_.at(stream_key); + + auto [it, inserted] = + streams_.try_emplace(stream_key, c10::cuda::getStreamFromPool()); + return it->second; } void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { diff --git a/tests/cpp/test_multidevice_lower_communication_cuda.cpp b/tests/cpp/test_multidevice_lower_communication_cuda.cpp index 3164f193132..971c075c018 100644 --- a/tests/cpp/test_multidevice_lower_communication_cuda.cpp +++ b/tests/cpp/test_multidevice_lower_communication_cuda.cpp @@ -330,9 +330,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( 2 * 1024 * 1024LL, // 2 MB 8 * 1024 * 1024LL, // 8 MB - 32 * 1024 * 1024LL, // 32 MB - 128 * 1024 * 1024LL, // 128 MB - 256 * 1024 * 1024LL // 256 MB + 32 * 1024 * 1024LL // 32 MB ), testing::Values( CommunicationProtocol::kMemcpy, From 2ddc32376737c5675f2bff8d22309240ce374eaf Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 2 Jan 2026 13:42:00 -0800 Subject: [PATCH 3/9] StreamTest no longer needs the HostIrLowering knob. --- tests/cpp/test_stream.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_stream.cpp b/tests/cpp/test_stream.cpp index 49a4653c69c..bf7355deef4 100644 --- a/tests/cpp/test_stream.cpp +++ b/tests/cpp/test_stream.cpp @@ -24,12 +24,10 @@ namespace nvfuser { -class StreamTest : public NVFuserTest { - public: - StreamTest() { - EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); - } -}; +// The tests in this file verify building blocks for stream parallelism, e.g., +// sharding propagation and KernelExecutor. End-to-end tests have been moved to +// tests/python/direct/test_stream.py because the Python API is sufficient. +using StreamTest = NVFuserTest; TEST_F(StreamTest, AddPerStream) { constexpr int64_t c = 3; From f5954065d29994dbe127eecb30eefcf3fefc3c45 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 2 Jan 2026 14:56:45 -0800 Subject: [PATCH 4/9] Minor cleanup to Communicator --- csrc/host_ir/evaluator.cpp | 7 +++---- csrc/multidevice/communicator.h | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 849c0c35463..3acba5309bf 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -46,13 +46,12 @@ HostIrEvaluator::HostIrEvaluator( communicator_(communicator), params_(params), expr_evaluator_(), - my_local_device_index_(communicator_ ? communicator_->local_rank() : 0), + my_local_device_index_( + communicator_ == nullptr ? 0 : communicator_->local_rank()), ipc_handle_cache_(expr_evaluator_), multicast_handle_cache_() { const DeviceIdxType device_index = - (communicator_ != nullptr && communicator_->is_available()) - ? communicator_->deviceId() - : 0; + communicator_ == nullptr ? 0 : communicator_->deviceId(); if (isDebugDumpEnabled(DebugDumpOption::HostIr) && device_index == 0) { container_->print(debug()); } diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index d41cb91ceb8..b56e6fee3aa 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -19,7 +19,6 @@ #include "multidevice/c10d_mock.h" #endif -#include "exceptions.h" #include "multidevice/multidevice.h" #include "visibility.h" From 10955984fd6f3970b3675cc1e1551f3f91ef9072 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 1 Jan 2026 19:45:38 -0800 Subject: [PATCH 5/9] New pass --- CMakeLists.txt | 3 +- csrc/host_ir/allocate_and_deallocate.h | 1 - csrc/host_ir/assign_streams.cpp | 64 ++++++++++++++++++++++++++ csrc/host_ir/assign_streams.h | 26 +++++++++++ csrc/host_ir/ir.h | 1 - csrc/host_ir/passes.cpp | 2 + csrc/ir/internal_nodes.h | 1 - 7 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 csrc/host_ir/assign_streams.cpp create mode 100644 csrc/host_ir/assign_streams.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ccda3d89fb5..c36c129b2fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -347,9 +347,10 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/host_ir/allocate_and_deallocate.cpp + ${NVFUSER_SRCS_DIR}/host_ir/assign_streams.cpp ${NVFUSER_SRCS_DIR}/host_ir/pass/convert_op_to_communication.cpp ${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp - ${NVFUSER_SRCS_DIR}/host_ir/allocate_and_deallocate.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_scatter_accumulate.cpp diff --git a/csrc/host_ir/allocate_and_deallocate.h b/csrc/host_ir/allocate_and_deallocate.h index fbf7c88470c..491e2000d91 100644 --- a/csrc/host_ir/allocate_and_deallocate.h +++ b/csrc/host_ir/allocate_and_deallocate.h @@ -7,7 +7,6 @@ // clang-format on #pragma once -#include "host_ir/container.h" #include "optimization_pass.h" namespace nvfuser::hir { diff --git a/csrc/host_ir/assign_streams.cpp b/csrc/host_ir/assign_streams.cpp new file mode 100644 index 00000000000..c3e81c19321 --- /dev/null +++ b/csrc/host_ir/assign_streams.cpp @@ -0,0 +1,64 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include "host_ir/assign_streams.h" + +#include "host_ir/container.h" +#include "ir/builder.h" + +namespace nvfuser::hir { + +void AssignStreams::runPass(Fusion* fusion) { + auto* hic = dynamic_cast(fusion); + NVF_CHECK(hic != nullptr); + + // For each stream-parallel loop, insert to the beginning a SetCurrentStream + // and a Synchronize (to the main stream). Right after the loop exits, insert + // another loop that joins all the worker streams. + + for (auto it = hic->topLevel().exprs().begin(); + it != hic->topLevel().exprs().end(); + ++it) { + auto* for_loop = dynamic_cast(*it); + if (!for_loop) { + continue; + } + + // FIXME: should have checked that the loop is stream-parallel + + auto* get_current_stream = IrBuilder::create(); + Stream* main_stream = get_current_stream->stream(); + hic->topLevel().insert(it, get_current_stream); + + // At the beginning of each iteration: set stream and synchronize with main + // stream + auto* worker_stream = IrBuilder::create(for_loop->index()); + auto* set_stream = IrBuilder::create(worker_stream); + auto* sync_main = IrBuilder::create(main_stream); + + // Insert at the beginning of the loop body + auto body_it = + for_loop->body().insert(for_loop->body().exprs().begin(), set_stream); + for_loop->body().insert(std::next(body_it), sync_main); + + // After the loop: create a joining loop to synchronize all worker streams + auto* join_loop = IrBuilder::create( + for_loop->index(), for_loop->start(), for_loop->stop()); + + // In the joining loop: synchronize each worker stream + auto* join_worker_stream = IrBuilder::create(join_loop->index()); + auto* sync_worker = IrBuilder::create(join_worker_stream); + join_loop->body().push_back(sync_worker); + + // Insert join_loop after the current for_loop + auto next_it = std::next(it); + hic->topLevel().insert(next_it, join_loop); + } +} + +} // namespace nvfuser::hir diff --git a/csrc/host_ir/assign_streams.h b/csrc/host_ir/assign_streams.h new file mode 100644 index 00000000000..cd14fbad95f --- /dev/null +++ b/csrc/host_ir/assign_streams.h @@ -0,0 +1,26 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include "optimization_pass.h" + +namespace nvfuser::hir { + +// A host IR pass that assigns streams to stream-parallel loops. +class AssignStreams : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + + static constexpr std::string_view name() { + return "AssignStreams"; + } +}; + +} // namespace nvfuser::hir diff --git a/csrc/host_ir/ir.h b/csrc/host_ir/ir.h index a3b36284e36..6c1c9c61ce9 100644 --- a/csrc/host_ir/ir.h +++ b/csrc/host_ir/ir.h @@ -15,7 +15,6 @@ #include "ir/base_nodes.h" #include "ir/builder.h" #include "multidevice/communication.h" -#include "scheduler/heuristic.h" namespace nvfuser { // This works around a circular dependency: compiled_kernel.h ==> diff --git a/csrc/host_ir/passes.cpp b/csrc/host_ir/passes.cpp index 98078520145..65a8327ba58 100644 --- a/csrc/host_ir/passes.cpp +++ b/csrc/host_ir/passes.cpp @@ -9,11 +9,13 @@ #include "host_ir/passes.h" #include "host_ir/allocate_and_deallocate.h" +#include "host_ir/assign_streams.h" namespace nvfuser::hir { void runPasses(HostIrContainer& hic) { OptimizationPass::runPass(&hic); + OptimizationPass::runPass(&hic); } } // namespace nvfuser::hir diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 9393dc3016b..ac825c01c6b 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -35,7 +35,6 @@ namespace nvfuser { class ViewTransform; -class Scope; class IrCloner; struct AnalyzeViewResult; From 2588deb89b690d141a7ac0050499db310db302df Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 2 Jan 2026 14:29:48 -0800 Subject: [PATCH 6/9] Remove dependency on nvfuser_direct_test It's never used by these tests. --- tests/python/direct/test_stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/direct/test_stream.py b/tests/python/direct/test_stream.py index 32013ae7c21..3e84c82ddc2 100644 --- a/tests/python/direct/test_stream.py +++ b/tests/python/direct/test_stream.py @@ -7,7 +7,7 @@ from nvfuser_direct import FusionDefinition, ParallelType, DataType -def test_matmul(nvfuser_direct_test): +def test_matmul(): c = 3 with FusionDefinition() as fd: @@ -46,7 +46,7 @@ def test_matmul(nvfuser_direct_test): assert event.input_shapes == [[5, 7], [7, 2], [5, 2]] -def test_two_matmuls_inlinable(nvfuser_direct_test): +def test_two_matmuls_inlinable(): c = 3 with FusionDefinition() as fd: @@ -97,7 +97,7 @@ def test_two_matmuls_inlinable(nvfuser_direct_test): assert event.input_shapes[0][0] == 2 -def test_two_matmuls_not_inlinable(nvfuser_direct_test): +def test_two_matmuls_not_inlinable(): c = 3 with FusionDefinition() as fd: From b09ab67ca254aa4ec3ce6a41ab0a3c2214012801 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 2 Jan 2026 21:08:34 -0800 Subject: [PATCH 7/9] Print Stream properly --- csrc/host_ir/ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 5b46b7f0abb..3ac90190b0c 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -197,7 +197,7 @@ std::string Stream::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "Stream "; if (index() == nullptr) { - ss << name(); + ss << static_cast(this); } else { ss << index()->toInlineString(); } From 7ae4e9ee0c40e54f4cfcc4c75478ba4be1abd80c Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 2 Jan 2026 14:30:35 -0800 Subject: [PATCH 8/9] WIP --- csrc/host_ir/assign_streams.cpp | 27 +++--- csrc/ir/internal_nodes.h | 1 + tests/python/multidevice/benchmark_utils.py | 9 +- tests/python/multidevice/test_overlap.py | 91 +++++++++++++++------ 4 files changed, 85 insertions(+), 43 deletions(-) diff --git a/csrc/host_ir/assign_streams.cpp b/csrc/host_ir/assign_streams.cpp index c3e81c19321..ce404f0aff1 100644 --- a/csrc/host_ir/assign_streams.cpp +++ b/csrc/host_ir/assign_streams.cpp @@ -16,20 +16,21 @@ namespace nvfuser::hir { void AssignStreams::runPass(Fusion* fusion) { auto* hic = dynamic_cast(fusion); NVF_CHECK(hic != nullptr); - - // For each stream-parallel loop, insert to the beginning a SetCurrentStream - // and a Synchronize (to the main stream). Right after the loop exits, insert - // another loop that joins all the worker streams. + FusionGuard fg(hic); for (auto it = hic->topLevel().exprs().begin(); - it != hic->topLevel().exprs().end(); - ++it) { + it != hic->topLevel().exprs().end();) { + auto next_it = std::next(it); + auto* for_loop = dynamic_cast(*it); - if (!for_loop) { + if (for_loop == nullptr) { + it = next_it; continue; } - // FIXME: should have checked that the loop is stream-parallel + // We should check that the loop is stream-parallel. This is not necessary + // at this moment because all loops are stream-parallel. This is also hard + // to do becauase hir::ForLoop doesn't point to the source IterDomain. auto* get_current_stream = IrBuilder::create(); Stream* main_stream = get_current_stream->stream(); @@ -40,11 +41,9 @@ void AssignStreams::runPass(Fusion* fusion) { auto* worker_stream = IrBuilder::create(for_loop->index()); auto* set_stream = IrBuilder::create(worker_stream); auto* sync_main = IrBuilder::create(main_stream); - - // Insert at the beginning of the loop body - auto body_it = - for_loop->body().insert(for_loop->body().exprs().begin(), set_stream); - for_loop->body().insert(std::next(body_it), sync_main); + auto old_begin = for_loop->body().exprs().begin(); + for_loop->body().insert(old_begin, set_stream); + for_loop->body().insert(old_begin, sync_main); // After the loop: create a joining loop to synchronize all worker streams auto* join_loop = IrBuilder::create( @@ -56,8 +55,8 @@ void AssignStreams::runPass(Fusion* fusion) { join_loop->body().push_back(sync_worker); // Insert join_loop after the current for_loop - auto next_it = std::next(it); hic->topLevel().insert(next_it, join_loop); + it = next_it; } } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index ac825c01c6b..4fc707c163e 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2505,6 +2505,7 @@ class Scope { return std::ssize(exprs_); } + // Returns an iterator pointing to the inserted expression. Iterator insert(Iterator pos, Expr* expr); Iterator push_back(Expr* e) { diff --git a/tests/python/multidevice/benchmark_utils.py b/tests/python/multidevice/benchmark_utils.py index c585d68c452..0421a01dc23 100644 --- a/tests/python/multidevice/benchmark_utils.py +++ b/tests/python/multidevice/benchmark_utils.py @@ -28,12 +28,11 @@ def wrapper(*args, **kwargs): # Returns two functors, the first with profiler off and the second with profiler # on. The first functor is usually used for warmup and the second for actual -# benchmarking. This way, one -# can collect stats of the first few non-warmup benchmark iterations using +# benchmarking. This way, one can collect stats of the first few non-warmup +# benchmark iterations using # ```bash -# mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: pytest tests/python/multidevice/.py -k --only-mpi : -np pytest tests/python/multidevice/.py -k --only-mpi +# nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: mpirun -np pytest tests/python/multidevice/.py -k --only-mpi # ``` -# and then display the stats using e.g. `nsys stats --report=cuda_gpu_kern_sum -# report1.nsys-rep`. +# and then display the stats using `nsys stats`. def get_benchmark_fns(func): return get_benchmark_fn(func, profile=False), get_benchmark_fn(func, profile=True) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 3b55a9e7e3b..17922346679 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -14,19 +14,7 @@ from benchmark_utils import get_benchmark_fns -@pytest.mark.mpi -def test_row_parallel_linear_forward(multidevice_test): - # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. - h, s, t = 2, 3, 6 - d = multidevice_test.size - if (h * 4) % d != 0: - pytest.skip( - f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." - ) - assert t % s == 0 - - mesh = nvfuser.multidevice.DeviceMesh(range(d)) - +def row_parallel_linear_forward(h, mesh, num_chunks): with FusionDefinition() as fd: inp = fd.define_tensor( shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16 @@ -40,11 +28,11 @@ def test_row_parallel_linear_forward(multidevice_test): for tv in (inp, weight): tv.set_device_mesh(mesh) - inp.split(0, s, inner_split=False) + inp.outer_split(0, num_chunks) inp.axis(0).parallelize(nvfuser.ParallelType.stream) - inp.split(2, d, inner_split=False) + inp.outer_split(2, mesh.size) inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) - weight.split(1, d, inner_split=False) + weight.outer_split(1, mesh.size) weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x) # Expected pre-segmentation IR: @@ -67,22 +55,49 @@ def test_row_parallel_linear_forward(multidevice_test): # /\. # s* - # Expected host IR: + # The host IR dumped with NVFUSER_DUMP=host_ir is similar to `row_parallel_linear_forward_reference`: # # %HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) : # T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false) + # GetCurrentStream into Stream 0x174e5c80 # FOR i535 from 0 to 3: - # T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535) + # SetCurrentStream to Stream i535 + # Synchronize Stream 0x174e5c80 + # T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535) + # T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), mem_type=global, size=( ( ceilDiv(i0, 3) ) * 12 ), zero_init=false, resets_to_zero=false) # T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) - # = linear(T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}), + # = linear(T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}), # T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1}) ) - # T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535) - # Communication 250 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}), backend=NCCL) - # Wait Communication 250 + # T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535) + # Communication 272 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}), backend=NCCL) + # Wait Communication 272 + # FOR i535 from 0 to 3: + # Synchronize Stream i535 # } // %HostIrContainer - inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16) - weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16) + return fd + + +@pytest.mark.mpi +def test_row_parallel_linear_forward(multidevice_test): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, s, t = 2, 3, 6 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + mesh = nvfuser.multidevice.DeviceMesh(range(d)) + fd = row_parallel_linear_forward(h, mesh, s) + + inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + h, h * 4, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) out_ref = torch.nn.functional.linear(inp_ref, weight_ref) inp = multidevice_test.shard_tensor(inp_ref, -1, mesh) @@ -105,6 +120,34 @@ def test_row_parallel_linear_forward(multidevice_test): assert event.input_shapes == [[m, k], [k, n], [m, n]] +@pytest.mark.mpi +@pytest.mark.benchmark +def test_row_parallel_linear_forward_benchmark(multidevice_test, benchmark): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, s, t = 8192, 2, 8192 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + mesh = nvfuser.multidevice.DeviceMesh(range(d)) + fd = row_parallel_linear_forward(h, mesh, s) + + inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16, device="cpu") + weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16, device="cpu") + + inp = multidevice_test.shard_tensor(inp_ref, -1, mesh) + weight = multidevice_test.shard_tensor(weight_ref, -1, mesh) + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + def row_parallel_linear_forward_reference( inp_shard: torch.Tensor, weight_shard: torch.Tensor, num_chunks: int ) -> torch.Tensor: From b4adbf63db0d7e637053cdaa2613b6240ac33389 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 3 Jan 2026 08:19:02 -0800 Subject: [PATCH 9/9] Reset the current stream to the main stream --- csrc/host_ir/assign_streams.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/host_ir/assign_streams.cpp b/csrc/host_ir/assign_streams.cpp index ce404f0aff1..556de39ed8e 100644 --- a/csrc/host_ir/assign_streams.cpp +++ b/csrc/host_ir/assign_streams.cpp @@ -22,7 +22,7 @@ void AssignStreams::runPass(Fusion* fusion) { it != hic->topLevel().exprs().end();) { auto next_it = std::next(it); - auto* for_loop = dynamic_cast(*it); + auto* for_loop = dynamic_cast(*it); if (for_loop == nullptr) { it = next_it; continue; @@ -46,16 +46,17 @@ void AssignStreams::runPass(Fusion* fusion) { for_loop->body().insert(old_begin, sync_main); // After the loop: create a joining loop to synchronize all worker streams - auto* join_loop = IrBuilder::create( + hic->topLevel().insert( + next_it, IrBuilder::create(main_stream)); + auto* join_loop = IrBuilder::create( for_loop->index(), for_loop->start(), for_loop->stop()); + hic->topLevel().insert(next_it, join_loop); // In the joining loop: synchronize each worker stream auto* join_worker_stream = IrBuilder::create(join_loop->index()); auto* sync_worker = IrBuilder::create(join_worker_stream); join_loop->body().push_back(sync_worker); - // Insert join_loop after the current for_loop - hic->topLevel().insert(next_it, join_loop); it = next_it; } }