Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions cpp/src/coll/allgather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,44 @@ ProgressThread::ProgressState AllGather::event_loop() {
*/
Rank const dst = (comm_->rank() + 1) % comm_->nranks();
Rank const src = (comm_->rank() + comm_->nranks() - 1) % comm_->nranks();
// GPU data sends and metadata sends can be arbitrarily interleaved. To allow reuse of
// `op_id` once `wait_and_extract()` returns, we rely on a number of invariants
// enforced by the communication scheme.
//
// Suppose we have two successive allgathers separated by a wait_and_extract "barrier"
// that reuse the op_id:
//
// AG1(op_id)
// AG1.wait_and_extract()
// AG2(op_id)
//
// The requirements for safe reuse of the tag are that:
// 1. all metadata sends/receives from AG1 are posted before wait_and_extract returns
// 2. all data sends/receives are posted before wait_and_extract returns
//
// There can be arbitrary interleaving of messages (e.g. finish messages and normal
// metadata messages), and data messages and metadata messages, as long as these two
// invariants are upheld.
//
// The communication scheme in this loop enforces this in the following way.
// The finish condition requires that:
// - we have received finish messages from all ranks, defining the final extraction
// goalpost;
// - The extraction postbox has reached a size equal to the advertised goalpost.
//
// Posting receives for more metadata is gated on both of these conditions, so we only
// post exactly the correct number of receives.
//
// To ensure that data sends/receives are correctly posted, note that data is only put
// in the extraction postbox _after_ it has been posted for send, therefore
// `wait_and_extract()` cannot return until all sends/receives have at least been
// posted, upholding the required invariants.
Tag metadata_tag{op_id_, 0};
Tag finish_tag{op_id_, 1};
Tag gpu_data_tag{op_id_, 2};
Tag gpu_data_tag{op_id_, 1};
if (comm_->nranks() == 1) {
// Note that we don't need to use extract_ready because there is
// no message passing and our promise to the consumer is that
// extracted data are valid on the stream used to construct
// the allgather instance.
// Note that we don't need to use extract_ready because there is no message
// passing and our promise to the consumer is that extracted data chunks are valid
// on their respective streams.
for (auto&& chunk : inserted_.extract()) {
if (chunk->is_finish()) {
mark_finish(chunk->sequence());
Expand All @@ -454,10 +484,13 @@ ProgressThread::ProgressState AllGather::event_loop() {
} else {
// Chunks that are ready to send
for (auto&& chunk : inserted_.extract_ready()) {
// Tell the destination about them. Finish messages use a separate tag so they
// can be received independently of data metadata.
Tag const send_tag = chunk->is_finish() ? finish_tag : metadata_tag;
fire_and_forget_.push_back(comm_->send(chunk->serialize(), dst, send_tag));
// Tell the destination about them. All messages (data + finish) share
// metadata_tag so the no-overtaking guarantee on a single (src, tag) pair
// ensures current-collective messages arrive before any new-collective
// messages that reuse the same op_id.
fire_and_forget_.push_back(
comm_->send(chunk->serialize(), dst, metadata_tag)
);
if (chunk->is_finish()) {
// Finish chunk contains as sequence number the number
// of insertions from that rank.
Expand All @@ -470,28 +503,10 @@ ProgressThread::ProgressState AllGather::event_loop() {
);
}
}
// Poll for any remaining finish messages. As soon as all finish messages are
// received we stop polling for more finish messages so that a subsequent
// collective using the same OpID doesn't get matched here, even if the event loop
// continues to process actual data messages.
while (remote_finish_counter_ > 0) {
auto const msg = comm_->recv_from(src, finish_tag);
if (!msg) {
break;
}
auto chunk = detail::Chunk::deserialize(*msg, br_);
remote_finish_counter_--;
num_expected_messages_ += chunk->sequence();
if (chunk->origin() != dst) {
fire_and_forget_.push_back(
comm_->send(chunk->serialize(), dst, finish_tag)
);
}
mark_finish(chunk->sequence());
}
// Poll for the data messages we expect. We might receive more data as part of
// this collective if we either haven't received all finish messages, or we have
// and haven't yet processed all the data messages.
// Receive metadata messages. All messages (data + finish) share metadata_tag, so
// the no-overtaking guarantee ensures current-collective messages arrive before
// any new-collective messages that reuse the same op_id. While either of these
// conditions are true, this allgather needs to consume more metadata messages.
while (remote_finish_counter_ > 0
|| num_received_messages_ < num_expected_messages_)
{
Expand All @@ -500,8 +515,19 @@ ProgressThread::ProgressState AllGather::event_loop() {
break;
}
auto chunk = detail::Chunk::deserialize(*msg, br_);
num_received_messages_++;
to_receive_.emplace_back(std::move(chunk));
if (chunk->is_finish()) {
remote_finish_counter_--;
num_expected_messages_ += chunk->sequence();
if (chunk->origin() != dst) {
fire_and_forget_.push_back(
comm_->send(chunk->serialize(), dst, metadata_tag)
);
}
mark_finish(chunk->sequence());
} else {
num_received_messages_++;
to_receive_.emplace_back(std::move(chunk));
}
}
// Post receives if the chunk is ready
for (auto&& chunk : to_receive_) {
Expand Down
Loading