Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions mlx/distributed/jaccl/lib/jaccl/mesh_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ class MeshImpl {
//
// If a send was completed mark how many completions we have received
// for that buffer. If we have sent the buffer to all peers we can
// reuse the buffer so copy the next chunk of data and send it to all.
// Also copy the next chunk into the staging area and advance our
// completed "receives".
// reuse the send buffer so copy the next chunk of data into it and
// post the next send. local_staging is refilled later in the reduce
// loop, since the chunk in it may still be needed for the own-rank
// reduction step.
//
// If a receive is completed then advance the pointer of completed
// receives.
Expand All @@ -108,15 +109,10 @@ class MeshImpl {
completed_send_count[buff]++;
if (completed_send_count[buff] == num_peers) {
int64_t elems = std::min(N, total - read_offset);
std::copy(
in + read_offset,
in + read_offset + elems,
local_staging(buff));
std::copy(
in + read_offset,
in + read_offset + elems,
send_buffer(sz, buff).begin<T>());
recv_end[rank_]++;
post_send_all(sz, buff);

completed_send_count[buff] = 0;
Expand Down Expand Up @@ -155,6 +151,17 @@ class MeshImpl {
} else {
reduce_op(local_staging(b), out + w, elems);
}

// Refill local_staging(b) for the next own-rank chunk and bump
// recv_end[rank_] to unblock its reduction.
int64_t next_local_chunk =
static_cast<int64_t>(reduce_chunk) + PIPELINE;
if (next_local_chunk < total_chunks) {
int64_t off = next_local_chunk * N;
int64_t next_elems = std::min(N, total - off);
std::copy(in + off, in + off + next_elems, local_staging(b));
recv_end[rank_]++;
}
}

// Data is read from the recv buffers
Expand Down