Skip to content

Commit 8149bb9

Browse files
Support PushMergedData in CppClient
1 parent fddb817 commit 8149bb9

20 files changed

Lines changed: 1456 additions & 4 deletions

cpp/celeborn/client/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_library(
2020
writer/PushStrategy.cpp
2121
writer/ReviveManager.cpp
2222
writer/PushDataCallback.cpp
23+
writer/PushMergedDataCallback.cpp
2324
ShuffleClient.cpp
2425
compress/Decompressor.cpp
2526
compress/Lz4Decompressor.cpp

cpp/celeborn/client/ShuffleClient.cpp

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,377 @@ int ShuffleClientImpl::pushData(
269269
return body->remainingSize();
270270
}
271271

272+
int ShuffleClientImpl::mergeData(
273+
int shuffleId,
274+
int mapId,
275+
int attemptId,
276+
int partitionId,
277+
const uint8_t* data,
278+
size_t offset,
279+
size_t length,
280+
int numMappers,
281+
int numPartitions) {
282+
const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
283+
if (checkMapperEnded(shuffleId, mapId, mapKey)) {
284+
return 0;
285+
}
286+
287+
auto partitionLocationMap =
288+
getPartitionLocation(shuffleId, numMappers, numPartitions);
289+
CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
290+
auto partitionLocationOptional = partitionLocationMap->get(partitionId);
291+
if (!partitionLocationOptional.has_value()) {
292+
if (!revive(
293+
shuffleId,
294+
mapId,
295+
attemptId,
296+
partitionId,
297+
-1,
298+
nullptr,
299+
protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
300+
CELEBORN_FAIL(fmt::format(
301+
"Revive for shuffleId {} partitionId {} failed.",
302+
shuffleId,
303+
partitionId));
304+
}
305+
partitionLocationOptional = partitionLocationMap->get(partitionId);
306+
}
307+
if (checkMapperEnded(shuffleId, mapId, mapKey)) {
308+
return 0;
309+
}
310+
311+
CELEBORN_CHECK(partitionLocationOptional.has_value());
312+
auto partitionLocation = partitionLocationOptional.value();
313+
auto pushState = getPushState(mapKey);
314+
const int nextBatchId = pushState->nextBatchId();
315+
316+
CELEBORN_CHECK(
317+
length <= static_cast<size_t>(std::numeric_limits<int>::max()),
318+
fmt::format(
319+
"Data length {} exceeds maximum supported size {}",
320+
length,
321+
std::numeric_limits<int>::max()));
322+
323+
// Compression support
324+
const uint8_t* dataToWrite = data + offset;
325+
int lengthToWrite = static_cast<int>(length);
326+
std::unique_ptr<uint8_t[]> compressedBuffer;
327+
328+
if (shuffleCompressionEnabled_ && compressorFactory_) {
329+
auto compressor = compressorFactory_();
330+
const size_t compressedCapacity =
331+
compressor->getDstCapacity(static_cast<int>(length));
332+
compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
333+
334+
const size_t compressedSize = compressor->compress(
335+
dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
336+
337+
CELEBORN_CHECK(
338+
compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
339+
fmt::format(
340+
"Compressed size {} exceeds maximum supported size {}",
341+
compressedSize,
342+
std::numeric_limits<int>::max()));
343+
344+
lengthToWrite = static_cast<int>(compressedSize);
345+
dataToWrite = compressedBuffer.get();
346+
}
347+
348+
CELEBORN_CHECK(
349+
static_cast<size_t>(lengthToWrite) <=
350+
std::numeric_limits<size_t>::max() - kBatchHeaderSize,
351+
fmt::format(
352+
"Buffer size {} + header {} would overflow",
353+
lengthToWrite,
354+
kBatchHeaderSize));
355+
356+
auto writeBuffer = memory::ByteBuffer::createWriteOnly(
357+
kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
358+
writeBuffer->writeLE<int>(mapId);
359+
writeBuffer->writeLE<int>(attemptId);
360+
writeBuffer->writeLE<int>(nextBatchId);
361+
writeBuffer->writeLE<int>(lengthToWrite);
362+
writeBuffer->writeFromBuffer(
363+
dataToWrite, 0, static_cast<size_t>(lengthToWrite));
364+
365+
auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
366+
int bodySize = static_cast<int>(body->remainingSize());
367+
368+
// Build address pair for the target worker.
369+
std::string primaryAddr = partitionLocation->hostAndPushPort();
370+
std::string replicaAddr;
371+
if (partitionLocation->hasPeer()) {
372+
replicaAddr = partitionLocation->getPeer()->hostAndPushPort();
373+
}
374+
PushState::AddressPair addressPair(primaryAddr, replicaAddr);
375+
376+
int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
377+
bool exceedsThreshold = pushState->addBatchData(
378+
addressPair, partitionLocation, nextBatchId, std::move(body),
379+
pushBufferMaxSize);
380+
381+
if (exceedsThreshold) {
382+
auto hostAndPushPort = partitionLocation->hostAndPushPort();
383+
limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
384+
385+
auto dataBatches = pushState->takeDataBatches(addressPair);
386+
if (dataBatches) {
387+
auto batches = dataBatches->requireBatches();
388+
if (!batches.empty()) {
389+
doPushMergedData(
390+
addressPair,
391+
shuffleId,
392+
mapId,
393+
attemptId,
394+
numMappers,
395+
numPartitions,
396+
std::move(batches),
397+
pushState,
398+
conf_->clientPushMaxReviveTimes());
399+
}
400+
}
401+
}
402+
403+
return bodySize;
404+
}
405+
406+
void ShuffleClientImpl::pushMergedData(
407+
int shuffleId,
408+
int mapId,
409+
int attemptId) {
410+
const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
411+
auto pushState = getPushState(mapKey);
412+
int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
413+
414+
// Collect all address pairs first to avoid modifying map during iteration.
415+
std::vector<PushState::AddressPair> addressPairs;
416+
pushState->forEachBatchEntry(
417+
[&](const PushState::AddressPair& addr,
418+
std::shared_ptr<DataBatches> /* unused */) {
419+
addressPairs.push_back(addr);
420+
});
421+
422+
for (const auto& addressPair : addressPairs) {
423+
auto dataBatches = pushState->takeDataBatches(addressPair);
424+
if (!dataBatches) {
425+
continue;
426+
}
427+
// Drain in chunks of pushBufferMaxSize.
428+
while (dataBatches->getTotalSize() > 0) {
429+
auto hostAndPushPort = addressPair.first;
430+
limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
431+
auto batches = dataBatches->requireBatches(pushBufferMaxSize);
432+
if (batches.empty()) {
433+
break;
434+
}
435+
doPushMergedData(
436+
addressPair,
437+
shuffleId,
438+
mapId,
439+
attemptId,
440+
0, // numMappers not needed for flush
441+
0, // numPartitions not needed for flush
442+
std::move(batches),
443+
pushState,
444+
conf_->clientPushMaxReviveTimes());
445+
}
446+
}
447+
}
448+
449+
void ShuffleClientImpl::doPushMergedData(
450+
const PushState::AddressPair& addressPair,
451+
int shuffleId,
452+
int mapId,
453+
int attemptId,
454+
int numMappers,
455+
int numPartitions,
456+
std::vector<DataBatch> batches,
457+
std::shared_ptr<PushState> pushState,
458+
int remainReviveTimes) {
459+
int groupedBatchId = pushState->nextBatchId();
460+
461+
// Build partitionUniqueIds, batchOffsets, and concatenated body.
462+
std::vector<std::string> partitionUniqueIds;
463+
std::vector<int> batchOffsets;
464+
int totalBytes = 0;
465+
for (const auto& batch : batches) {
466+
partitionUniqueIds.push_back(batch.loc->uniqueId());
467+
batchOffsets.push_back(totalBytes);
468+
totalBytes += static_cast<int>(batch.body->remainingSize());
469+
}
470+
471+
// Register in-flight.
472+
auto hostAndPushPort = addressPair.first;
473+
pushState->addBatch(groupedBatchId, totalBytes, hostAndPushPort);
474+
475+
// Concatenate all bodies into a single buffer.
476+
auto combinedBuffer = memory::ByteBuffer::createWriteOnly(totalBytes);
477+
for (const auto& batch : batches) {
478+
auto bodyClone = batch.body->clone();
479+
size_t batchSize = bodyClone->remainingSize();
480+
auto tmpBuf = std::make_unique<uint8_t[]>(batchSize);
481+
bodyClone->readToBuffer(tmpBuf.get(), batchSize);
482+
combinedBuffer->writeFromBuffer(tmpBuf.get(), 0, batchSize);
483+
}
484+
auto combinedBody =
485+
memory::ByteBuffer::toReadOnly(std::move(combinedBuffer));
486+
487+
// Get host/port from the first batch's location before moving batches.
488+
const auto& firstLoc = batches[0].loc;
489+
const auto sendHost = firstLoc->host;
490+
const auto sendPort = firstLoc->pushPort;
491+
492+
// Build PushMergedData message.
493+
const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
494+
network::PushMergedData pushMergedData(
495+
network::Message::nextRequestId(),
496+
protocol::PartitionLocation::Mode::PRIMARY,
497+
shuffleKey,
498+
std::move(partitionUniqueIds),
499+
std::move(batchOffsets),
500+
std::move(combinedBody));
501+
502+
// Build callback.
503+
auto callback = PushMergedDataCallback::create(
504+
shuffleId,
505+
mapId,
506+
attemptId,
507+
numMappers,
508+
numPartitions,
509+
utils::makeMapKey(shuffleId, mapId, attemptId),
510+
groupedBatchId,
511+
std::move(batches),
512+
pushState,
513+
weak_from_this(),
514+
remainReviveTimes,
515+
addressPair);
516+
517+
// Send.
518+
auto client = clientFactory_->createClient(sendHost, sendPort);
519+
client->pushMergedDataAsync(
520+
pushMergedData, conf_->clientPushDataTimeout(), callback);
521+
}
522+
523+
void ShuffleClientImpl::submitRetryPushMergedData(
524+
int shuffleId,
525+
int mapId,
526+
int attemptId,
527+
int numMappers,
528+
int numPartitions,
529+
int groupedBatchId,
530+
std::vector<DataBatch> batches,
531+
std::shared_ptr<PushState> pushState,
532+
const std::vector<std::shared_ptr<protocol::ReviveRequest>>&
533+
reviveRequests,
534+
int remainReviveTimes,
535+
long dueTimeMs) {
536+
long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis();
537+
long accumulatedTimeMs = 0;
538+
const long deltaMs = 50;
539+
540+
// Wait for all revive requests to complete.
541+
bool allDone = false;
542+
while (!allDone && accumulatedTimeMs <= reviveWaitTimeMs) {
543+
allDone = true;
544+
for (const auto& req : reviveRequests) {
545+
if (req->reviveStatus.load() ==
546+
protocol::StatusCode::REVIVE_INITIALIZED) {
547+
allDone = false;
548+
break;
549+
}
550+
}
551+
if (!allDone) {
552+
std::this_thread::sleep_for(utils::MS(deltaMs));
553+
accumulatedTimeMs += deltaMs;
554+
}
555+
}
556+
557+
std::string oldHostAndPushPort;
558+
if (!reviveRequests.empty() && reviveRequests[0]->loc) {
559+
oldHostAndPushPort = reviveRequests[0]->loc->hostAndPushPort();
560+
}
561+
562+
if (mapperEnded(shuffleId, mapId)) {
563+
pushState->removeBatch(groupedBatchId, oldHostAndPushPort);
564+
return;
565+
}
566+
567+
// Check if all revives succeeded.
568+
for (const auto& req : reviveRequests) {
569+
if (req->reviveStatus.load() != protocol::StatusCode::SUCCESS) {
570+
pushState->setException(std::make_unique<std::runtime_error>(
571+
"Revive failed for pushMergedData retry"));
572+
return;
573+
}
574+
}
575+
576+
// Remove old in-flight tracking.
577+
pushState->removeBatch(groupedBatchId, oldHostAndPushPort);
578+
579+
// Regroup batches by new partition locations.
580+
auto locationMapOptional = partitionLocationMaps_.get(shuffleId);
581+
CELEBORN_CHECK(locationMapOptional.has_value());
582+
auto locationMap = locationMapOptional.value();
583+
584+
// Build a map from partitionId to new location.
585+
std::unordered_map<int, std::shared_ptr<const protocol::PartitionLocation>>
586+
newLocationMap;
587+
for (const auto& req : reviveRequests) {
588+
auto newLocOpt = locationMap->get(req->partitionId);
589+
if (newLocOpt.has_value()) {
590+
newLocationMap[req->partitionId] = newLocOpt.value();
591+
}
592+
}
593+
594+
// Group batches by new primary address.
595+
using AddressPair = PushState::AddressPair;
596+
struct AddressPairHasher {
597+
size_t operator()(const AddressPair& p) const {
598+
size_t h1 = std::hash<std::string>{}(p.first);
599+
size_t h2 = std::hash<std::string>{}(p.second);
600+
return h1 ^ (h2 << 1);
601+
}
602+
};
603+
std::unordered_map<AddressPair, std::vector<DataBatch>, AddressPairHasher>
604+
regrouped;
605+
606+
for (auto& batch : batches) {
607+
auto it = newLocationMap.find(batch.loc->id);
608+
if (it == newLocationMap.end()) {
609+
continue;
610+
}
611+
auto newLoc = it->second;
612+
std::string primaryAddr = newLoc->hostAndPushPort();
613+
std::string replicaAddr;
614+
if (newLoc->hasPeer()) {
615+
replicaAddr = newLoc->getPeer()->hostAndPushPort();
616+
}
617+
AddressPair newAddr(primaryAddr, replicaAddr);
618+
batch.loc = newLoc;
619+
regrouped[newAddr].push_back(std::move(batch));
620+
}
621+
622+
LOG(INFO) << "Revive for push merged data succeeded for shuffle "
623+
<< shuffleId << " map " << mapId << " attempt " << attemptId
624+
<< " groupedBatchId " << groupedBatchId
625+
<< ", regrouped into " << regrouped.size() << " groups.";
626+
627+
// Re-push each group.
628+
for (auto& [addrPair, groupBatches] : regrouped) {
629+
CELEBORN_CHECK_GT(remainReviveTimes, 0, "no remainReviveTime left");
630+
doPushMergedData(
631+
addrPair,
632+
shuffleId,
633+
mapId,
634+
attemptId,
635+
numMappers,
636+
numPartitions,
637+
std::move(groupBatches),
638+
pushState,
639+
remainReviveTimes);
640+
}
641+
}
642+
272643
void ShuffleClientImpl::mapperEnd(
273644
int shuffleId,
274645
int mapId,

0 commit comments

Comments
 (0)