@@ -269,6 +269,377 @@ int ShuffleClientImpl::pushData(
269269 return body->remainingSize ();
270270}
271271
272+ int ShuffleClientImpl::mergeData (
273+ int shuffleId,
274+ int mapId,
275+ int attemptId,
276+ int partitionId,
277+ const uint8_t * data,
278+ size_t offset,
279+ size_t length,
280+ int numMappers,
281+ int numPartitions) {
282+ const auto mapKey = utils::makeMapKey (shuffleId, mapId, attemptId);
283+ if (checkMapperEnded (shuffleId, mapId, mapKey)) {
284+ return 0 ;
285+ }
286+
287+ auto partitionLocationMap =
288+ getPartitionLocation (shuffleId, numMappers, numPartitions);
289+ CELEBORN_CHECK_NOT_NULL (partitionLocationMap);
290+ auto partitionLocationOptional = partitionLocationMap->get (partitionId);
291+ if (!partitionLocationOptional.has_value ()) {
292+ if (!revive (
293+ shuffleId,
294+ mapId,
295+ attemptId,
296+ partitionId,
297+ -1 ,
298+ nullptr ,
299+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
300+ CELEBORN_FAIL (fmt::format (
301+ " Revive for shuffleId {} partitionId {} failed." ,
302+ shuffleId,
303+ partitionId));
304+ }
305+ partitionLocationOptional = partitionLocationMap->get (partitionId);
306+ }
307+ if (checkMapperEnded (shuffleId, mapId, mapKey)) {
308+ return 0 ;
309+ }
310+
311+ CELEBORN_CHECK (partitionLocationOptional.has_value ());
312+ auto partitionLocation = partitionLocationOptional.value ();
313+ auto pushState = getPushState (mapKey);
314+ const int nextBatchId = pushState->nextBatchId ();
315+
316+ CELEBORN_CHECK (
317+ length <= static_cast <size_t >(std::numeric_limits<int >::max ()),
318+ fmt::format (
319+ " Data length {} exceeds maximum supported size {}" ,
320+ length,
321+ std::numeric_limits<int >::max ()));
322+
323+ // Compression support
324+ const uint8_t * dataToWrite = data + offset;
325+ int lengthToWrite = static_cast <int >(length);
326+ std::unique_ptr<uint8_t []> compressedBuffer;
327+
328+ if (shuffleCompressionEnabled_ && compressorFactory_) {
329+ auto compressor = compressorFactory_ ();
330+ const size_t compressedCapacity =
331+ compressor->getDstCapacity (static_cast <int >(length));
332+ compressedBuffer = std::make_unique<uint8_t []>(compressedCapacity);
333+
334+ const size_t compressedSize = compressor->compress (
335+ dataToWrite, 0 , static_cast <int >(length), compressedBuffer.get (), 0 );
336+
337+ CELEBORN_CHECK (
338+ compressedSize <= static_cast <size_t >(std::numeric_limits<int >::max ()),
339+ fmt::format (
340+ " Compressed size {} exceeds maximum supported size {}" ,
341+ compressedSize,
342+ std::numeric_limits<int >::max ()));
343+
344+ lengthToWrite = static_cast <int >(compressedSize);
345+ dataToWrite = compressedBuffer.get ();
346+ }
347+
348+ CELEBORN_CHECK (
349+ static_cast <size_t >(lengthToWrite) <=
350+ std::numeric_limits<size_t >::max () - kBatchHeaderSize ,
351+ fmt::format (
352+ " Buffer size {} + header {} would overflow" ,
353+ lengthToWrite,
354+ kBatchHeaderSize ));
355+
356+ auto writeBuffer = memory::ByteBuffer::createWriteOnly (
357+ kBatchHeaderSize + static_cast <size_t >(lengthToWrite));
358+ writeBuffer->writeLE <int >(mapId);
359+ writeBuffer->writeLE <int >(attemptId);
360+ writeBuffer->writeLE <int >(nextBatchId);
361+ writeBuffer->writeLE <int >(lengthToWrite);
362+ writeBuffer->writeFromBuffer (
363+ dataToWrite, 0 , static_cast <size_t >(lengthToWrite));
364+
365+ auto body = memory::ByteBuffer::toReadOnly (std::move (writeBuffer));
366+ int bodySize = static_cast <int >(body->remainingSize ());
367+
368+ // Build address pair for the target worker.
369+ std::string primaryAddr = partitionLocation->hostAndPushPort ();
370+ std::string replicaAddr;
371+ if (partitionLocation->hasPeer ()) {
372+ replicaAddr = partitionLocation->getPeer ()->hostAndPushPort ();
373+ }
374+ PushState::AddressPair addressPair (primaryAddr, replicaAddr);
375+
376+ int pushBufferMaxSize = conf_->clientPushBufferMaxSize ();
377+ bool exceedsThreshold = pushState->addBatchData (
378+ addressPair, partitionLocation, nextBatchId, std::move (body),
379+ pushBufferMaxSize);
380+
381+ if (exceedsThreshold) {
382+ auto hostAndPushPort = partitionLocation->hostAndPushPort ();
383+ limitMaxInFlight (mapKey, *pushState, hostAndPushPort);
384+
385+ auto dataBatches = pushState->takeDataBatches (addressPair);
386+ if (dataBatches) {
387+ auto batches = dataBatches->requireBatches ();
388+ if (!batches.empty ()) {
389+ doPushMergedData (
390+ addressPair,
391+ shuffleId,
392+ mapId,
393+ attemptId,
394+ numMappers,
395+ numPartitions,
396+ std::move (batches),
397+ pushState,
398+ conf_->clientPushMaxReviveTimes ());
399+ }
400+ }
401+ }
402+
403+ return bodySize;
404+ }
405+
406+ void ShuffleClientImpl::pushMergedData (
407+ int shuffleId,
408+ int mapId,
409+ int attemptId) {
410+ const auto mapKey = utils::makeMapKey (shuffleId, mapId, attemptId);
411+ auto pushState = getPushState (mapKey);
412+ int pushBufferMaxSize = conf_->clientPushBufferMaxSize ();
413+
414+ // Collect all address pairs first to avoid modifying map during iteration.
415+ std::vector<PushState::AddressPair> addressPairs;
416+ pushState->forEachBatchEntry (
417+ [&](const PushState::AddressPair& addr,
418+ std::shared_ptr<DataBatches> /* unused */ ) {
419+ addressPairs.push_back (addr);
420+ });
421+
422+ for (const auto & addressPair : addressPairs) {
423+ auto dataBatches = pushState->takeDataBatches (addressPair);
424+ if (!dataBatches) {
425+ continue ;
426+ }
427+ // Drain in chunks of pushBufferMaxSize.
428+ while (dataBatches->getTotalSize () > 0 ) {
429+ auto hostAndPushPort = addressPair.first ;
430+ limitMaxInFlight (mapKey, *pushState, hostAndPushPort);
431+ auto batches = dataBatches->requireBatches (pushBufferMaxSize);
432+ if (batches.empty ()) {
433+ break ;
434+ }
435+ doPushMergedData (
436+ addressPair,
437+ shuffleId,
438+ mapId,
439+ attemptId,
440+ 0 , // numMappers not needed for flush
441+ 0 , // numPartitions not needed for flush
442+ std::move (batches),
443+ pushState,
444+ conf_->clientPushMaxReviveTimes ());
445+ }
446+ }
447+ }
448+
449+ void ShuffleClientImpl::doPushMergedData (
450+ const PushState::AddressPair& addressPair,
451+ int shuffleId,
452+ int mapId,
453+ int attemptId,
454+ int numMappers,
455+ int numPartitions,
456+ std::vector<DataBatch> batches,
457+ std::shared_ptr<PushState> pushState,
458+ int remainReviveTimes) {
459+ int groupedBatchId = pushState->nextBatchId ();
460+
461+ // Build partitionUniqueIds, batchOffsets, and concatenated body.
462+ std::vector<std::string> partitionUniqueIds;
463+ std::vector<int > batchOffsets;
464+ int totalBytes = 0 ;
465+ for (const auto & batch : batches) {
466+ partitionUniqueIds.push_back (batch.loc ->uniqueId ());
467+ batchOffsets.push_back (totalBytes);
468+ totalBytes += static_cast <int >(batch.body ->remainingSize ());
469+ }
470+
471+ // Register in-flight.
472+ auto hostAndPushPort = addressPair.first ;
473+ pushState->addBatch (groupedBatchId, totalBytes, hostAndPushPort);
474+
475+ // Concatenate all bodies into a single buffer.
476+ auto combinedBuffer = memory::ByteBuffer::createWriteOnly (totalBytes);
477+ for (const auto & batch : batches) {
478+ auto bodyClone = batch.body ->clone ();
479+ size_t batchSize = bodyClone->remainingSize ();
480+ auto tmpBuf = std::make_unique<uint8_t []>(batchSize);
481+ bodyClone->readToBuffer (tmpBuf.get (), batchSize);
482+ combinedBuffer->writeFromBuffer (tmpBuf.get (), 0 , batchSize);
483+ }
484+ auto combinedBody =
485+ memory::ByteBuffer::toReadOnly (std::move (combinedBuffer));
486+
487+ // Get host/port from the first batch's location before moving batches.
488+ const auto & firstLoc = batches[0 ].loc ;
489+ const auto sendHost = firstLoc->host ;
490+ const auto sendPort = firstLoc->pushPort ;
491+
492+ // Build PushMergedData message.
493+ const auto shuffleKey = utils::makeShuffleKey (appUniqueId_, shuffleId);
494+ network::PushMergedData pushMergedData (
495+ network::Message::nextRequestId (),
496+ protocol::PartitionLocation::Mode::PRIMARY,
497+ shuffleKey,
498+ std::move (partitionUniqueIds),
499+ std::move (batchOffsets),
500+ std::move (combinedBody));
501+
502+ // Build callback.
503+ auto callback = PushMergedDataCallback::create (
504+ shuffleId,
505+ mapId,
506+ attemptId,
507+ numMappers,
508+ numPartitions,
509+ utils::makeMapKey (shuffleId, mapId, attemptId),
510+ groupedBatchId,
511+ std::move (batches),
512+ pushState,
513+ weak_from_this (),
514+ remainReviveTimes,
515+ addressPair);
516+
517+ // Send.
518+ auto client = clientFactory_->createClient (sendHost, sendPort);
519+ client->pushMergedDataAsync (
520+ pushMergedData, conf_->clientPushDataTimeout (), callback);
521+ }
522+
523+ void ShuffleClientImpl::submitRetryPushMergedData (
524+ int shuffleId,
525+ int mapId,
526+ int attemptId,
527+ int numMappers,
528+ int numPartitions,
529+ int groupedBatchId,
530+ std::vector<DataBatch> batches,
531+ std::shared_ptr<PushState> pushState,
532+ const std::vector<std::shared_ptr<protocol::ReviveRequest>>&
533+ reviveRequests,
534+ int remainReviveTimes,
535+ long dueTimeMs) {
536+ long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis ();
537+ long accumulatedTimeMs = 0 ;
538+ const long deltaMs = 50 ;
539+
540+ // Wait for all revive requests to complete.
541+ bool allDone = false ;
542+ while (!allDone && accumulatedTimeMs <= reviveWaitTimeMs) {
543+ allDone = true ;
544+ for (const auto & req : reviveRequests) {
545+ if (req->reviveStatus .load () ==
546+ protocol::StatusCode::REVIVE_INITIALIZED) {
547+ allDone = false ;
548+ break ;
549+ }
550+ }
551+ if (!allDone) {
552+ std::this_thread::sleep_for (utils::MS (deltaMs));
553+ accumulatedTimeMs += deltaMs;
554+ }
555+ }
556+
557+ std::string oldHostAndPushPort;
558+ if (!reviveRequests.empty () && reviveRequests[0 ]->loc ) {
559+ oldHostAndPushPort = reviveRequests[0 ]->loc ->hostAndPushPort ();
560+ }
561+
562+ if (mapperEnded (shuffleId, mapId)) {
563+ pushState->removeBatch (groupedBatchId, oldHostAndPushPort);
564+ return ;
565+ }
566+
567+ // Check if all revives succeeded.
568+ for (const auto & req : reviveRequests) {
569+ if (req->reviveStatus .load () != protocol::StatusCode::SUCCESS) {
570+ pushState->setException (std::make_unique<std::runtime_error>(
571+ " Revive failed for pushMergedData retry" ));
572+ return ;
573+ }
574+ }
575+
576+ // Remove old in-flight tracking.
577+ pushState->removeBatch (groupedBatchId, oldHostAndPushPort);
578+
579+ // Regroup batches by new partition locations.
580+ auto locationMapOptional = partitionLocationMaps_.get (shuffleId);
581+ CELEBORN_CHECK (locationMapOptional.has_value ());
582+ auto locationMap = locationMapOptional.value ();
583+
584+ // Build a map from partitionId to new location.
585+ std::unordered_map<int , std::shared_ptr<const protocol::PartitionLocation>>
586+ newLocationMap;
587+ for (const auto & req : reviveRequests) {
588+ auto newLocOpt = locationMap->get (req->partitionId );
589+ if (newLocOpt.has_value ()) {
590+ newLocationMap[req->partitionId ] = newLocOpt.value ();
591+ }
592+ }
593+
594+ // Group batches by new primary address.
595+ using AddressPair = PushState::AddressPair;
596+ struct AddressPairHasher {
597+ size_t operator ()(const AddressPair& p) const {
598+ size_t h1 = std::hash<std::string>{}(p.first );
599+ size_t h2 = std::hash<std::string>{}(p.second );
600+ return h1 ^ (h2 << 1 );
601+ }
602+ };
603+ std::unordered_map<AddressPair, std::vector<DataBatch>, AddressPairHasher>
604+ regrouped;
605+
606+ for (auto & batch : batches) {
607+ auto it = newLocationMap.find (batch.loc ->id );
608+ if (it == newLocationMap.end ()) {
609+ continue ;
610+ }
611+ auto newLoc = it->second ;
612+ std::string primaryAddr = newLoc->hostAndPushPort ();
613+ std::string replicaAddr;
614+ if (newLoc->hasPeer ()) {
615+ replicaAddr = newLoc->getPeer ()->hostAndPushPort ();
616+ }
617+ AddressPair newAddr (primaryAddr, replicaAddr);
618+ batch.loc = newLoc;
619+ regrouped[newAddr].push_back (std::move (batch));
620+ }
621+
622+ LOG (INFO) << " Revive for push merged data succeeded for shuffle "
623+ << shuffleId << " map " << mapId << " attempt " << attemptId
624+ << " groupedBatchId " << groupedBatchId
625+ << " , regrouped into " << regrouped.size () << " groups." ;
626+
627+ // Re-push each group.
628+ for (auto & [addrPair, groupBatches] : regrouped) {
629+ CELEBORN_CHECK_GT (remainReviveTimes, 0 , " no remainReviveTime left" );
630+ doPushMergedData (
631+ addrPair,
632+ shuffleId,
633+ mapId,
634+ attemptId,
635+ numMappers,
636+ numPartitions,
637+ std::move (groupBatches),
638+ pushState,
639+ remainReviveTimes);
640+ }
641+ }
642+
272643void ShuffleClientImpl::mapperEnd (
273644 int shuffleId,
274645 int mapId,
0 commit comments