From afe48aa3c70d404ebaa5ab8fed5e832b98db7a8c Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 10 Feb 2026 14:26:28 +0100 Subject: [PATCH 1/8] Make FallbackStorage.read return a truly lazy ManagedBuffer --- .../spark/storage/FallbackStorage.scala | 66 +++++++++++++++---- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala index 19cdebd80ebf9..bb1f2687ebbbf 100644 --- a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala +++ b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala @@ -17,12 +17,13 @@ package org.apache.spark.storage -import java.io.DataInputStream +import java.io.{DataInputStream, InputStream} import java.nio.ByteBuffer import scala.concurrent.Future import scala.reflect.ClassTag +import io.netty.buffer.Unpooled import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -31,8 +32,8 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config.{STORAGE_DECOMMISSION_FALLBACK_STORAGE_CLEANUP, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH} -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.util.JavaUtils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.{JavaUtils, LimitedInputStream} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID @@ -114,6 +115,51 @@ private[storage] class FallbackStorageRpcEndpointRef(conf: SparkConf, hadoopConf } } +/** + * Lazily reads a segment of an Hadoop FileSystem file, i.e. when createInputStream is called. + * @param filesystem hadoop filesystem + * @param file path of the file + * @param offset offset of the segment + * @param length size of the segmetn + */ +private[storage] class FileSystemSegmentManagedBuffer( + filesystem: FileSystem, + file: Path, + offset: Long, + length: Long) extends ManagedBuffer with Logging { + + override def size(): Long = length + + override def nioByteBuffer(): ByteBuffer = { + Utils.tryWithResource(createInputStream()) { in => + ByteBuffer.wrap(in.readAllBytes()) + } + } + + override def createInputStream(): InputStream = { + val startTimeNs = System.nanoTime() + try { + val in = filesystem.open(file) + in.seek(offset) + new LimitedInputStream(in, length) + } finally { + logDebug(s"Took ${(System.nanoTime() - startTimeNs) / (1000 * 1000)}ms") + } + } + + override def retain(): ManagedBuffer = this + + override def release(): ManagedBuffer = this + + override def convertToNetty(): AnyRef = { + Unpooled.wrappedBuffer(nioByteBuffer()); + } + + override def convertToNettyForSsl(): AnyRef = { + Unpooled.wrappedBuffer(nioByteBuffer()); + } +} + private[spark] object FallbackStorage extends Logging { /** We use one block manager id as a place holder. */ val FALLBACK_BLOCK_MANAGER_ID: BlockManagerId = BlockManagerId("fallback", "remote", 7337) @@ -168,7 +214,9 @@ private[spark] object FallbackStorage extends Logging { } /** - * Read a ManagedBuffer. + * Read a block as ManagedBuffer. This reads the index for offset and block size + * but does not read the actual block data. Those data are later read when calling + * createInputStream() on the returned ManagedBuffer. */ def read(conf: SparkConf, blockId: BlockId): ManagedBuffer = { logInfo(log"Read ${MDC(BLOCK_ID, blockId)}") @@ -202,15 +250,7 @@ private[spark] object FallbackStorage extends Logging { val hash = JavaUtils.nonNegativeHash(name) val dataFile = new Path(fallbackPath, s"$appId/$shuffleId/$hash/$name") val size = nextOffset - offset - logDebug(s"To byte array $size") - val array = new Array[Byte](size.toInt) - val startTimeNs = System.nanoTime() - Utils.tryWithResource(fallbackFileSystem.open(dataFile)) { f => - f.seek(offset) - f.readFully(array) - logDebug(s"Took ${(System.nanoTime() - startTimeNs) / (1000 * 1000)}ms") - } - new NioManagedBuffer(ByteBuffer.wrap(array)) + new FileSystemSegmentManagedBuffer(fallbackFileSystem, dataFile, offset, size) } } } From ff9e3b371ddcbf8572600f1924725f155c9f61b5 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 10 Feb 2026 20:48:44 +0100 Subject: [PATCH 2/8] Add unit tests --- .../spark/storage/FallbackStorageSuite.scala | 50 ++++++- .../ShuffleBlockFetcherIteratorSuite.scala | 137 +++++++++++++++--- 2 files changed, 167 insertions(+), 20 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala index 6df8bc85b5104..ad09958baca0a 100644 --- a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala @@ -22,10 +22,11 @@ import java.nio.file.Files import scala.concurrent.duration._ import scala.util.Random +import io.netty.buffer.ByteBuf import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataInputStream, LocalFileSystem, Path, PositionedReadable, Seekable} +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, LocalFileSystem, Path, PositionedReadable, Seekable} import org.mockito.{ArgumentMatchers => mc} -import org.mockito.Mockito.{mock, never, verify, when} +import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TestUtils} @@ -110,7 +111,9 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { intercept[java.io.EOFException] { FallbackStorage.read(conf, ShuffleBlockId(1, 1L, 0)) } - FallbackStorage.read(conf, ShuffleBlockId(1, 2L, 0)) + val readResult = FallbackStorage.read(conf, ShuffleBlockId(1, 2L, 0)) + assert(readResult.isInstanceOf[FileSystemSegmentManagedBuffer]) + readResult.createInputStream().close() } test("SPARK-39200: fallback storage APIs - readFully") { @@ -155,9 +158,49 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { assert(fallbackStorage.exists(1, ShuffleDataBlockId(1, 2L, NOOP_REDUCE_ID).name)) val readResult = FallbackStorage.read(conf, ShuffleBlockId(1, 2L, 0)) + assert(readResult.isInstanceOf[FileSystemSegmentManagedBuffer]) assert(readResult.nioByteBuffer().array().sameElements(content)) } + test("SPARK-55469: FileSystemSegmentManagedBuffer reads block data lazily") { + withTempDir { dir => + val fs = FileSystem.getLocal(new Configuration()) + val file = new Path(dir.getAbsolutePath, "file") + val data = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + tryWithResource(fs.create(file)) { os => os.write(data) } + + Seq((0, 4), (1, 2), (4, 4), (7, 2), (8, 0)).foreach { case (offset, length) => + val clue = s"offset: $offset, length: $length" + + // creating the managed buffer does not open the file + val mfs = spy(fs) + val buf = new FileSystemSegmentManagedBuffer(mfs, file, offset, length) + verify(mfs, never()).open(mc.any[Path]()) + assert(buf.size() === length, clue) + + // creating the input stream opens the file + { + val bytes = buf.createInputStream().readAllBytes() + verify(mfs, times(1)).open(mc.any[Path]()) + assert(bytes.mkString(",") === data.slice(offset, offset + length).mkString(","), clue) + } + + // getting a NIO ByteBuffer opens the file again + { + val bytes = buf.nioByteBuffer().array() + verify(mfs, times(2)).open(mc.any[Path]()) + assert(bytes.mkString(",") === data.slice(offset, offset + length).mkString(","), clue) + } + + // getting a Netty ByteBufs opens the file again and again + assert(buf.convertToNetty().asInstanceOf[ByteBuf].release() === length > 0, clue) + verify(mfs, times(3)).open(mc.any[Path]()) + assert(buf.convertToNettyForSsl().asInstanceOf[ByteBuf].release() === length > 0, clue) + verify(mfs, times(4)).open(mc.any[Path]()) + } + } + } + test("SPARK-34142: fallback storage API - cleanUp app") { withTempDir { dir => Seq(true, false).foreach { cleanUp => @@ -372,6 +415,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { } } } + class ReadPartialInputStream(val in: FSDataInputStream) extends InputStream with Seekable with PositionedReadable { override def read: Int = in.read diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 211de2e8729eb..08220a26010fc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -32,7 +32,7 @@ import scala.concurrent.Future import io.netty.util.internal.OutOfDirectMemoryError import org.apache.logging.log4j.Level import org.mockito.ArgumentMatchers.{any, eq => meq} -import org.mockito.Mockito.{doThrow, mock, times, verify, when} +import org.mockito.Mockito.{doThrow, mock, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.roaringbitmap.RoaringBitmap @@ -300,11 +300,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { } } - test("successful 3 local + 4 host local + 2 remote reads") { + test("successful 3 local + 4 host local + 2 remote + 2 fallback storage reads") { val blockManager = createMockBlockManager() - val localBmId = blockManager.blockManagerId // Make sure blockManager.getBlockData would return the blocks + val localBmId = blockManager.blockManagerId val localBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), @@ -334,19 +334,37 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) + // Make sure fallback storage blocks would return + val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID + val fallbackBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 9, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 10, 0) -> createMockManagedBuffer()) + fallbackBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + } + val iterator = createShuffleBlockIteratorWithDefaults( Map( localBmId -> toBlockList(localBlocks.keys, 1L, 0), remoteBmId -> toBlockList(remoteBlocks.keys, 1L, 1), - hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1) + hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1), + fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L, 1) ), blockManager = Some(blockManager) ) - // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getLocalBlockData(any()) + // 3 local blocks and 2 fallback blocks fetched in initialization + verify(blockManager, times(3 + 2)).getLocalBlockData(any()) + + // SPARK-55469: but buffer data have never been materialized + fallbackBlocks.values.foreach { mockBuf => + verify(mockBuf, never()).nioByteBuffer() + verify(mockBuf, never()).createInputStream() + verify(mockBuf, never()).convertToNetty() + verify(mockBuf, never()).convertToNettyForSsl() + } - val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks + val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks ++ fallbackBlocks for (i <- 0 until allBlocks.size) { assert(iterator.hasNext, s"iterator should have ${allBlocks.size} elements but actually has $i elements") @@ -356,14 +374,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val mockBuf = allBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) } + assert(!iterator.hasNext) // 4 host-local locks fetched verify(blockManager, times(4)) .getHostLocalShuffleData(any(), meq(Array("local-dir"))) - // 2 remote blocks are read from the same block manager + // 2 remote blocks are read from the same block manager in one fetch verifyFetchBlocksInvocationCount(1) assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) + + // SPARK-55469: fallback buffer data have been materialized once + fallbackBlocks.values.foreach { mockBuf => + verify(mockBuf, never()).nioByteBuffer() + verify(mockBuf, times(1)).createInputStream() + verify(mockBuf, never()).convertToNetty() + verify(mockBuf, never()).convertToNettyForSsl() + } } test("error during accessing host local dirs for executors") { @@ -451,10 +478,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(!iterator.hasNext) } - test("fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote reads") { + test("fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote + " + + "2 fallback storage reads") { val blockManager = createMockBlockManager() - val localBmId = blockManager.blockManagerId + // Make sure blockManager.getBlockData would return the merged block + val localBmId = blockManager.blockManagerId val localBlocks = Seq[BlockId]( ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 0, 1), @@ -465,6 +494,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) } + // Make sure fallback storage would return the merged block + val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID + val fallbackBlocks = Seq[BlockId]( + ShuffleBlockId(0, 1, 0), + ShuffleBlockId(0, 1, 1)) + val mergedFallbackBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 1, 0, 2) -> createMockManagedBuffer()) + mergedFallbackBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + } + // Make sure remote blocks would return the merged block val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Seq[BlockId]( @@ -496,6 +536,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val iterator = createShuffleBlockIteratorWithDefaults( Map( localBmId -> toBlockList(localBlocks, 1L, 0), + fallbackBmId -> toBlockList(fallbackBlocks, 1L, 1), remoteBmId -> toBlockList(remoteBlocks, 1L, 1), hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1) ), @@ -503,23 +544,41 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { doBatchFetch = true ) - // 3 local blocks batch fetched in initialization - verify(blockManager, times(1)).getLocalBlockData(any()) + // 1 local merge block and 1 fallback merge block fetched in initialization + verify(blockManager, times(1 + 1)).getLocalBlockData(any()) - val allBlocks = mergedLocalBlocks ++ mergedRemoteBlocks ++ mergedHostLocalBlocks - for (i <- 0 until 3) { - assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") + // SPARK-55469: but buffer data have never been materialized + mergedFallbackBlocks.values.foreach { mockBuf => + verify(mockBuf, never()).nioByteBuffer() + verify(mockBuf, never()).createInputStream() + verify(mockBuf, never()).convertToNetty() + verify(mockBuf, never()).convertToNettyForSsl() + } + + val allBlocks = mergedLocalBlocks ++ mergedFallbackBlocks ++ mergedRemoteBlocks ++ + mergedHostLocalBlocks + for (i <- 0 until 4) { + assert(iterator.hasNext, s"iterator should have 4 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() verifyFetchBlocksInvocationCount(1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) } + assert(!iterator.hasNext) - // 4 host-local locks fetched + // 1 merged host-local locks fetched verify(blockManager, times(1)) .getHostLocalShuffleData(any(), meq(Array("local-dir"))) + // SPARK-55469: merged fallback buffer data have been materialized once + mergedFallbackBlocks.values.foreach { mockBuf => + verify(mockBuf, never()).nioByteBuffer() + verify(mockBuf, times(1)).createInputStream() + verify(mockBuf, never()).convertToNetty() + verify(mockBuf, never()).convertToNettyForSsl() + } + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } @@ -1051,6 +1110,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val mockBuf = remoteBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) } + assert(!iterator.hasNext) // 1st fetch request (contains 1 block) would fail due to Netty OOM // 2nd fetch request retry the block of the 1st fetch request @@ -1091,6 +1151,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val mockBuf = remoteBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) } + assert(!iterator.hasNext) // 1st fetch request (contains 3 blocks) would fail on the someone block due to Netty OOM // but succeed for the remaining blocks @@ -2037,9 +2098,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { test("SPARK-52395: Fast fail when failed to get host local dirs") { val blockManager = createMockBlockManager() - val localBmId = blockManager.blockManagerId // Make sure blockManager.getBlockData would return the blocks + val localBmId = blockManager.blockManagerId val localBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) @@ -2076,4 +2137,46 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) assert(!iterator.hasNext) } + + test("Fast fail when failed to get fallback storage blocks") { + val blockManager = createMockBlockManager() + + // Make sure blockManager.getBlockData would return the blocks + val localBmId = blockManager.blockManagerId + val localBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) + localBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + } + + // Make sure fallback storage would return the blocks + val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID + val fallbackBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer()) + fallbackBlocks.take(1).foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + } + fallbackBlocks.takeRight(1).foreach { case (blockId, _) => + doThrow(new RuntimeException("Cannot read from fallback storage")) + .when(blockManager).getLocalBlockData(meq(blockId)) + } + + val iterator = createShuffleBlockIteratorWithDefaults( + Map( + localBmId -> toBlockList(localBlocks.keys, 1L, 0), + fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L, 1) + ), + blockManager = Some(blockManager) + ) + + // Fetch failure should be placed in the head of results, exception should be thrown for the + // 1st instance. + intercept[FetchFailedException] { iterator.next() } + assert(iterator.next()._1 === ShuffleBlockId(0, 0, 0)) + assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) + assert(iterator.next()._1 === ShuffleBlockId(0, 2, 0)) + assert(!iterator.hasNext) + } } From f623d1018adc0502010a5b8ca9af109d6b4be485 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 10 Feb 2026 14:30:49 +0100 Subject: [PATCH 3/8] Separate fallback storage blocks and local blocks --- .../org/apache/spark/internal/LogKeys.java | 2 + .../spark/network/BlockDataManager.scala | 6 ++ .../apache/spark/storage/BlockManager.scala | 30 ++++--- .../storage/ShuffleBlockFetcherIterator.scala | 84 ++++++++++++++++--- 4 files changed, 102 insertions(+), 20 deletions(-) diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 59df0423fad26..c70382bf93190 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -232,6 +232,7 @@ public enum LogKeys implements LogKey { EXPR, EXPR_TERMS, EXTENDED_EXPLAIN_GENERATOR, + FALLBACK_STORAGE_BLOCKS_SIZE, FAILED_STAGE, FAILED_STAGE_NAME, FAILURES, @@ -473,6 +474,7 @@ public enum LogKeys implements LogKey { NUM_EXECUTOR_DESIRED, NUM_EXECUTOR_LAUNCH, NUM_EXECUTOR_TARGET, + NUM_FALLBACK_STORAGE_BLOCKS, NUM_FAILURES, NUM_FEATURES, NUM_FILES, diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 89177346a789a..6350c2eef785c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -53,6 +53,12 @@ trait BlockDataManager { */ def getLocalBlockData(blockId: BlockId): ManagedBuffer + /** + * Interface to get fallback storage block data. Throws an exception if the block cannot be found + * or cannot be read successfully. + */ + def getFallbackStorageBlockData(blockId: BlockId): ManagedBuffer + /** * Put the block locally, using the given storage level. * diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5fbc8dca74f68..3e69af157693a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -759,16 +759,7 @@ private[spark] class BlockManager( override def getLocalBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { logDebug(s"Getting local shuffle block ${blockId}") - try { - shuffleManager.shuffleBlockResolver.getBlockData(blockId) - } catch { - case e: IOException => - if (conf.get(config.STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH).isDefined) { - FallbackStorage.read(conf, blockId) - } else { - throw e - } - } + shuffleManager.shuffleBlockResolver.getBlockData(blockId) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -783,6 +774,25 @@ private[spark] class BlockManager( } } + /** + * Interface to get fallback storage block data. Throws an exception if the block cannot be found + * or cannot be read successfully. + */ + override def getFallbackStorageBlockData(blockId: BlockId): ManagedBuffer = { + require(conf.get(config.STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH).isDefined) + + if (blockId.isShuffle) { + logDebug(s"Getting fallback storage block ${blockId}") + FallbackStorage.read(conf, blockId) + } else { + // If this block manager receives a request for a block that it doesn't have then it's + // likely that the master has outdated block statuses for this block. Therefore, we send + // an RPC so that this block is marked as being unavailable from this block manager. + reportBlockStatus(blockId, BlockStatus.empty) + throw SparkCoreErrors.blockNotFoundError(blockId) + } + } + /** * Put the block locally, using the given storage level. * diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index cc552a2985f7e..0c3721e0eee33 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -391,6 +391,7 @@ final class ShuffleBlockFetcherIterator( private[this] def partitionBlocksByFetchMode( blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + fallbackStorageBlocks: mutable.LinkedHashSet[(BlockId, Int)], hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { @@ -402,13 +403,15 @@ final class ShuffleBlockFetcherIterator( // in order to limit the amount of data in flight val collectedRemoteRequests = new ArrayBuffer[FetchRequest] var localBlockBytes = 0L + var fallbackStorageBlockBytes = 0L var hostLocalBlockBytes = 0L var numHostLocalBlocks = 0 var pushMergedLocalBlockBytes = 0L val prevNumBlocksToFetch = numBlocksToFetch - val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId - val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + val localExecId = blockManager.blockManagerId.executorId + val fallbackExecId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localAndFallbackExecIds = Set(localExecId, fallbackExecId) for ((address, blockInfos) <- blocksByAddress) { checkBlockSizes(blockInfos) if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { @@ -420,12 +423,23 @@ final class ShuffleBlockFetcherIterator( } else { collectFetchRequests(address, blockInfos, collectedRemoteRequests) } - } else if (localExecIds.contains(address.executorId)) { + } else if (address.executorId == localExecId) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if (localAndFallbackExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + if (address.executorId == localExecId) { + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + fallbackStorageBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + fallbackStorageBlockBytes += mergedBlockInfos.map(_.size).sum + } } else if (blockManager.hostLocalDirManager.isDefined && address.host == blockManager.blockManagerId.host) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( @@ -445,13 +459,14 @@ final class ShuffleBlockFetcherIterator( } val (remoteBlockBytes, numRemoteBlocks) = collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) - val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + - pushMergedLocalBlockBytes + val totalBytes = localBlockBytes + fallbackStorageBlockBytes + remoteBlockBytes + + hostLocalBlockBytes + pushMergedLocalBlockBytes val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch - assert(blocksToFetchCurrentIteration == localBlocks.size + + assert(blocksToFetchCurrentIteration == localBlocks.size + fallbackStorageBlocks.size + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of fallback storage blocks ${fallbackStorageBlocks.size} + " + s"the number of host-local blocks ${numHostLocalBlocks} " + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + s"+ the number of remote blocks ${numRemoteBlocks} ") @@ -459,8 +474,10 @@ final class ShuffleBlockFetcherIterator( log"Getting ${MDC(NUM_BLOCKS, blocksToFetchCurrentIteration)} " + log"(${MDC(TOTAL_SIZE, Utils.bytesToString(totalBytes))}) non-empty blocks including " + log"${MDC(NUM_LOCAL_BLOCKS, localBlocks.size)} " + - log"(${MDC(LOCAL_BLOCKS_SIZE, Utils.bytesToString(localBlockBytes))}) local and " + - log"${MDC(NUM_HOST_LOCAL_BLOCKS, numHostLocalBlocks)} " + + log"(${MDC(LOCAL_BLOCKS_SIZE, Utils.bytesToString(localBlockBytes))}) " + + log"local and ${MDC(NUM_FALLBACK_STORAGE_BLOCKS, fallbackStorageBlocks.size)} " + + log"(${MDC(FALLBACK_STORAGE_BLOCKS_SIZE, Utils.bytesToString(fallbackStorageBlockBytes))}) " + + log"fallback storage and ${MDC(NUM_HOST_LOCAL_BLOCKS, numHostLocalBlocks)} " + log"(${MDC(HOST_LOCAL_BLOCKS_SIZE, Utils.bytesToString(hostLocalBlockBytes))}) " + log"host-local and ${MDC(NUM_PUSH_MERGED_LOCAL_BLOCKS, pushMergedLocalBlocks.size)} " + log"(${MDC(PUSH_MERGED_LOCAL_BLOCKS_SIZE, Utils.bytesToString(pushMergedLocalBlockBytes))})" + @@ -608,6 +625,42 @@ final class ShuffleBlockFetcherIterator( } } + /** + * Fetch the blocks from fallback storage while we are fetching remote blocks. + */ + private[this] def fetchFallbackStorageBlocks( + blocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching fallback storage blocks: ${blocks.mkString(", ")}") + val iter = blocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getFallbackStorageBlockData(blockId) + // TODO: add fallback storage metrics + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, + buf.size(), buf, false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError( + log"Error occurred while fetching local blocks, ${MDC(ERROR, ce.getMessage)}") + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.putFirst( + FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + private[this] def fetchHostLocalBlock( blockId: BlockId, mapIndex: Int, @@ -712,13 +765,15 @@ final class ShuffleBlockFetcherIterator( context.addTaskCompletionListener(onCompleteCallback) // Local blocks to fetch, excluding zero-sized blocks. val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val fallbackStorageBlocks = mutable.LinkedHashSet[(BlockId, Int)]() val hostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() // Partition blocks by the different fetch modes: local, host-local, push-merged-local and // remote blocks. val remoteRequests = partitionBlocksByFetchMode( - blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks) + blocksByAddress, localBlocks, fallbackStorageBlocks, hostLocalBlocksByExecutor, + pushMergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -738,6 +793,11 @@ final class ShuffleBlockFetcherIterator( // Get Local Blocks fetchLocalBlocks(localBlocks) logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + + // Get Fallback Storage Blocks + fetchFallbackStorageBlocks(fallbackStorageBlocks) + logDebug(s"Got fallback storage blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any withFetchWaitTimeTracked(fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)) pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) @@ -1287,17 +1347,21 @@ final class ShuffleBlockFetcherIterator( originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]): Unit = { val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalFallbackStorageBlocks = mutable.LinkedHashSet[(BlockId, Int)]() val originalHostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr, - originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) + originalLocalBlocks, originalFallbackStorageBlocks, originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(originalRemoteReqs) logInfo(log"Created ${MDC(NUM_REQUESTS, originalRemoteReqs.size)} fallback remote requests " + log"for push-merged") // fetch all the fallback blocks that are local. fetchLocalBlocks(originalLocalBlocks) + // fetch all the fallback blocks from fallback storage. + fetchFallbackStorageBlocks(originalFallbackStorageBlocks) // Merged local blocks should be empty during fallback assert(originalMergedLocalBlocks.isEmpty, "There should be zero push-merged blocks during fallback") From 5dbe284a97fe122c2e93523338558a7ad8938318 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 10 Feb 2026 16:41:29 +0100 Subject: [PATCH 4/8] Read from fallback storage multithreaded --- .../spark/internal/config/package.scala | 7 + .../shuffle/BlockStoreShuffleReader.scala | 1 + .../storage/ShuffleBlockFetcherIterator.scala | 198 ++++++++++++------ .../ShuffleBlockFetcherIteratorSuite.scala | 2 + 4 files changed, 149 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9fee7a36a0445..feddf7d12b340 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1501,6 +1501,13 @@ package object config { "maxRemoteBlockSizeFetchToMem cannot be larger than (Int.MaxValue - 512) bytes.") .createWithDefaultString("200m") + private[spark] val REDUCER_FALLBACK_STORAGE_READ_THREADS = + ConfigBuilder("spark.reducer.fallbackStorage.readThreads") + .doc("Number of threads used by the reducer to read shuffle blocks from fallback storage.") + .version("4.2.0") + .intConf + .createWithDefault(5) + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7918d1618eb06..381089ff8bdd2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -83,6 +83,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.REDUCER_FALLBACK_STORAGE_READ_THREADS), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0c3721e0eee33..b69cfefc8f109 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit} +import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.CheckedInputStream import javax.annotation.concurrent.GuardedBy @@ -27,6 +27,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.{Failure, Success} import io.netty.util.internal.OutOfDirectMemoryError @@ -37,12 +38,12 @@ import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, Utils} +import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, ThreadUtils, Utils} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -73,6 +74,7 @@ import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskComple * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before * throwing the fetch failure. + * @param fallbackStorageReadThreads number of threads reading concurrently from fallback storage * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param checksumEnabled whether the shuffle checksum is enabled. When enabled, Spark will try to * diagnose the cause of the block corruption. @@ -95,6 +97,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, val maxReqSizeShuffleToMem: Long, maxAttemptsOnNettyOOM: Int, + fallbackStorageReadThreads: Int, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, checksumEnabled: Boolean, @@ -139,9 +142,25 @@ final class ShuffleBlockFetcherIterator( */ @volatile private[this] var currentResult: SuccessFetchResult = null + /** + * Queue of fallback storage requests to issue; we'll pull requests off this gradually to make + * sure that the number of bytes and requests in flight is limited to maxBytesInFlight and + * maxReqsInFlight. + */ + private[this] val fallbackStorageRequests = new Queue[FallbackStorageRequest] + + /** + * Thread pool reading from fallback storage, first creating FallbackStorageRequest from + * block id and map index, then materializing requests to SuccessFetchResult. + */ + private[this] val fallbackStorageReadPool: ThreadPoolExecutor = + ThreadUtils.newDaemonFixedThreadPool(fallbackStorageReadThreads, "fallback-storage-read") + private[this] val fallbackStorageReadContext: ExecutionContextExecutor = + ExecutionContext.fromExecutor(fallbackStorageReadPool) + /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - * the number of bytes in flight is limited to maxBytesInFlight. + * the number of bytes and requests in flight is limited to maxBytesInFlight and maxReqsInFlight. */ private[this] val fetchRequests = new Queue[FetchRequest] @@ -259,6 +278,25 @@ final class ShuffleBlockFetcherIterator( logWarning(log"Failed to cleanup shuffle fetch temp file ${MDC(PATH, file.path())}") } } + fallbackStorageReadPool.shutdownNow() + } + + private[this] def createFallbackStorageRequest(blockId: BlockId, mapIndex: Int): Unit = { + Future { + try { + val block = blockManager.getFallbackStorageBlockData(blockId) + val request = FallbackStorageRequest(blockId, mapIndex, block) + results.put(PreparedFallbackStorageRequestResult(request)) + } catch { + case e: Throwable => + logError(log"Failed to prepare request to read block ${MDC(BLOCK_ID, blockId)} " + + log"from fallback storage", e) + results.put( + FailureFetchResult(blockId, mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e)) + // stop processing any further fallback storage requests + fallbackStorageReadPool.shutdownNow() + } + }(fallbackStorageReadContext) } private[this] def sendRequest(req: FetchRequest): Unit = { @@ -391,10 +429,10 @@ final class ShuffleBlockFetcherIterator( private[this] def partitionBlocksByFetchMode( blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], localBlocks: mutable.LinkedHashSet[(BlockId, Int)], - fallbackStorageBlocks: mutable.LinkedHashSet[(BlockId, Int)], hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], - pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId], + fallbackStorageBlocks: mutable.LinkedHashSet[(BlockId, Int)]): ArrayBuffer[FetchRequest] = { logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") @@ -625,42 +663,6 @@ final class ShuffleBlockFetcherIterator( } } - /** - * Fetch the blocks from fallback storage while we are fetching remote blocks. - */ - private[this] def fetchFallbackStorageBlocks( - blocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { - logDebug(s"Start fetching fallback storage blocks: ${blocks.mkString(", ")}") - val iter = blocks.iterator - while (iter.hasNext) { - val (blockId, mapIndex) = iter.next() - try { - val buf = blockManager.getFallbackStorageBlockData(blockId) - // TODO: add fallback storage metrics - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - buf.retain() - results.put(SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, - buf.size(), buf, false)) - } catch { - // If we see an exception, stop immediately. - case e: Exception => - e match { - // ClosedByInterruptException is an excepted exception when kill task, - // don't log the exception stack trace to avoid confusing users. - // See: SPARK-28340 - case ce: ClosedByInterruptException => - logError( - log"Error occurred while fetching local blocks, ${MDC(ERROR, ce.getMessage)}") - case ex: Exception => logError("Error occurred while fetching local blocks", ex) - } - results.putFirst( - FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) - return - } - } - } - private[this] def fetchHostLocalBlock( blockId: BlockId, mapIndex: Int, @@ -765,15 +767,22 @@ final class ShuffleBlockFetcherIterator( context.addTaskCompletionListener(onCompleteCallback) // Local blocks to fetch, excluding zero-sized blocks. val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() - val fallbackStorageBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val fallbackBlocks = mutable.LinkedHashSet[(BlockId, Int)]() val hostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() - // Partition blocks by the different fetch modes: local, host-local, push-merged-local and - // remote blocks. + + // Partition blocks by the different fetch modes: local, host-local, push-merged-local, + // fallback storage and remote blocks. val remoteRequests = partitionBlocksByFetchMode( - blocksByAddress, localBlocks, fallbackStorageBlocks, hostLocalBlocksByExecutor, - pushMergedLocalBlocks) + blocksByAddress, localBlocks, hostLocalBlocksByExecutor, + pushMergedLocalBlocks, fallbackBlocks) + + // Turn the fallback storage blocks into read requests in random order. + Utils.randomize(fallbackBlocks).foreach { case (blockId, mapIndex) => + createFallbackStorageRequest(blockId, mapIndex) + } + // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -794,10 +803,6 @@ final class ShuffleBlockFetcherIterator( fetchLocalBlocks(localBlocks) logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") - // Get Fallback Storage Blocks - fetchFallbackStorageBlocks(fallbackStorageBlocks) - logDebug(s"Got fallback storage blocks in ${Utils.getUsedTimeNs(startTimeNs)}") - // Get host local blocks if any withFetchWaitTimeTracked(fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)) pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) @@ -1058,6 +1063,10 @@ final class ShuffleBlockFetcherIterator( defReqQueue.enqueue(request) result = null + case PreparedFallbackStorageRequestResult(request) => + fallbackStorageRequests.enqueue(request) + result = null + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => // We get this result in 3 cases: // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the @@ -1248,13 +1257,18 @@ final class ShuffleBlockFetcherIterator( } } + // Send fallback storage requests up to maxBytesInFlight + while (isBlockFetchable(fallbackStorageRequests)) { + sendFallbackStorageRequest(fallbackStorageRequests.dequeue()) + } + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host // immediately, defer the request until the next time it can be processed. // Process any outstanding deferred fetch requests if possible. if (deferredFetchRequests.nonEmpty) { for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { - while (isRemoteBlockFetchable(defReqQueue) && + while (isBlockFetchable(defReqQueue) && !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { val request = defReqQueue.dequeue() logDebug(s"Processing deferred fetch request for $remoteAddress with " @@ -1268,7 +1282,7 @@ final class ShuffleBlockFetcherIterator( } // Process any regular fetch requests if possible. - while (isRemoteBlockFetchable(fetchRequests)) { + while (isBlockFetchable(fetchRequests)) { val request = fetchRequests.dequeue() val remoteAddress = request.address if (isRemoteAddressMaxedOut(remoteAddress, request)) { @@ -1291,7 +1305,42 @@ final class ShuffleBlockFetcherIterator( numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } - def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + def sendFallbackStorageRequest(request: FallbackStorageRequest): Unit = { + bytesInFlight += request.size + reqsInFlight += 1 + + Future { + if (!isZombie) { + logDebug(log"Reading block ${MDC(BLOCK_ID, request.blockId)} from fallback storage") + try { + // materialize the block ManagedBuffer and store data in SuccessFetchResult + val buf = new NioManagedBuffer(request.block.nioByteBuffer()) + // TODO: add fallback storage metrics + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + val result = SuccessFetchResult( + request.blockId, request.mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, + request.size, buf, isNetworkReqDone = true) + results.put(result) + } catch { + case e: Throwable => + logError(log"Failed to read block ${MDC(BLOCK_ID, request.blockId)} " + + log"from fallback storage", e) + val result = FailureFetchResult( + request.blockId, request.mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e) + results.put(result) + // stop processing any further fallback storage requests + fallbackStorageReadPool.shutdownNow() + } + } + }(fallbackStorageReadContext) + + // TODO: needed? + numBlocksInFlightPerAddress(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID) = + numBlocksInFlightPerAddress.getOrElse(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, 0) + 1 + } + + def isBlockFetchable[T <: Request](fetchReqQueue: Queue[T]): Boolean = { fetchReqQueue.nonEmpty && (bytesInFlight == 0 || (reqsInFlight + 1 <= maxReqsInFlight && @@ -1352,16 +1401,18 @@ final class ShuffleBlockFetcherIterator( mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr, - originalLocalBlocks, originalFallbackStorageBlocks, originalHostLocalBlocksByExecutor, - originalMergedLocalBlocks) + originalLocalBlocks, originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks, originalFallbackStorageBlocks) + // Turn the fallback storage blocks into read requests in random order. + Utils.randomize(originalFallbackStorageBlocks).foreach { case (blockId, mapIndex) => + createFallbackStorageRequest(blockId, mapIndex) + } // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(originalRemoteReqs) logInfo(log"Created ${MDC(NUM_REQUESTS, originalRemoteReqs.size)} fallback remote requests " + log"for push-merged") // fetch all the fallback blocks that are local. fetchLocalBlocks(originalLocalBlocks) - // fetch all the fallback blocks from fallback storage. - fetchFallbackStorageBlocks(originalFallbackStorageBlocks) // Merged local blocks should be empty during fallback assert(originalMergedLocalBlocks.isEmpty, "There should be zero push-merged blocks during fallback") @@ -1602,6 +1653,10 @@ object ShuffleBlockFetcherIterator { result } + private[storage] trait Request { + val size: Long + } + /** * The block information to fetch used in FetchRequest. * @param blockId block id @@ -1624,10 +1679,25 @@ object ShuffleBlockFetcherIterator { case class FetchRequest( address: BlockManagerId, blocks: collection.Seq[FetchBlockInfo], - forMergedMetas: Boolean = false) { + forMergedMetas: Boolean = false) extends Request { val size = blocks.map(_.size).sum } + /** + * A request to fetch blocks from the Fallback Storage. Holds block data lazily. + * We read the data asynchronously and multithreaded. The result is a SuccessFetchResult + * where buf contains the materialized data. + * @param blockId The block id to read + * @param mapIndex The mapId of the block + * @param block the block as a lazy ManagedBuffer + */ + case class FallbackStorageRequest( + blockId: BlockId, + mapIndex: Int, + block: ManagedBuffer) extends Request { + val size: Long = block.size() + } + /** * Result of a fetch from a remote block. */ @@ -1674,6 +1744,16 @@ object ShuffleBlockFetcherIterator { private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult + /** + * Fetching block data from the fallback storage is a two-steps process: + * 1. read offset and size of the shuffle block from fallback storage + * 2. read the block data from fallback storage + * A PreparedFallbackStorageRequestResult is the outcome of the first step, + * the SuccessFetchResult is the outcome of the second step. + */ + private[storage] case class PreparedFallbackStorageRequestResult( + fallbackStorageRequest: FallbackStorageRequest) extends FetchResult + /** * Result of an un-successful fetch of either of these: * 1) Remote shuffle chunk. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 08220a26010fc..780af6ff9473a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -191,6 +191,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { maxBlocksInFlightPerAddress: Int = Int.MaxValue, maxReqSizeShuffleToMem: Int = Int.MaxValue, maxAttemptsOnNettyOOM: Int = 10, + fallbackStorageReadThreads: Int = 5, detectCorrupt: Boolean = true, detectCorruptUseExtraMemory: Boolean = true, checksumEnabled: Boolean = true, @@ -217,6 +218,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, maxAttemptsOnNettyOOM, + fallbackStorageReadThreads, detectCorrupt, detectCorruptUseExtraMemory, checksumEnabled, From c8bd1418144074ca56426c7838c39636a614069d Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 10 Feb 2026 20:48:44 +0100 Subject: [PATCH 5/8] Update unit tests Removes "Fast fail when failed to get fallback storage blocks" as fallback storage blocks are fetched concurrently and fast fail is not guaranteed any more. --- .../storage/ShuffleBlockFetcherIterator.scala | 3 +- .../ShuffleBlockFetcherIteratorSuite.scala | 96 ++++++++----------- 2 files changed, 41 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b69cfefc8f109..2180efcf3fe74 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -153,7 +153,8 @@ final class ShuffleBlockFetcherIterator( * Thread pool reading from fallback storage, first creating FallbackStorageRequest from * block id and map index, then materializing requests to SuccessFetchResult. */ - private[this] val fallbackStorageReadPool: ThreadPoolExecutor = + // This is visible for testing + private[storage] val fallbackStorageReadPool: ThreadPoolExecutor = ThreadUtils.newDaemonFixedThreadPool(fallbackStorageReadThreads, "fallback-storage-read") private[this] val fallbackStorageReadContext: ExecutionContextExecutor = ExecutionContext.fromExecutor(fallbackStorageReadPool) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 780af6ff9473a..fe07d6d2cbb90 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.util import java.util.UUID import java.util.concurrent.{CompletableFuture, Semaphore} import java.util.zip.CheckedInputStream @@ -36,6 +37,8 @@ import org.mockito.Mockito.{doThrow, mock, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.roaringbitmap.RoaringBitmap +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext} import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID @@ -49,7 +52,7 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.apache.spark.util.Utils -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with Eventually { private var transfer: BlockTransferService = _ private var mapOutputTracker: MapOutputTracker = _ @@ -153,6 +156,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) + val buf = ByteBuffer.allocate(size) + util.Arrays.fill(buf.array(), 1.byteValue) + when(mockManagedBuffer.nioByteBuffer()).thenReturn(buf) when(mockManagedBuffer.createInputStream()).thenReturn(in) when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer @@ -342,7 +348,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { ShuffleBlockId(0, 9, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 10, 0) -> createMockManagedBuffer()) fallbackBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getFallbackStorageBlockData(meq(blockId)) } val iterator = createShuffleBlockIteratorWithDefaults( @@ -355,9 +361,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { blockManager = Some(blockManager) ) - // 3 local blocks and 2 fallback blocks fetched in initialization - verify(blockManager, times(3 + 2)).getLocalBlockData(any()) + // 3 local blocks fetched in initialization + verify(blockManager, times(3)).getLocalBlockData(any()) + // 2 fallback storage blocks fetched in initialization + // initialize creates futures that eventually call into getFallbackStorageBlockData + eventually(timeout(1.seconds), interval(10.millis)) { + assert(iterator.fallbackStorageReadPool.getCompletedTaskCount >= 2) + } + verify(blockManager, times(2)).getFallbackStorageBlockData(any()) // SPARK-55469: but buffer data have never been materialized fallbackBlocks.values.foreach { mockBuf => verify(mockBuf, never()).nioByteBuffer() @@ -374,7 +386,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) - verifyBufferRelease(mockBuf, inputStream) + if (!fallbackBlocks.contains(blockId)) { + verifyBufferRelease(mockBuf, inputStream) + } } assert(!iterator.hasNext) @@ -388,8 +402,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // SPARK-55469: fallback buffer data have been materialized once fallbackBlocks.values.foreach { mockBuf => - verify(mockBuf, never()).nioByteBuffer() - verify(mockBuf, times(1)).createInputStream() + verify(mockBuf, times(1)).nioByteBuffer() + verify(mockBuf, never()).createInputStream() verify(mockBuf, never()).convertToNetty() verify(mockBuf, never()).convertToNettyForSsl() } @@ -504,7 +518,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val mergedFallbackBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockBatchId(0, 1, 0, 2) -> createMockManagedBuffer()) mergedFallbackBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getFallbackStorageBlockData(meq(blockId)) } // Make sure remote blocks would return the merged block @@ -546,9 +560,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { doBatchFetch = true ) - // 1 local merge block and 1 fallback merge block fetched in initialization - verify(blockManager, times(1 + 1)).getLocalBlockData(any()) + // 1 local merge block fetched in initialization + verify(blockManager, times(1)).getLocalBlockData(any()) + // 1 fallback merge block fetched in initialization + // initialize creates futures that eventually call into getFallbackStorageBlockData + eventually(timeout(1.seconds), interval(10.millis)) { + assert(iterator.fallbackStorageReadPool.getCompletedTaskCount >= 1) + } + verify(blockManager, times(1)).getFallbackStorageBlockData(any()) // SPARK-55469: but buffer data have never been materialized mergedFallbackBlocks.values.foreach { mockBuf => verify(mockBuf, never()).nioByteBuffer() @@ -565,7 +585,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { verifyFetchBlocksInvocationCount(1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) - verifyBufferRelease(mockBuf, inputStream) + if (!mergedFallbackBlocks.contains(blockId)) { + verifyBufferRelease(mockBuf, inputStream) + } } assert(!iterator.hasNext) @@ -573,15 +595,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { verify(blockManager, times(1)) .getHostLocalShuffleData(any(), meq(Array("local-dir"))) + // 1 merged remote block is read from the same block manager + verifyFetchBlocksInvocationCount(1) + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) + // SPARK-55469: merged fallback buffer data have been materialized once mergedFallbackBlocks.values.foreach { mockBuf => - verify(mockBuf, never()).nioByteBuffer() - verify(mockBuf, times(1)).createInputStream() + verify(mockBuf, times(1)).nioByteBuffer() + verify(mockBuf, never()).createInputStream() verify(mockBuf, never()).convertToNetty() verify(mockBuf, never()).convertToNettyForSsl() } - - assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } test("fetch continuous blocks in batch should respect maxBytesInFlight") { @@ -2139,46 +2163,4 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) assert(!iterator.hasNext) } - - test("Fast fail when failed to get fallback storage blocks") { - val blockManager = createMockBlockManager() - - // Make sure blockManager.getBlockData would return the blocks - val localBmId = blockManager.blockManagerId - val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) - localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) - } - - // Make sure fallback storage would return the blocks - val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID - val fallbackBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer()) - fallbackBlocks.take(1).foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) - } - fallbackBlocks.takeRight(1).foreach { case (blockId, _) => - doThrow(new RuntimeException("Cannot read from fallback storage")) - .when(blockManager).getLocalBlockData(meq(blockId)) - } - - val iterator = createShuffleBlockIteratorWithDefaults( - Map( - localBmId -> toBlockList(localBlocks.keys, 1L, 0), - fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L, 1) - ), - blockManager = Some(blockManager) - ) - - // Fetch failure should be placed in the head of results, exception should be thrown for the - // 1st instance. - intercept[FetchFailedException] { iterator.next() } - assert(iterator.next()._1 === ShuffleBlockId(0, 0, 0)) - assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) - assert(iterator.next()._1 === ShuffleBlockId(0, 2, 0)) - assert(!iterator.hasNext) - } } From d4dde4c44a72725e6302b00843818dc7420070d5 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 4 Mar 2026 10:07:22 +0100 Subject: [PATCH 6/8] Guard adding FailureFetchResult to results by isZombie and synchronized This blocks cleanup() calling fallbackStorageReadPool.shutdownNow() while futures are locking results to put FailureFetchResults. Otherwise, that put in catch clauses would be interrupted and that exception kills the executor. Logging errors only if !isZombie, meaning the iterator is not yet cleaning up. Further, FailureFetchResult are putFirst to stop iteration as quickly as possible. Finally, fallbackStorageReadPool.shutdownNow() is only called in cleanup(). --- .../storage/ShuffleBlockFetcherIterator.scala | 55 ++++++++++++------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2180efcf3fe74..5f00c0a513b14 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -284,18 +284,28 @@ final class ShuffleBlockFetcherIterator( private[this] def createFallbackStorageRequest(blockId: BlockId, mapIndex: Int): Unit = { Future { - try { - val block = blockManager.getFallbackStorageBlockData(blockId) - val request = FallbackStorageRequest(blockId, mapIndex, block) - results.put(PreparedFallbackStorageRequestResult(request)) - } catch { - case e: Throwable => - logError(log"Failed to prepare request to read block ${MDC(BLOCK_ID, blockId)} " + - log"from fallback storage", e) - results.put( - FailureFetchResult(blockId, mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e)) - // stop processing any further fallback storage requests - fallbackStorageReadPool.shutdownNow() + if (!isZombie) { + try { + val block = blockManager.getFallbackStorageBlockData(blockId) + val request = FallbackStorageRequest(blockId, mapIndex, block) + results.put(PreparedFallbackStorageRequestResult(request)) + } catch { + case e: Throwable => + // the FailureFetchResult will stop iteration of this iterator + // task completion listener will shut down the thread pool / execution context + // the synchronized protects isZombie and blocks cleanup() from calling + // fallbackStorageReadPool.shutdownNow(), which would interrupt results.put + // that interrupted exception would kill the executor + synchronized { + if (!isZombie) { + logError(log"Failed to prepare request to read block ${MDC(BLOCK_ID, blockId)} " + + log"from fallback storage", e) + val result = FailureFetchResult( + blockId, mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e) + results.putFirst(result) + } + } + } } }(fallbackStorageReadContext) } @@ -1325,13 +1335,20 @@ final class ShuffleBlockFetcherIterator( results.put(result) } catch { case e: Throwable => - logError(log"Failed to read block ${MDC(BLOCK_ID, request.blockId)} " + - log"from fallback storage", e) - val result = FailureFetchResult( - request.blockId, request.mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e) - results.put(result) - // stop processing any further fallback storage requests - fallbackStorageReadPool.shutdownNow() + // the FailureFetchResult will stop iteration of this iterator + // task completion listener will shut down the thread pool / execution context + // the synchronized protects isZombie and blocks cleanup() from calling + // fallbackStorageReadPool.shutdownNow(), which would interrupt results.put + // that interrupted exception would kill the executor + synchronized { + if (!isZombie) { + logError(log"Failed to read block ${MDC(BLOCK_ID, request.blockId)} " + + log"from fallback storage", e) + val result = FailureFetchResult( + request.blockId, request.mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e) + results.putFirst(result) + } + } } } }(fallbackStorageReadContext) From 5df69d10b0b89251174e35c4a5061c8e41b79d49 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 13 Mar 2026 10:03:37 +0100 Subject: [PATCH 7/8] Don't maintain numBlocksInFlightPerAddress for fallback storage --- .../spark/storage/ShuffleBlockFetcherIterator.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5f00c0a513b14..f0c37658dd02c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -907,7 +907,9 @@ final class ShuffleBlockFetcherIterator( // It is a host local block or a local shuffle chunk shuffleMetricsUpdate(blockId, buf, local = true) } else { - numBlocksInFlightPerAddress(address) -= 1 + if (address != FallbackStorage.FALLBACK_BLOCK_MANAGER_ID) { + numBlocksInFlightPerAddress(address) -= 1 + } shuffleMetricsUpdate(blockId, buf, local = false) bytesInFlight -= size } @@ -1352,10 +1354,6 @@ final class ShuffleBlockFetcherIterator( } } }(fallbackStorageReadContext) - - // TODO: needed? - numBlocksInFlightPerAddress(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID) = - numBlocksInFlightPerAddress.getOrElse(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, 0) + 1 } def isBlockFetchable[T <: Request](fetchReqQueue: Queue[T]): Boolean = { From 6ed4e0f99649da7037ccedd3484a019182fa12f2 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Mar 2026 09:59:23 +0100 Subject: [PATCH 8/8] Remove redundant if branch for only local blocks --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f0c37658dd02c..f8dddee9cacac 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -472,12 +472,6 @@ final class ShuffleBlockFetcherIterator( } else { collectFetchRequests(address, blockInfos, collectedRemoteRequests) } - } else if (address.executorId == localExecId) { - val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( - blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) - numBlocksToFetch += mergedBlockInfos.size - localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) - localBlockBytes += mergedBlockInfos.map(_.size).sum } else if (localAndFallbackExecIds.contains(address.executorId)) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)