diff --git a/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h b/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h index 0201f7ffff..6327f41d35 100644 --- a/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h +++ b/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h @@ -89,9 +89,10 @@ class MeshImpl { // // If a send was completed mark how many completions we have received // for that buffer. If we have sent the buffer to all peers we can - // reuse the buffer so copy the next chunk of data and send it to all. - // Also copy the next chunk into the staging area and advance our - // completed "receives". + // reuse the send buffer so copy the next chunk of data into it and + // post the next send. local_staging is refilled later in the reduce + // loop, since the chunk in it may still be needed for the own-rank + // reduction step. // // If a receive is completed then advance the pointer of completed // receives. @@ -108,15 +109,10 @@ class MeshImpl { completed_send_count[buff]++; if (completed_send_count[buff] == num_peers) { int64_t elems = std::min(N, total - read_offset); - std::copy( - in + read_offset, - in + read_offset + elems, - local_staging(buff)); std::copy( in + read_offset, in + read_offset + elems, send_buffer(sz, buff).begin()); - recv_end[rank_]++; post_send_all(sz, buff); completed_send_count[buff] = 0; @@ -155,6 +151,17 @@ class MeshImpl { } else { reduce_op(local_staging(b), out + w, elems); } + + // Refill local_staging(b) for the next own-rank chunk and bump + // recv_end[rank_] to unblock its reduction. + int64_t next_local_chunk = + static_cast(reduce_chunk) + PIPELINE; + if (next_local_chunk < total_chunks) { + int64_t off = next_local_chunk * N; + int64_t next_elems = std::min(N, total - off); + std::copy(in + off, in + off + next_elems, local_staging(b)); + recv_end[rank_]++; + } } // Data is read from the recv buffers