diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 69e0a10a34b28..5c077a7a3bbb8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -35,7 +35,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config._ import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.util.{AccumulatorV2, Clock, LongAccumulator, SystemClock, Utils} -import org.apache.spark.util.collection.PercentileHeap +import org.apache.spark.util.collection.{OpenHashSet, PercentileHeap} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -200,6 +200,12 @@ private[spark] class TaskSetManager( // Task index, start and finish time for each task attempt (indexed by task ID) private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] + // Reverse index: executor ID -> set of task IDs that were launched on that executor. + // This includes both running and completed tasks, used to efficiently look up tasks + // when an executor is lost, avoiding O(N) scans over all taskInfos. + // Uses OpenHashSet[Long] (specialized for Long) to avoid boxing overhead. + private[scheduler] val executorIdToTaskIds = new HashMap[String, OpenHashSet[Long]] + // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead // of inserting into the heap when the heap won't be used. @@ -537,6 +543,7 @@ private[spark] class TaskSetManager( taskId, index, attemptNum, task.partitionId, launchTime, execId, host, taskLocality, speculative) taskInfos(taskId) = info + executorIdToTaskIds.getOrElseUpdate(execId, new OpenHashSet[Long]).add(taskId) taskAttempts(index) = info :: taskAttempts(index) // Serialize and return the task val serializedTask: ByteBuffer = try { @@ -1141,6 +1148,7 @@ private[spark] class TaskSetManager( /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ override def executorLost(execId: String, host: String, reason: ExecutorLossReason): Unit = { + val taskIdsOnExec = executorIdToTaskIds.getOrElse(execId, TaskSetManager.EMPTY_LONG_SET) // Re-enqueue any tasks with potential shuffle data loss that ran on the failed executor // if this is a shuffle map stage, and we are not using an external shuffle server which // could serve the shuffle outputs or the executor lost is caused by decommission (which @@ -1150,7 +1158,10 @@ private[spark] class TaskSetManager( !sched.sc.shuffleDriverComponents.supportsReliableStorage() && (reason.isInstanceOf[ExecutorDecommission] || !env.blockManager.externalShuffleServiceEnabled) if (maybeShuffleMapOutputLoss && !isZombie) { - for ((tid, info) <- taskInfos if info.executorId == execId) { + val iter1 = taskIdsOnExec.iterator + while (iter1.hasNext) { + val tid = iter1.next() + val info = taskInfos(tid) val index = info.index lazy val isShuffleMapOutputAvailable = reason match { case ExecutorDecommission(_, _) => @@ -1192,18 +1203,23 @@ private[spark] class TaskSetManager( } } } - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - val exitCausedByApp: Boolean = reason match { - case ExecutorExited(_, false, _) => false - case ExecutorKilled | ExecutorDecommission(_, _) => false - case ExecutorProcessLost(_, _, false) => false - // If the task is launching, this indicates that Driver has sent LaunchTask to Executor, - // but Executor has not sent StatusUpdate(TaskState.RUNNING) to Driver. Hence, we assume - // that the task is not running, and it is NetworkFailure rather than TaskFailure. - case _ => !info.launching + val iter2 = taskIdsOnExec.iterator + while (iter2.hasNext) { + val tid = iter2.next() + val info = taskInfos(tid) + if (info.running) { + val exitCausedByApp: Boolean = reason match { + case ExecutorExited(_, false, _) => false + case ExecutorKilled | ExecutorDecommission(_, _) => false + case ExecutorProcessLost(_, _, false) => false + // If the task is launching, this indicates that Driver has sent LaunchTask to Executor, + // but Executor has not sent StatusUpdate(TaskState.RUNNING) to Driver. Hence, we assume + // that the task is not running, and it is NetworkFailure rather than TaskFailure. + case _ => !info.launching + } + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, + exitCausedByApp, Some(reason.toString))) } - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, - Some(reason.toString))) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() @@ -1448,6 +1464,10 @@ private[spark] object TaskSetManager { // 1 minute val BARRIER_LOGGING_INTERVAL = 60000 + + // Shared empty set used as default value for executorIdToTaskIds lookups + // to avoid allocating a new empty set on each executorLost call. + private val EMPTY_LONG_SET = new OpenHashSet[Long](0) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index be1bc5fe3212a..adcb57a0187a4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1745,6 +1745,114 @@ class TaskSetManagerSuite } + test("SPARK-56235 Reverse index is correctly maintained and used by executorLost") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + // Create a task set with 4 tasks + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec2")), + Seq(TaskLocation("host2", "exec2"))) + val clock = new ManualClock() + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + + // Initially, the reverse index should be empty + assert(manager.executorIdToTaskIds.isEmpty) + + // Offer resources: tasks 0, 1 on exec1; tasks 2, 3 on exec2 + val task0 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get + val task1 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get + val task2 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get + val task3 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get + + // Verify reverse index is correctly populated + assert(manager.executorIdToTaskIds.size === 2) + assert(manager.executorIdToTaskIds("exec1").size === 2) + assert(manager.executorIdToTaskIds("exec1").contains(task0.taskId)) + assert(manager.executorIdToTaskIds("exec1").contains(task1.taskId)) + assert(manager.executorIdToTaskIds("exec2").size === 2) + assert(manager.executorIdToTaskIds("exec2").contains(task2.taskId)) + assert(manager.executorIdToTaskIds("exec2").contains(task3.taskId)) + + assert(manager.runningTasks === 4) + + // Lose exec1 - only tasks on exec1 should be affected + sched.removeExecutor("exec1") + manager.executorLost("exec1", "host1", ExecutorProcessLost()) + + // Tasks on exec1 should be failed (no longer running) + assert(manager.runningTasks === 2) + // Tasks on exec2 should still be running + assert(manager.taskInfos(task2.taskId).running) + assert(manager.taskInfos(task3.taskId).running) + // Tasks on exec1 should not be running + assert(!manager.taskInfos(task0.taskId).running) + assert(!manager.taskInfos(task1.taskId).running) + } + + test("SPARK-56235 Reverse index works correctly with speculative tasks") { + val conf = new SparkConf().set(config.SPECULATION_ENABLED, true) + sc = new SparkContext("local", "test", conf) + sc.conf.set(config.SPECULATION_MULTIPLIER, 0.0) + sc.conf.set(config.SPECULATION_QUANTILE, 0.5) + + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} + }) + + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec2"))) + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + + // Launch tasks: task 0 on exec1, task 1 on exec2 + val task0 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get + val task1 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get + + assert(manager.executorIdToTaskIds("exec1").size === 1) + assert(manager.executorIdToTaskIds("exec1").contains(task0.taskId)) + assert(manager.executorIdToTaskIds("exec2").size === 1) + assert(manager.executorIdToTaskIds("exec2").contains(task1.taskId)) + + // Complete task 0, so that speculative task can be launched for task 1 + clock.advance(1) + val directTaskResult = new DirectTaskResult[String]() { + override def value(resultSer: SerializerInstance): String = "" + } + manager.handleSuccessfulTask(task0.taskId, directTaskResult) + + // Launch speculative copy of task 1 on exec3 + clock.advance(1) + manager.checkSpeculatableTasks(0) + manager.speculatableTasks += 1 + manager.addPendingTask(1, speculatable = true) + val specTask = manager.resourceOffer("exec3", "host3", ANY)._1.get + assert(specTask.index === 1) + assert(specTask.attemptNumber === 1) + + // Verify reverse index now has speculative task on exec3 + assert(manager.executorIdToTaskIds("exec3").size === 1) + assert(manager.executorIdToTaskIds("exec3").contains(specTask.taskId)) + + // Lose exec2 (where original task 1 is running) + sched.removeExecutor("exec2") + manager.executorLost("exec2", "host2", ExecutorProcessLost()) + + // Original task 1 should be failed, speculative task 1 on exec3 should still be running + assert(!manager.taskInfos(task1.taskId).running) + assert(manager.taskInfos(specTask.taskId).running) + assert(manager.runningTasks === 1) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty,