Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,40 @@ void lowerToReduceScatter(
backend));
}

void lowerToAllToAll(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
sender_mesh.rank(),
1,
"AllToAll sender mesh must be a 1D mesh. Given ",
sender_mesh);
NVF_ERROR_EQ(
receiver_mesh.rank(),
1,
"AllToAll receiver mesh must be a 1D mesh. Given ",
receiver_mesh);
NVF_ERROR_EQ(
sender_mesh,
receiver_mesh,
"AllToAll sender and receiver meshes must be the same. Given ",
sender_mesh,
" and ",
receiver_mesh);
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::AllToAll,
output_tv,
input_tv,
sender_mesh.vector(),
/*root=*/-1,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
std::unordered_set<IterDomain*> logical_ids =
getInputsInTargetDomain({loop_id}, tv->getLogicalDomain());
Expand Down Expand Up @@ -378,8 +412,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
// TODO(#4604): This is problematic for 2D sharding.
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);

if (c_logical_id == p2c_map.at(p_logical_id)) {
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);
} else {
fill_communication_info(
CommunicationType::AllToAll,
c2p_map.at(c_logical_id),
c_logical_id);
}
}
} else {
NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
Expand Down Expand Up @@ -520,10 +562,10 @@ std::vector<Expr*> convertSingleOpToCommunication(
output_tv->setMemoryType(MemoryType::Global);
}

NVF_ERROR(
isCommunicationLayoutCompliant(e),
"Resharding on an inner axis is not lowerable ",
e->toString());
// NVF_ERROR(
// isCommunicationLayoutCompliant(e),
// "Resharding on an inner axis is not lowerable ",
// e->toString());

CommunicationInfo communication_info = getCommunicationInfo(e);

Expand Down Expand Up @@ -567,6 +609,9 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::Reduce:
lowerToReduce(input_tv, output_tv, op_type(e), backend, comms);
break;
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
}

return comms;
Expand Down
46 changes: 46 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {
case CommunicationType::SendRecv:
os << "SendRecv";
break;
case CommunicationType::AllToAll:
os << "AllToAll";
break;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
}
Expand Down Expand Up @@ -152,6 +155,7 @@ bool hasRoot(CommunicationType type) {
case CommunicationType::Allgather:
case CommunicationType::Allreduce:
case CommunicationType::ReduceScatter:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Expand All @@ -169,6 +173,7 @@ bool isReduction(CommunicationType type) {
case CommunicationType::Scatter:
case CommunicationType::Broadcast:
case CommunicationType::SendRecv:
case CommunicationType::AllToAll:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Expand Down Expand Up @@ -572,6 +577,44 @@ c10::intrusive_ptr<c10d::Work> postSendRecv(
/*tag=*/0);
}
}

c10::intrusive_ptr<c10d::Work> postAllToAll(
Communication* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
NVF_ERROR(
isTvContiguous(communication->in()),
"Input tensor is not contiguous: ",
communication->in(),
" contiguity: ",
communication->in()->domain()->getContiguityString());
NVF_ERROR(
isTvContiguous(communication->out()),
"Output tensor is not contiguous: ",
communication->out(),
" contiguity: ",
communication->out()->domain()->getContiguityString());

auto flattened_input_tensor = viewAsCompact(input_tensor);
auto flattened_output_tensor = viewAsCompact(output_tensor);

// alltoall_bases requires even splits of the input and output tensors.
auto input_splits = at::tensor_split(
flattened_input_tensor, communication->team_size(), /*dim=*/0);
auto output_splits = at::tensor_split(
flattened_output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize(input_splits, output_splits);

std::vector<int64_t> empty_split_sizes;
return backend->alltoall_base(
flattened_output_tensor,
flattened_input_tensor,
empty_split_sizes,
empty_split_sizes,
/*options=*/{});
}
} // namespace

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
Expand Down Expand Up @@ -617,6 +660,9 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
case CommunicationType::SendRecv:
return postSendRecv(
communication, my_device_index, backend, input_tensor, output_tensor);
case CommunicationType::AllToAll:
return postAllToAll(
communication, my_device_index, backend, input_tensor, output_tensor);
default:
NVF_THROW("Wrong communication type: ", communication->type());
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ enum class CommunicationType {
Allreduce,
ReduceScatter,
Broadcast,
SendRecv
SendRecv,
AllToAll
};

std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/pre_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace nvfuser::preseg_passes {

OptimizationPass<PropagateShardingsPass>::runPass(fusion);
OptimizationPass<DecomposeReshardingsPass>::runPass(fusion);
OptimizationPass<ReorderShardedAxisPass>::runPass(fusion);
// OptimizationPass<ReorderShardedAxisPass>::runPass(fusion);

OptimizationPass<MarkAliasesPreparePass>::runPass(fusion);
OptimizationPass<AllocationDomainPass>::runPass(fusion);
Expand Down
2 changes: 1 addition & 1 deletion csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ inferAndValidateAllocationSizesAndStrides(
auto [allocation_sizes, allocation_strides] =
inferAllocationSizesAndStrides(tensor, tv, ee);

bool skip_validation = false;
bool skip_validation = true;

// Skip validation for block scales of BlockQuantizationOp with
// swizzled scales.
Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/test_multidevice_communications.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,30 @@ TEST_P(CommunicationTest, SendRecv) {
}
}

TEST_P(CommunicationTest, AllToAll) {
hir::HostIrContainer container;
FusionGuard fg(&container);
TensorView* in = makeContigTensor(1);
in->setDeviceMesh(full_mesh_);
TensorView* out = ops::newValLike(in, in->dtype())->as<TensorView>();
out->setDeviceMesh(full_mesh_);
auto communication = IrBuilder::create<Communication>(
CommunicationType::AllToAll, out, in, all_ranks_);

at::Tensor input_tensor = at::tensor({1, 2, 3, 4}, tensor_options_);
at::Tensor output_tensor = at::empty({4}, tensor_options_);

auto work = postSingleCommunication(
communication,
communicator_->deviceId(),
backend_,
input_tensor,
output_tensor);
work->wait();

debug() << "output_tensor: " << output_tensor << std::endl;
}

TEST_P(CommunicationTest, SendRecvToSelf) {
constexpr DeviceIdxType sender = 0;
if (communicator_->deviceId() > 0) {
Expand Down
21 changes: 21 additions & 0 deletions tests/python/direct/test_python_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,24 @@ def fusion_func(fd: FusionDefinition) -> None:
"FusionDefinition did not run correctly with profile enabled! Error: "
+ str(e)
)


def test_split_allocation_domain():
k, m, n = 2, 6, 5
with FusionDefinition() as fd:
inp = fd.define_tensor((k, m, n), contiguity=True, dtype=DataType.Half)
out = fd.ops.set(inp)
fd.add_output(out)

inp.set_allocation_domain(inp.get_loop_domain(), True)
out.outer_split(1, 2) # [k, 2, m//2, n]
out.reorder({1: 0}) # [2, k, m//2, n]
out.set_allocation_domain(out.get_loop_domain(), True)
print(out.domain())

in_tensor = torch.arange(k * m * n, dtype=torch.float16, device="cuda").reshape(
k, m, n
)
(out,) = fd.execute([in_tensor])
print("in_tensor: \n", in_tensor)
print("out flattened: \n", out.as_strided((k * m * n,), (1,)))
63 changes: 63 additions & 0 deletions tests/python/multidevice/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,66 @@ def test_binary(multidevice_direct_test):
torch.testing.assert_close(
z, multidevice_direct_test.shard_tensor(x.float() + y.float(), 0, mesh)
)


@pytest.mark.mpi
def test_alltoall(multidevice_direct_test):
d = multidevice_direct_test.size
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
k, m, n = 2, 3, 5

with FusionDefinition() as fd:
inp = fd.define_tensor((k, d * m, d * n), contiguity=True, dtype=DataType.Half)
all2all_inp = fd.ops.set(inp)
all2all_out = fd.ops.set(all2all_inp)
out = fd.ops.set(all2all_out)
fd.add_output(all2all_inp)
fd.add_output(all2all_out)
fd.add_output(out)

inp.set_device_mesh(mesh)
inp.outer_split(2, d)
inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)

all2all_inp.set_device_mesh(mesh)
all2all_inp.outer_split(2, d)
all2all_inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
all2all_inp.outer_split(1, d)
all2all_inp.reorder({1: 0}) # [d, k, m, DIDx(d), n]
all2all_inp.set_allocation_domain(all2all_inp.get_loop_domain(), True)

all2all_out.set_device_mesh(mesh)
all2all_out.outer_split(2, d)
all2all_out.outer_split(1, d)
all2all_out.axis(1).parallelize(
nvfuser.ParallelType.mesh_x
) # [k, DIDx(d), m, d, n]
all2all_out.reorder({3: 0}) # [d, k, m, DIDx(d), n]
all2all_out.set_allocation_domain(all2all_out.get_loop_domain(), True)

out.set_device_mesh(mesh)
out.outer_split(1, d)
out.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
out.set_allocation_domain(out.get_loop_domain(), True)

in_tensor = torch.arange(k * d * m * d * n, dtype=torch.float16).reshape(
k, d * m, d * n
)
sharded = multidevice_direct_test.shard_tensor(in_tensor, 2, mesh)
(all2all_inp, all2all_out, out) = fd.execute([sharded])
if multidevice_direct_test.rank == 0:
print("in_tensor: \n", in_tensor)
print("sharded: \n", sharded)
# print("out: \n", out)
print(
"all2all_inp flattened: \n",
all2all_inp.as_strided((sharded.numel(),), (1,)),
)
print(
"all2all_out flattened: \n",
all2all_out.as_strided((sharded.numel(),), (1,)),
)
print(out)
torch.testing.assert_close(
out, multidevice_direct_test.shard_tensor(in_tensor, 1, mesh)
)