From e63b4f4268dea6820481b0e3d3f97bbe7819b50b Mon Sep 17 00:00:00 2001 From: Tarjei Mandt Date: Sun, 26 Apr 2026 09:02:51 +1000 Subject: [PATCH] [jaccl] Fix race on local_staging in MeshImpl::all_reduce The SEND-completion handler refilled local_staging(buff) the moment all peers ACK'd the previous send, regardless of whether that chunk had been consumed by the own-rank reduction step. RDMA timing made this non-deterministic, producing wrong sums for messages spanning multiple PIPELINE chunks. Decouple the two: SEND completion only refills send_buffer (free once on the wire) and posts the next send. local_staging(b) is refilled in the reduce loop right after the own-rank reduction reads it, which also bumps recv_end[rank_] to gate the next step. --- mlx/distributed/jaccl/lib/jaccl/mesh_impl.h | 23 ++++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h b/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h index 0201f7ffff..6327f41d35 100644 --- a/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h +++ b/mlx/distributed/jaccl/lib/jaccl/mesh_impl.h @@ -89,9 +89,10 @@ class MeshImpl { // // If a send was completed mark how many completions we have received // for that buffer. If we have sent the buffer to all peers we can - // reuse the buffer so copy the next chunk of data and send it to all. - // Also copy the next chunk into the staging area and advance our - // completed "receives". + // reuse the send buffer so copy the next chunk of data into it and + // post the next send. local_staging is refilled later in the reduce + // loop, since the chunk in it may still be needed for the own-rank + // reduction step. // // If a receive is completed then advance the pointer of completed // receives. @@ -108,15 +109,10 @@ class MeshImpl { completed_send_count[buff]++; if (completed_send_count[buff] == num_peers) { int64_t elems = std::min(N, total - read_offset); - std::copy( - in + read_offset, - in + read_offset + elems, - local_staging(buff)); std::copy( in + read_offset, in + read_offset + elems, send_buffer(sz, buff).begin()); - recv_end[rank_]++; post_send_all(sz, buff); completed_send_count[buff] = 0; @@ -155,6 +151,17 @@ class MeshImpl { } else { reduce_op(local_staging(b), out + w, elems); } + + // Refill local_staging(b) for the next own-rank chunk and bump + // recv_end[rank_] to unblock its reduction. + int64_t next_local_chunk = + static_cast(reduce_chunk) + PIPELINE; + if (next_local_chunk < total_chunks) { + int64_t off = next_local_chunk * N; + int64_t next_elems = std::min(N, total - off); + std::copy(in + off, in + off + next_elems, local_staging(b)); + recv_end[rank_]++; + } } // Data is read from the recv buffers