From 368b4aae4d7576c34b39b5eef546cfc0af2154df Mon Sep 17 00:00:00 2001 From: yew1eb Date: Wed, 11 Feb 2026 11:43:14 +0800 Subject: [PATCH 1/7] [CELEBORN-2264] Support cancel shuffle when write bytes exceeds threshold --- .../celeborn/client/ShuffleClientImpl.java | 12 ++++- .../celeborn/client/LifecycleManager.scala | 47 ++++++++++++++++--- .../celeborn/common/write/PushState.java | 11 +++++ common/src/main/proto/TransportMessages.proto | 1 + .../apache/celeborn/common/CelebornConf.scala | 19 ++++++++ .../protocol/message/ControlMessages.scala | 10 ++-- .../celeborn/common/util/UtilsSuite.scala | 4 +- .../LifecycleManagerReserveSlotsSuite.scala | 2 +- .../JavaCppHybridReadWriteTestBase.scala | 2 +- .../cluster/LocalReadByChunkOffsetsTest.scala | 2 +- .../cluster/PushMergedDataSplitSuite.scala | 2 +- .../deploy/cluster/ReadWriteTestBase.scala | 2 +- .../cluster/ReadWriteTestWithFailures.scala | 2 +- 13 files changed, 98 insertions(+), 18 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index be2bdf87d11..26f4eef6e7a 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -149,6 +149,8 @@ protected Compressor initialValue() { private final boolean dataPushFailureTrackingEnabled; + private final boolean shuffleWriteLimitEnabled; + public static class ReduceFileGroups { public Map> partitionGroups; public Map pushFailedBatches; @@ -211,6 +213,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u } authEnabled = conf.authEnabledOnClient(); dataPushFailureTrackingEnabled = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled(); + shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled(); // init rpc env rpcEnv = @@ -1067,6 +1070,10 @@ public int pushOrMergeData( Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length); System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length); + if (shuffleWriteLimitEnabled) { + pushState.addWrittenBytes(body.length); + } + if (doPush) { // check limit limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort()); @@ -1789,6 +1796,8 @@ private void mapEndInternal( long[] bytesPerPartition = pushState.getBytesWrittenPerPartition(shuffleIntegrityCheckEnabled, numPartitions); + long bytesWritten = pushState.getBytesWritten(); + MapperEndResponse response = lifecycleManagerRef.askSync( new MapperEnd( @@ -1801,7 +1810,8 @@ private void mapEndInternal( numPartitions, crc32PerPartition, bytesPerPartition, - SerdeVersion.V1), + SerdeVersion.V1, + bytesWritten), rpcMaxRetries, rpcRetryWait, ClassTag$.MODULE$.apply(MapperEndResponse.class)); diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index f48f3cd72ba..56a125b443d 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -23,9 +23,8 @@ import java.security.SecureRandom import java.util import java.util.{function, List => JList} import java.util.concurrent._ -import java.util.concurrent.atomic.{AtomicInteger, LongAdder} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder} import java.util.function.{BiConsumer, BiFunction, Consumer} - import scala.collection.JavaConverters._ import scala.collection.generic.CanBuildFrom import scala.collection.mutable @@ -33,11 +32,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.Random - import com.google.common.annotations.VisibleForTesting import com.google.common.cache.{Cache, CacheBuilder} import org.roaringbitmap.RoaringBitmap - import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} import org.apache.celeborn.client.listener.WorkerStatusListener import org.apache.celeborn.common.{CelebornConf, CommitMetadata} @@ -132,6 +129,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private val mockDestroyFailure = conf.testMockDestroySlotsFailure private val authEnabled = conf.authEnabledOnClient private var applicationMeta: ApplicationMeta = _ + + + private val shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled + private val shuffleWriteLimitThreshold = conf.shuffleWriteLimitThreshold + private val shuffleTotalWrittenBytes = JavaUtils.newConcurrentHashMap[Int, AtomicLong]() + @VisibleForTesting def workerSnapshots(shuffleId: Int): util.Map[String, ShufflePartitionLocationInfo] = shuffleAllocatedWorkers.get(shuffleId) @@ -439,7 +442,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions, crc32PerPartition, bytesWrittenPerPartition, - serdeVersion) => + serdeVersion, + bytesWritten) => logTrace(s"Received MapperEnd TaskEnd request, " + s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") val partitionType = getPartitionType(shuffleId) @@ -455,7 +459,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions, crc32PerPartition, bytesWrittenPerPartition, - serdeVersion) + serdeVersion, + bytesWritten) case PartitionType.MAP => handleMapPartitionEnd( context, @@ -933,7 +938,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions: Int, crc32PerPartition: Array[Int], bytesWrittenPerPartition: Array[Long], - serdeVersion: SerdeVersion): Unit = { + serdeVersion: SerdeVersion, + bytesWritten: Long): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = commitManager.finishMapperAttempt( @@ -945,6 +951,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions = numPartitions, crc32PerPartition = crc32PerPartition, bytesWrittenPerPartition = bytesWrittenPerPartition) + + if(mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) { + handleShuffleWriteLimitCheck(shuffleId, bytesWritten) + logDebug(s"Shuffle $shuffleId, mapId: $mapId, attemptId: $attemptId, " + + s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get(shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold") + } + if (mapperAttemptFinishedSuccess && allMapperFinished) { // last mapper finished. call mapper end logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" + @@ -2081,4 +2094,24 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } def getShuffleIdMapping = shuffleIdMapping + + private def handleShuffleWriteLimitCheck(shuffleId: Int, + writtenBytes: Long): Unit = { + if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold <= 0) return + + if (writtenBytes > 0) { + val totalBytesAccumulator = shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, _ => new AtomicLong(0)) + val currentTotalBytes = totalBytesAccumulator.addAndGet(writtenBytes) + + if (currentTotalBytes > shuffleWriteLimitThreshold) { + val reason = s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes" + logError(reason) + + cancelShuffleCallback match { + case Some(c) => c.accept(shuffleId, reason) + case _ => None + } + } + } + } } diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index 46714c4e826..03daf2cc32a 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -40,10 +40,13 @@ public class PushState { private final Map failedBatchMap; + private long bytesWritten; + public PushState(CelebornConf conf) { pushBufferMaxSize = conf.clientPushBufferMaxSize(); inFlightRequestTracker = new InFlightRequestTracker(conf, this); failedBatchMap = JavaUtils.newConcurrentHashMap(); + bytesWritten = 0; } public void cleanup() { @@ -136,4 +139,12 @@ public void addDataWithOffsetAndLength(int partitionId, byte[] data, int offset, commitMetadataMap.computeIfAbsent(partitionId, id -> new CommitMetadata()); commitMetadata.addDataWithOffsetAndLength(data, offset, length); } + + public void addWrittenBytes(int length) { + bytesWritten += length; + } + + public long getBytesWritten() { + return bytesWritten; + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..2807d3f5464 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -377,6 +377,7 @@ message PbMapperEnd { int32 numPartitions = 7; repeated int32 crc32PerPartition = 8; repeated int64 bytesWrittenPerPartition = 9; + int64 bytesWritten = 10; } message PbLocationPushFailedBatches { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 8b406c3e906..f10f155c415 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1674,6 +1674,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def secretRedactionPattern = get(SECRET_REDACTION_PATTERN) def containerInfoProviderClass = get(CONTAINER_INFO_PROVIDER) + + def shuffleWriteLimitEnabled: Boolean = get(SHUFFLE_WRITE_LIMIT_ENABLED) + + def shuffleWriteLimitThreshold: Long = get(SHUFFLE_WRITE_LIMIT_THRESHOLD) } object CelebornConf extends Logging { @@ -6854,4 +6858,19 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val SHUFFLE_WRITE_LIMIT_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.shuffle.write.limit.enabled") + .categories("client") + .doc("Enable shuffle write limit check to prevent cluster resource exhaustion.") + .version("0.7.0") + .booleanConf + .createWithDefault(false) + + val SHUFFLE_WRITE_LIMIT_THRESHOLD: ConfigEntry[Long] = + buildConf("celeborn.client.shuffle.write.limit.threshold") + .categories("client") + .doc("Shuffle write limit threshold, exceed to cancel oversized shuffle tasks.") + .version("0.7.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("5TB") } diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 36f164d697e..a56a0b1bad3 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -218,7 +218,8 @@ object ControlMessages extends Logging { numPartitions: Int, crc32PerPartition: Array[Int], bytesWrittenPerPartition: Array[Long], - serdeVersion: SerdeVersion) + serdeVersion: SerdeVersion, + bytesWritten: Long) extends MasterMessage case class ReadReducerPartitionEnd( @@ -737,7 +738,8 @@ object ControlMessages extends Logging { numPartitions, crc32PerPartition, bytesWrittenPerPartition, - serdeVersion) => + serdeVersion, + bytesWritten) => val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v) (k, resultValue) @@ -753,6 +755,7 @@ object ControlMessages extends Logging { .addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava) .addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map( java.lang.Long.valueOf).toSeq.asJava) + .setBytesWritten(bytesWritten) .build().toByteArray new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion) @@ -1248,7 +1251,8 @@ object ControlMessages extends Logging { pbMapperEnd.getNumPartitions, crc32Array, bytesWrittenPerPartitionArray, - message.getSerdeVersion) + message.getSerdeVersion, + pbMapperEnd.getBytesWritten) case READ_REDUCER_PARTITION_END_VALUE => val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload) diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index 8be472b6447..26d06777e25 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -160,7 +160,8 @@ class UtilsSuite extends CelebornFunSuite { 1, Array.emptyIntArray, Array.emptyLongArray, - SerdeVersion.V1) + SerdeVersion.V1, + 1) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId) @@ -172,6 +173,7 @@ class UtilsSuite extends CelebornFunSuite { assert(mapperEnd.numPartitions == mapperEndTrans.numPartitions) mapperEnd.crc32PerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.crc32PerPartition mapperEnd.bytesWrittenPerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.bytesWrittenPerPartition + assert(mapperEnd.bytesWritten == mapperEndTrans.bytesWritten) } test("validate HDFS compatible fs path") { diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala index 1626f2226b7..608d2093c6f 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala @@ -157,7 +157,7 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) - shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM) + shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM, getBytesWritten()) // partition(1) will not be split assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala index 325f9c8b77a..57bdc3a8384 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala @@ -116,7 +116,7 @@ trait JavaCppHybridReadWriteTestBase extends AnyFunSuite } } shuffleClient.pushMergedData(shuffleId, mapId, attemptId) - shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions) + shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions, getBytesWritten()) } // Launch cpp reader to read data, calculate result and write to specific result file. diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala index 24f745857ff..4070af0248e 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala @@ -116,7 +116,7 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1, 0) + shuffleClient.mapperEnd(1, 0, 0, 1, 0, getBytesWritten()) val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala index f095bb5754f..8e25ed8b71a 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala @@ -153,7 +153,7 @@ class PushMergedDataSplitSuite extends AnyFunSuite // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) - shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM) + shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM, getBytesWritten()) assert( partitionLocationMap.get(partitions(1)).getEpoch == 0 ) // means partition(1) will not be split diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index dec30f8c621..b99fee97e79 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -98,7 +98,7 @@ trait ReadWriteTestBase extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1, 1) + shuffleClient.mapperEnd(1, 0, 0, 1, 1, getBytesWritten()) val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index 06383806771..f6a2d19c8ba 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -92,7 +92,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1, 1) + shuffleClient.mapperEnd(1, 0, 0, 1, 1, getBytesWritten()) var duplicateBytesRead = new AtomicLong(0) val metricsCallback = new MetricsCallback { From ca5f9568e4e8efd0f59350ec5006cf9404701c68 Mon Sep 17 00:00:00 2001 From: yew1eb Date: Wed, 11 Feb 2026 15:02:09 +0800 Subject: [PATCH 2/7] up --- .../celeborn/client/LifecycleManager.scala | 18 +++++++++++------- .../LifecycleManagerReserveSlotsSuite.scala | 2 +- .../JavaCppHybridReadWriteTestBase.scala | 2 +- .../cluster/LocalReadByChunkOffsetsTest.scala | 2 +- .../cluster/PushMergedDataSplitSuite.scala | 2 +- .../deploy/cluster/ReadWriteTestBase.scala | 2 +- .../cluster/ReadWriteTestWithFailures.scala | 2 +- 7 files changed, 17 insertions(+), 13 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 56a125b443d..b0f10a1a99d 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -25,6 +25,7 @@ import java.util.{function, List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder} import java.util.function.{BiConsumer, BiFunction, Consumer} + import scala.collection.JavaConverters._ import scala.collection.generic.CanBuildFrom import scala.collection.mutable @@ -32,9 +33,11 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.Random + import com.google.common.annotations.VisibleForTesting import com.google.common.cache.{Cache, CacheBuilder} import org.roaringbitmap.RoaringBitmap + import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} import org.apache.celeborn.client.listener.WorkerStatusListener import org.apache.celeborn.common.{CelebornConf, CommitMetadata} @@ -130,7 +133,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private val authEnabled = conf.authEnabledOnClient private var applicationMeta: ApplicationMeta = _ - private val shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled private val shuffleWriteLimitThreshold = conf.shuffleWriteLimitThreshold private val shuffleTotalWrittenBytes = JavaUtils.newConcurrentHashMap[Int, AtomicLong]() @@ -952,10 +954,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends crc32PerPartition = crc32PerPartition, bytesWrittenPerPartition = bytesWrittenPerPartition) - if(mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) { + if (mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) { handleShuffleWriteLimitCheck(shuffleId, bytesWritten) logDebug(s"Shuffle $shuffleId, mapId: $mapId, attemptId: $attemptId, " + - s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get(shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold") + s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get( + shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold") } if (mapperAttemptFinishedSuccess && allMapperFinished) { @@ -2095,16 +2098,17 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends def getShuffleIdMapping = shuffleIdMapping - private def handleShuffleWriteLimitCheck(shuffleId: Int, - writtenBytes: Long): Unit = { + private def handleShuffleWriteLimitCheck(shuffleId: Int, writtenBytes: Long): Unit = { if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold <= 0) return if (writtenBytes > 0) { - val totalBytesAccumulator = shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, _ => new AtomicLong(0)) + val totalBytesAccumulator = + shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, (id: Int) => new AtomicLong(0)) val currentTotalBytes = totalBytesAccumulator.addAndGet(writtenBytes) if (currentTotalBytes > shuffleWriteLimitThreshold) { - val reason = s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes" + val reason = + s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes" logError(reason) cancelShuffleCallback match { diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala index 608d2093c6f..1626f2226b7 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala @@ -157,7 +157,7 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) - shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM, getBytesWritten()) + shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM) // partition(1) will not be split assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala index 57bdc3a8384..325f9c8b77a 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala @@ -116,7 +116,7 @@ trait JavaCppHybridReadWriteTestBase extends AnyFunSuite } } shuffleClient.pushMergedData(shuffleId, mapId, attemptId) - shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions, getBytesWritten()) + shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions) } // Launch cpp reader to read data, calculate result and write to specific result file. diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala index 4070af0248e..24f745857ff 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala @@ -116,7 +116,7 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1, 0, getBytesWritten()) + shuffleClient.mapperEnd(1, 0, 0, 1, 0) val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala index 8e25ed8b71a..f095bb5754f 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala @@ -153,7 +153,7 @@ class PushMergedDataSplitSuite extends AnyFunSuite // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) - shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM, getBytesWritten()) + shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM) assert( partitionLocationMap.get(partitions(1)).getEpoch == 0 ) // means partition(1) will not be split diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index b99fee97e79..dec30f8c621 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -98,7 +98,7 @@ trait ReadWriteTestBase extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1, 1, getBytesWritten()) + shuffleClient.mapperEnd(1, 0, 0, 1, 1) val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index f6a2d19c8ba..06383806771 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -92,7 +92,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1, 1, getBytesWritten()) + shuffleClient.mapperEnd(1, 0, 0, 1, 1) var duplicateBytesRead = new AtomicLong(0) val metricsCallback = new MetricsCallback { From 0a9435ee765e66e67bac25456b44a162ec648a14 Mon Sep 17 00:00:00 2001 From: yew1eb Date: Sun, 15 Feb 2026 20:29:45 +0800 Subject: [PATCH 3/7] up --- .../scala/org/apache/celeborn/client/LifecycleManager.scala | 4 ++++ .../scala/org/apache/celeborn/common/CelebornConf.scala | 6 +++--- docs/configuration/client.md | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index b0f10a1a99d..b34274f3980 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1274,6 +1274,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logInfo(s"[handleUnregisterShuffle] Wait for handleStageEnd complete costs ${cost}ms") } } + + if(shuffleWriteLimitEnabled) { + shuffleTotalWrittenBytes.remove(shuffleId) + } } // add shuffleKey to delay shuffle removal set diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index f10f155c415..b4a86de2495 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -6859,7 +6859,7 @@ object CelebornConf extends Logging { .createWithDefault(false) val SHUFFLE_WRITE_LIMIT_ENABLED: ConfigEntry[Boolean] = - buildConf("celeborn.client.shuffle.write.limit.enabled") + buildConf("celeborn.client.spark.shuffle.write.limit.enabled") .categories("client") .doc("Enable shuffle write limit check to prevent cluster resource exhaustion.") .version("0.7.0") @@ -6867,9 +6867,9 @@ object CelebornConf extends Logging { .createWithDefault(false) val SHUFFLE_WRITE_LIMIT_THRESHOLD: ConfigEntry[Long] = - buildConf("celeborn.client.shuffle.write.limit.threshold") + buildConf("celeborn.client.spark.shuffle.write.limit.threshold") .categories("client") - .doc("Shuffle write limit threshold, exceed to cancel oversized shuffle tasks.") + .doc("Shuffle write limit threshold, exceed to cancel oversized shuffle.") .version("0.7.0") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("5TB") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index fcd28ec2c9c..b8a29a21f99 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -121,6 +121,8 @@ license: | | celeborn.client.shuffle.rangeReadFilter.enabled | false | false | If a spark application have skewed partition, this value can set to true to improve performance. | 0.2.0 | celeborn.shuffle.rangeReadFilter.enabled | | celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | false | Whether to filter excluded worker when register shuffle. | 0.4.0 | | | celeborn.client.shuffle.reviseLostShuffles.enabled | false | false | Whether to revise lost shuffles. | 0.6.0 | | +| celeborn.client.shuffle.write.limit.enabled | false | false | Enable shuffle write limit check to prevent cluster resource exhaustion. | 0.7.0 | | +| celeborn.client.shuffle.write.limit.threshold | 5TB | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | | | celeborn.client.shuffleDataLostOnUnknownWorker.enabled | false | false | Whether to mark shuffle data lost when unknown worker is detected. | 0.6.3 | | | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | | celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 | | From 2966bf21d01b4d961b2b67b002356fa1a0ef70b9 Mon Sep 17 00:00:00 2001 From: yew1eb Date: Sun, 15 Feb 2026 21:58:30 +0800 Subject: [PATCH 4/7] up --- .../celeborn/client/LifecycleManager.scala | 2 +- tests/spark-it/pom.xml | 5 ++ .../tests/client/LifecycleManagerSuite.scala | 68 +++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index b34274f3980..f288f7911af 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1275,7 +1275,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } - if(shuffleWriteLimitEnabled) { + if (shuffleWriteLimitEnabled) { shuffleTotalWrittenBytes.remove(shuffleId) } } diff --git a/tests/spark-it/pom.xml b/tests/spark-it/pom.xml index a87594f5ce2..733b71ea8eb 100644 --- a/tests/spark-it/pom.xml +++ b/tests/spark-it/pom.xml @@ -187,6 +187,11 @@ minio test + + org.mockito + mockito-core + test + diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala index d34c79419b8..ce79262c717 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala @@ -18,7 +18,10 @@ package org.apache.celeborn.tests.client import java.util +import java.util.Collections +import java.util.function.BiConsumer +import org.mockito.Mockito import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.{interval, timeout} import org.scalatest.time.SpanSugar.convertIntToGrainOfTime @@ -26,7 +29,10 @@ import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.network.protocol.SerdeVersion +import org.apache.celeborn.common.protocol.message.ControlMessages.MapperEnd import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.service.deploy.MiniClusterFeature class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeature { @@ -126,6 +132,68 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu } } + test("CELEBORN-2264: Support cancel shuffle when write bytes exceeds threshold") { + val conf = celebornConf.clone + conf.set(CelebornConf.SHUFFLE_WRITE_LIMIT_ENABLED.key, "true") + .set(CelebornConf.SHUFFLE_WRITE_LIMIT_THRESHOLD.key, "2000") + val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) + val ctx = Mockito.mock(classOf[RpcCallContext]) + + // Custom BiConsumer callback to track if cancelShuffle is invoked + var isCancelShuffleInvoked = false + val cancelShuffleCallback = new BiConsumer[Integer, String] { + override def accept(shuffleId: Integer, reason: String): Unit = { + isCancelShuffleInvoked = true + } + } + lifecycleManager.registerCancelShuffleCallback(cancelShuffleCallback) + + // Scenario 1: Same mapper with multiple attempts (total bytes exceed threshold but no cancel) + val shuffleId = 0 + val mapId1 = 0 + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId1, + attemptId = 0, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1500)) + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId1, + attemptId = 1, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1500)) + assert(!isCancelShuffleInvoked) + + // Scenario 2: Total bytes of mapId1 + mapId2 exceed threshold (trigger cancel) + val mapId2 = 1 + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId2, + attemptId = 0, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1000)) + assert(isCancelShuffleInvoked) + } + override def afterAll(): Unit = { logInfo("all test complete , stop celeborn mini cluster") shutdownMiniCluster() From 9e32fcfeeefae8166e7e554de273a65241f3a09d Mon Sep 17 00:00:00 2001 From: yew1eb Date: Tue, 17 Feb 2026 16:39:45 +0800 Subject: [PATCH 5/7] up --- docs/configuration/client.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration/client.md b/docs/configuration/client.md index b8a29a21f99..81f06d8814a 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -121,8 +121,6 @@ license: | | celeborn.client.shuffle.rangeReadFilter.enabled | false | false | If a spark application have skewed partition, this value can set to true to improve performance. | 0.2.0 | celeborn.shuffle.rangeReadFilter.enabled | | celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | false | Whether to filter excluded worker when register shuffle. | 0.4.0 | | | celeborn.client.shuffle.reviseLostShuffles.enabled | false | false | Whether to revise lost shuffles. | 0.6.0 | | -| celeborn.client.shuffle.write.limit.enabled | false | false | Enable shuffle write limit check to prevent cluster resource exhaustion. | 0.7.0 | | -| celeborn.client.shuffle.write.limit.threshold | 5TB | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | | | celeborn.client.shuffleDataLostOnUnknownWorker.enabled | false | false | Whether to mark shuffle data lost when unknown worker is detected. | 0.6.3 | | | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | | celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 | | @@ -139,6 +137,8 @@ license: | | celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always use spark built-in shuffle implementation. This configuration is deprecated, consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 0.3.0 | celeborn.shuffle.forceFallback.enabled | | celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | | | celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | | +| celeborn.client.spark.shuffle.write.limit.enabled | false | false | Enable shuffle write limit check to prevent cluster resource exhaustion. | 0.7.0 | | +| celeborn.client.spark.shuffle.write.limit.threshold | 5TB | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | | | celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer | | celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure | | celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider | From 3fadf69fadddb6d4f10319a5d0c55409b60c9da6 Mon Sep 17 00:00:00 2001 From: yew1eb Date: Wed, 25 Feb 2026 16:39:36 +0800 Subject: [PATCH 6/7] up --- .../tests/client/LifecycleManagerSuite.scala | 57 +++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala index ce79262c717..2cb77e93927 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala @@ -132,14 +132,14 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu } } - test("CELEBORN-2264: Support cancel shuffle when write bytes exceeds threshold") { + test("CELEBORN-2264: Support cancel shuffle when write bytes exceeds threshold (enabled)") { val conf = celebornConf.clone conf.set(CelebornConf.SHUFFLE_WRITE_LIMIT_ENABLED.key, "true") .set(CelebornConf.SHUFFLE_WRITE_LIMIT_THRESHOLD.key, "2000") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val ctx = Mockito.mock(classOf[RpcCallContext]) - // Custom BiConsumer callback to track if cancelShuffle is invoked + // Track cancelShuffle invocation var isCancelShuffleInvoked = false val cancelShuffleCallback = new BiConsumer[Integer, String] { override def accept(shuffleId: Integer, reason: String): Unit = { @@ -148,7 +148,7 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu } lifecycleManager.registerCancelShuffleCallback(cancelShuffleCallback) - // Scenario 1: Same mapper with multiple attempts (total bytes exceed threshold but no cancel) + // Same mapper multiple attempts (total > threshold, no cancel) val shuffleId = 0 val mapId1 = 0 lifecycleManager.receiveAndReply(ctx)(MapperEnd( @@ -177,7 +177,7 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu bytesWritten = 1500)) assert(!isCancelShuffleInvoked) - // Scenario 2: Total bytes of mapId1 + mapId2 exceed threshold (trigger cancel) + // mapId1 + mapId2 exceed threshold (trigger cancel) val mapId2 = 1 lifecycleManager.receiveAndReply(ctx)(MapperEnd( shuffleId = shuffleId, @@ -194,6 +194,55 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu assert(isCancelShuffleInvoked) } + test("CELEBORN-2264: Support cancel shuffle when write bytes exceeds threshold (disable)") { + val conf = celebornConf.clone + conf.set(CelebornConf.SHUFFLE_WRITE_LIMIT_ENABLED.key, "false") + .set(CelebornConf.SHUFFLE_WRITE_LIMIT_THRESHOLD.key, "2000") + val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) + val ctx = Mockito.mock(classOf[RpcCallContext]) + + // Track cancelShuffle invocation + var isCancelShuffleInvoked = false + val cancelShuffleCallback = new BiConsumer[Integer, String] { + override def accept(shuffleId: Integer, reason: String): Unit = { + isCancelShuffleInvoked = true + } + } + lifecycleManager.registerCancelShuffleCallback(cancelShuffleCallback) + + // Cumulative bytes exceed threshold (no cancel when disabled) + val shuffleId = 0 + val mapId1 = 0 + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId1, + attemptId = 0, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1500)) + + val mapId2 = 1 + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId2, + attemptId = 0, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1500)) + + assert(!isCancelShuffleInvoked) + } + override def afterAll(): Unit = { logInfo("all test complete , stop celeborn mini cluster") shutdownMiniCluster() From dd6ea657a667f62e2fe852c230a1957d32bcb156 Mon Sep 17 00:00:00 2001 From: yew1eb Date: Thu, 26 Feb 2026 15:55:11 +0800 Subject: [PATCH 7/7] up --- .../apache/celeborn/client/LifecycleManager.scala | 12 +++++++----- .../org/apache/celeborn/common/CelebornConf.scala | 6 +++--- docs/configuration/client.md | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index f288f7911af..21527c5dcf4 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -957,8 +957,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends if (mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) { handleShuffleWriteLimitCheck(shuffleId, bytesWritten) logDebug(s"Shuffle $shuffleId, mapId: $mapId, attemptId: $attemptId, " + - s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get( - shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold") + s"map written bytes: ${Utils.bytesToString(bytesWritten)}, shuffle total written bytes: ${Utils.bytesToString(shuffleTotalWrittenBytes.get( + shuffleId).get())}, write limit threshold: ${Utils.bytesToString( + shuffleWriteLimitThreshold.getOrElse(0L))}") } if (mapperAttemptFinishedSuccess && allMapperFinished) { @@ -2103,16 +2104,17 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends def getShuffleIdMapping = shuffleIdMapping private def handleShuffleWriteLimitCheck(shuffleId: Int, writtenBytes: Long): Unit = { - if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold <= 0) return + if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold.isEmpty) return if (writtenBytes > 0) { val totalBytesAccumulator = shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, (id: Int) => new AtomicLong(0)) val currentTotalBytes = totalBytesAccumulator.addAndGet(writtenBytes) - if (currentTotalBytes > shuffleWriteLimitThreshold) { + if (currentTotalBytes > shuffleWriteLimitThreshold.get) { val reason = - s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes" + s"Shuffle $shuffleId exceeded write limit threshold: current total ${Utils.bytesToString( + currentTotalBytes)}, max allowed ${Utils.bytesToString(shuffleWriteLimitThreshold.get)}" logError(reason) cancelShuffleCallback match { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index b4a86de2495..8ea88d97558 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1677,7 +1677,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def shuffleWriteLimitEnabled: Boolean = get(SHUFFLE_WRITE_LIMIT_ENABLED) - def shuffleWriteLimitThreshold: Long = get(SHUFFLE_WRITE_LIMIT_THRESHOLD) + def shuffleWriteLimitThreshold: Option[Long] = get(SHUFFLE_WRITE_LIMIT_THRESHOLD) } object CelebornConf extends Logging { @@ -6866,11 +6866,11 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) - val SHUFFLE_WRITE_LIMIT_THRESHOLD: ConfigEntry[Long] = + val SHUFFLE_WRITE_LIMIT_THRESHOLD: OptionalConfigEntry[Long] = buildConf("celeborn.client.spark.shuffle.write.limit.threshold") .categories("client") .doc("Shuffle write limit threshold, exceed to cancel oversized shuffle.") .version("0.7.0") .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("5TB") + .createOptional } diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 81f06d8814a..0d6a59f194f 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -138,7 +138,7 @@ license: | | celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | | | celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | | | celeborn.client.spark.shuffle.write.limit.enabled | false | false | Enable shuffle write limit check to prevent cluster resource exhaustion. | 0.7.0 | | -| celeborn.client.spark.shuffle.write.limit.threshold | 5TB | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | | +| celeborn.client.spark.shuffle.write.limit.threshold | <undefined> | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | | | celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer | | celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure | | celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider |