Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6a4a494
Integrate `MetadataPayloadExchange` in `Shuffler`
pentschev Nov 10, 2025
095c14e
Fix send of zero-sized payloads
pentschev Nov 10, 2025
5cf22a1
Add a helper function to convert Chunks into Messages
pentschev Nov 11, 2025
817e8de
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 12, 2026
843c33c
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 12, 2026
0023495
Fix control message
pentschev Mar 12, 2026
2ea7c58
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 12, 2026
3a8163b
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 13, 2026
72bf896
Remove unused _op_id_ attribute
pentschev Mar 13, 2026
18a46f3
Prevent test_some call when using single comms
pentschev Mar 13, 2026
98e8ed4
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 30, 2026
4a30128
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 30, 2026
0c9df2f
Partial fix for op_id reuse
pentschev Mar 30, 2026
efe0361
Support op_id reuse via MPE
pentschev Mar 30, 2026
793119b
Add polling mode UCXX progress to RapidsMPF's ProgressThread
pentschev Mar 31, 2026
eb0868b
Only send/expect receive termination from peers that communicated
pentschev Mar 31, 2026
931bef4
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Mar 31, 2026
2e6cef1
Pre-distribute listener addresses during `barrier()`
pentschev Mar 31, 2026
e3500b5
Restructure control messages
pentschev Mar 31, 2026
1f05038
Bulk-pack listener addresses into a single AM per rank
pentschev Mar 31, 2026
f8ea148
Remove duplicate barrier() implementation
pentschev Mar 31, 2026
a1a902c
Add optional `data` arg to `Chunk::deserialize()`
pentschev Apr 2, 2026
e9d97ad
Merge remote-tracking branch 'upstream/main' into shuffle-integrate-c…
pentschev Apr 7, 2026
35b86ae
Merge branch 'main' into shuffle-integrate-comms-interface
pentschev Apr 14, 2026
792cd0f
Prevent serialized metadata copying
pentschev Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -182,14 +182,28 @@ class MetadataPayloadExchange {
*/
[[nodiscard]] virtual std::vector<std::unique_ptr<Message>> recv() = 0;

/**
* @brief Signal that no more messages will be sent.
*
* After calling this method, no further calls to send() are permitted.
* The implementation sends protocol-level termination markers to all peers so
* that each receiver knows the exact number of application messages to expect.
* This enables safe reuse of operation IDs: once all termination markers have
* been received and all expected messages processed, the communication layer
* considers itself idle and the tag/op_id can be reused.
*
* @throws std::logic_error If called more than once.
*/
virtual void finish() = 0;

/**
* @brief Check if the communication layer is currently idle.
*
* Indicates whether there are any active or pending communication operations.
* A return value of `true` means the exchange is idling, i.e. no operations
* are currently in progress. However, new send/receive requests may still be
* submitted in the future; this does not imply that all communication has been
* fully finalized or globally synchronized.
* Before finish() is called, a return value of `true` means no I/O operations
* are in progress. After finish() is called, `true` additionally requires that
* all peers have sent their termination markers and all expected messages have
* been received, meaning the op_id can safely be reused.
*
* @return `true` if the communication layer is idle; `false` if activity is ongoing.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -74,6 +74,11 @@ class TagMetadataPayloadExchange : public MetadataPayloadExchange {
*/
std::vector<std::unique_ptr<Message>> recv() override;

/**
* @copydoc MetadataPayloadExchange::finish
*/
void finish() override;

/**
* @copydoc MetadataPayloadExchange::is_idle
*/
Expand All @@ -100,10 +105,29 @@ class TagMetadataPayloadExchange : public MetadataPayloadExchange {

// Core communication infrastructure
std::shared_ptr<Communicator> comm_;
Rank const nranks_;
Rank const rank_;
Tag const metadata_tag_;
Tag const gpu_data_tag_;
std::function<std::unique_ptr<Buffer>(std::size_t)> allocate_buffer_fn_;

// Per-peer tracking for op_id reuse (see rapidsai/rapidsmpf#927).
// After finish() is called, termination markers are exchanged so each peer
// knows exactly how many application messages to expect. recv_from is used
// instead of recv_any to avoid consuming messages from a future collective.
bool finished_{false};
std::vector<std::size_t>
messages_sent_to_; ///< Application messages sent per peer, indexed by Rank.
std::vector<std::size_t>
peer_received_; ///< Application messages received per peer, indexed by Rank.
std::vector<std::size_t>
peer_expected_; ///< Expected application messages per peer (0 = unknown).
std::vector<bool>
peer_terminated_; ///< Whether we received the termination marker from each peer.

/// Sentinel message_id value used to identify protocol-level termination markers.
static constexpr std::uint64_t termination_sentinel_ = UINT64_MAX;

// Communication state containers
std::vector<std::unique_ptr<Communicator::Future>>
fire_and_forget_; ///< Ongoing "fire-and-forget" operations (non-blocking sends).
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/rapidsmpf/shuffler/chunk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,24 @@ class Chunk {
* @brief Create a chunk by deserializing a metadata message.
*
* @param msg The metadata message received from another rank.
* @param br Buffer resource for allocating a the data buffer of the deserialized
* message.
* @param br Buffer resource for allocating the data buffer of the deserialized
* message. Ignored when @p data is provided.
* @param validate Whether to validate the metadata buffer.
* @param data Optional pre-existing data buffer to use instead of allocating a new
* one. When non-null, the buffer resource allocation is skipped entirely, avoiding
* unnecessary memory pressure from a temporary allocation.
* @return The chunk.
*
* @throws std::logic_error if the chunk is not a control message and no buffer
* resource is provided. @throws std::runtime_error if the metadata buffer does not
* follow the expected format and `validate` is true.
* @throws std::logic_error if the chunk is not a control message and neither @p data
* nor @p br is provided.
* @throws std::runtime_error if the metadata buffer does not follow the expected
* format and @p validate is true.
*/
static Chunk deserialize(
std::vector<std::uint8_t> const& msg, BufferResource* br, bool validate = true
std::vector<std::uint8_t> const& msg,
BufferResource* br,
bool validate = true,
std::unique_ptr<Buffer> data = nullptr
);

/**
Expand Down
23 changes: 19 additions & 4 deletions cpp/include/rapidsmpf/shuffler/shuffler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include <rapidsmpf/communicator/communicator.hpp>
#include <rapidsmpf/communicator/metadata_payload_exchange/tag.hpp>
#include <rapidsmpf/error.hpp>
#include <rapidsmpf/memory/buffer_resource.hpp>
#include <rapidsmpf/memory/packed_data.hpp>
Expand Down Expand Up @@ -118,6 +119,8 @@ class Shuffler {
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param finished_callback Callback to notify when all partitions are finished.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note It is safe to reuse the `op_id` as soon as `wait` has completed
* locally.
Expand All @@ -132,7 +135,8 @@ class Shuffler {
PartID total_num_partitions,
BufferResource* br,
FinishedCallback&& finished_callback,
PartitionOwner partition_owner = round_robin
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
);

/**
Expand All @@ -144,6 +148,8 @@ class Shuffler {
* @param total_num_partitions Total number of partitions in the shuffle.
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -154,9 +160,18 @@ class Shuffler {
OpID op_id,
PartID total_num_partitions,
BufferResource* br,
PartitionOwner partition_owner = round_robin
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
)
: Shuffler(comm, op_id, total_num_partitions, br, nullptr, partition_owner) {}
: Shuffler(
comm,
op_id,
total_num_partitions,
br,
nullptr,
partition_owner,
std::move(mpe)
) {}

~Shuffler();

Expand Down Expand Up @@ -313,12 +328,12 @@ class Shuffler {
// Flipped to true exactly once when partitions are ready for extraction and we've
// posted all sends we're going to
bool can_extract_{false};
OpID const op_id_;
detail::ChunksToSend to_send_; ///< Storage for chunks to send to other ranks.
detail::ReceivedChunks received_; ///< Storage for received chunks that are
///< ready to be extracted by the user.

std::shared_ptr<Communicator> comm_;
std::unique_ptr<communicator::MetadataPayloadExchange> mpe_;
ProgressThread::FunctionID progress_thread_function_id_;

SpillManager::SpillFunctionID spill_function_id_;
Expand Down
Loading
Loading