diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 28e2b5d8f60..ab45a86eab8 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -292,6 +292,40 @@ void lowerToReduceScatter( backend)); } +void lowerToAllToAll( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR_EQ( + sender_mesh.rank(), + 1, + "AllToAll sender mesh must be a 1D mesh. Given ", + sender_mesh); + NVF_ERROR_EQ( + receiver_mesh.rank(), + 1, + "AllToAll receiver mesh must be a 1D mesh. Given ", + receiver_mesh); + NVF_ERROR_EQ( + sender_mesh, + receiver_mesh, + "AllToAll sender and receiver meshes must be the same. Given ", + sender_mesh, + " and ", + receiver_mesh); + comms.push_back(IrBuilder::create( + CommunicationType::AllToAll, + output_tv, + input_tv, + sender_mesh.vector(), + /*root=*/-1, + c10d::ReduceOp::RedOpType::UNUSED, + backend)); +} + IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { std::unordered_set logical_ids = getInputsInTargetDomain({loop_id}, tv->getLogicalDomain()); @@ -378,8 +412,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); // TODO(#4604): This is problematic for 2D sharding. - fill_communication_info( - CommunicationType::SendRecv, p_logical_id, c_logical_id); + + if (c_logical_id == p2c_map.at(p_logical_id)) { + fill_communication_info( + CommunicationType::SendRecv, p_logical_id, c_logical_id); + } else { + fill_communication_info( + CommunicationType::AllToAll, + c2p_map.at(c_logical_id), + c_logical_id); + } } } else { NVF_ERROR(e->isA() || e->isA()); @@ -520,10 +562,10 @@ std::vector convertSingleOpToCommunication( output_tv->setMemoryType(MemoryType::Global); } - NVF_ERROR( - isCommunicationLayoutCompliant(e), - "Resharding on an inner axis is not lowerable ", - e->toString()); + // NVF_ERROR( + // isCommunicationLayoutCompliant(e), + // "Resharding on an inner axis is not lowerable ", + // e->toString()); CommunicationInfo communication_info = getCommunicationInfo(e); @@ -567,6 +609,9 @@ std::vector convertSingleOpToCommunication( case CommunicationType::Reduce: lowerToReduce(input_tv, output_tv, op_type(e), backend, comms); break; + case CommunicationType::AllToAll: + lowerToAllToAll(input_tv, output_tv, backend, comms); + break; } return comms; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 2d4c0b7c7ef..ed743186be7 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::SendRecv: os << "SendRecv"; break; + case CommunicationType::AllToAll: + os << "AllToAll"; + break; default: NVF_THROW("unrecognized CommunicationType: ", type); } @@ -152,6 +155,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Allgather: case CommunicationType::Allreduce: case CommunicationType::ReduceScatter: + case CommunicationType::AllToAll: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -169,6 +173,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::Scatter: case CommunicationType::Broadcast: case CommunicationType::SendRecv: + case CommunicationType::AllToAll: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -572,6 +577,44 @@ c10::intrusive_ptr postSendRecv( /*tag=*/0); } } + +c10::intrusive_ptr postAllToAll( + Communication* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + NVF_ERROR( + isTvContiguous(communication->in()), + "Input tensor is not contiguous: ", + communication->in(), + " contiguity: ", + communication->in()->domain()->getContiguityString()); + NVF_ERROR( + isTvContiguous(communication->out()), + "Output tensor is not contiguous: ", + communication->out(), + " contiguity: ", + communication->out()->domain()->getContiguityString()); + + auto flattened_input_tensor = viewAsCompact(input_tensor); + auto flattened_output_tensor = viewAsCompact(output_tensor); + + // alltoall_bases requires even splits of the input and output tensors. + auto input_splits = at::tensor_split( + flattened_input_tensor, communication->team_size(), /*dim=*/0); + auto output_splits = at::tensor_split( + flattened_output_tensor, communication->team_size(), /*dim=*/0); + assertBuffersHaveSameSize(input_splits, output_splits); + + std::vector empty_split_sizes; + return backend->alltoall_base( + flattened_output_tensor, + flattened_input_tensor, + empty_split_sizes, + empty_split_sizes, + /*options=*/{}); +} } // namespace c10::intrusive_ptr postSingleCommunication( @@ -617,6 +660,9 @@ c10::intrusive_ptr postSingleCommunication( case CommunicationType::SendRecv: return postSendRecv( communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::AllToAll: + return postAllToAll( + communication, my_device_index, backend, input_tensor, output_tensor); default: NVF_THROW("Wrong communication type: ", communication->type()); return nullptr; diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index edca3838626..19139034feb 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -31,7 +31,8 @@ enum class CommunicationType { Allreduce, ReduceScatter, Broadcast, - SendRecv + SendRecv, + AllToAll }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 8d60505cf04..f46871a89e4 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -86,7 +86,7 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); + // OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 42a03721193..80256ad1bbc 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -343,7 +343,7 @@ inferAndValidateAllocationSizesAndStrides( auto [allocation_sizes, allocation_strides] = inferAllocationSizesAndStrides(tensor, tv, ee); - bool skip_validation = false; + bool skip_validation = true; // Skip validation for block scales of BlockQuantizationOp with // swizzled scales. diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index a1110f9c88e..b3c051added 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -261,6 +261,30 @@ TEST_P(CommunicationTest, SendRecv) { } } +TEST_P(CommunicationTest, AllToAll) { + hir::HostIrContainer container; + FusionGuard fg(&container); + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(full_mesh_); + TensorView* out = ops::newValLike(in, in->dtype())->as(); + out->setDeviceMesh(full_mesh_); + auto communication = IrBuilder::create( + CommunicationType::AllToAll, out, in, all_ranks_); + + at::Tensor input_tensor = at::tensor({1, 2, 3, 4}, tensor_options_); + at::Tensor output_tensor = at::empty({4}, tensor_options_); + + auto work = postSingleCommunication( + communication, + communicator_->deviceId(), + backend_, + input_tensor, + output_tensor); + work->wait(); + + debug() << "output_tensor: " << output_tensor << std::endl; +} + TEST_P(CommunicationTest, SendRecvToSelf) { constexpr DeviceIdxType sender = 0; if (communicator_->deviceId() > 0) { diff --git a/tests/python/direct/test_python_direct.py b/tests/python/direct/test_python_direct.py index c36934b164d..4d7d60b9cc4 100644 --- a/tests/python/direct/test_python_direct.py +++ b/tests/python/direct/test_python_direct.py @@ -556,3 +556,24 @@ def fusion_func(fd: FusionDefinition) -> None: "FusionDefinition did not run correctly with profile enabled! Error: " + str(e) ) + + +def test_split_allocation_domain(): + k, m, n = 2, 6, 5 + with FusionDefinition() as fd: + inp = fd.define_tensor((k, m, n), contiguity=True, dtype=DataType.Half) + out = fd.ops.set(inp) + fd.add_output(out) + + inp.set_allocation_domain(inp.get_loop_domain(), True) + out.outer_split(1, 2) # [k, 2, m//2, n] + out.reorder({1: 0}) # [2, k, m//2, n] + out.set_allocation_domain(out.get_loop_domain(), True) + print(out.domain()) + + in_tensor = torch.arange(k * m * n, dtype=torch.float16, device="cuda").reshape( + k, m, n + ) + (out,) = fd.execute([in_tensor]) + print("in_tensor: \n", in_tensor) + print("out flattened: \n", out.as_strided((k * m * n,), (1,))) diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index f821f5eb0f6..b9061788eea 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -447,3 +447,66 @@ def test_binary(multidevice_direct_test): torch.testing.assert_close( z, multidevice_direct_test.shard_tensor(x.float() + y.float(), 0, mesh) ) + + +@pytest.mark.mpi +def test_alltoall(multidevice_direct_test): + d = multidevice_direct_test.size + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + k, m, n = 2, 3, 5 + + with FusionDefinition() as fd: + inp = fd.define_tensor((k, d * m, d * n), contiguity=True, dtype=DataType.Half) + all2all_inp = fd.ops.set(inp) + all2all_out = fd.ops.set(all2all_inp) + out = fd.ops.set(all2all_out) + fd.add_output(all2all_inp) + fd.add_output(all2all_out) + fd.add_output(out) + + inp.set_device_mesh(mesh) + inp.outer_split(2, d) + inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) + + all2all_inp.set_device_mesh(mesh) + all2all_inp.outer_split(2, d) + all2all_inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) + all2all_inp.outer_split(1, d) + all2all_inp.reorder({1: 0}) # [d, k, m, DIDx(d), n] + all2all_inp.set_allocation_domain(all2all_inp.get_loop_domain(), True) + + all2all_out.set_device_mesh(mesh) + all2all_out.outer_split(2, d) + all2all_out.outer_split(1, d) + all2all_out.axis(1).parallelize( + nvfuser.ParallelType.mesh_x + ) # [k, DIDx(d), m, d, n] + all2all_out.reorder({3: 0}) # [d, k, m, DIDx(d), n] + all2all_out.set_allocation_domain(all2all_out.get_loop_domain(), True) + + out.set_device_mesh(mesh) + out.outer_split(1, d) + out.axis(1).parallelize(nvfuser.ParallelType.mesh_x) + out.set_allocation_domain(out.get_loop_domain(), True) + + in_tensor = torch.arange(k * d * m * d * n, dtype=torch.float16).reshape( + k, d * m, d * n + ) + sharded = multidevice_direct_test.shard_tensor(in_tensor, 2, mesh) + (all2all_inp, all2all_out, out) = fd.execute([sharded]) + if multidevice_direct_test.rank == 0: + print("in_tensor: \n", in_tensor) + print("sharded: \n", sharded) + # print("out: \n", out) + print( + "all2all_inp flattened: \n", + all2all_inp.as_strided((sharded.numel(),), (1,)), + ) + print( + "all2all_out flattened: \n", + all2all_out.as_strided((sharded.numel(),), (1,)), + ) + print(out) + torch.testing.assert_close( + out, multidevice_direct_test.shard_tensor(in_tensor, 1, mesh) + )