Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config._
import org.apache.spark.scheduler.SchedulingMode._
import org.apache.spark.util.{AccumulatorV2, Clock, LongAccumulator, SystemClock, Utils}
import org.apache.spark.util.collection.PercentileHeap
import org.apache.spark.util.collection.{OpenHashSet, PercentileHeap}

/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
Expand Down Expand Up @@ -200,6 +200,12 @@ private[spark] class TaskSetManager(
// Task index, start and finish time for each task attempt (indexed by task ID)
private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]

// Reverse index: executor ID -> set of task IDs that were launched on that executor.
// This includes both running and completed tasks, used to efficiently look up tasks
// when an executor is lost, avoiding O(N) scans over all taskInfos.
// Uses OpenHashSet[Long] (specialized for Long) to avoid boxing overhead.
private[scheduler] val executorIdToTaskIds = new HashMap[String, OpenHashSet[Long]]

// Use a MedianHeap to record durations of successful tasks so we know when to launch
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
// of inserting into the heap when the heap won't be used.
Expand Down Expand Up @@ -537,6 +543,7 @@ private[spark] class TaskSetManager(
taskId, index, attemptNum, task.partitionId, launchTime,
execId, host, taskLocality, speculative)
taskInfos(taskId) = info
executorIdToTaskIds.getOrElseUpdate(execId, new OpenHashSet[Long]).add(taskId)
taskAttempts(index) = info :: taskAttempts(index)
// Serialize and return the task
val serializedTask: ByteBuffer = try {
Expand Down Expand Up @@ -1141,6 +1148,7 @@ private[spark] class TaskSetManager(

/** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
override def executorLost(execId: String, host: String, reason: ExecutorLossReason): Unit = {
val taskIdsOnExec = executorIdToTaskIds.getOrElse(execId, TaskSetManager.emptyLongSet)
// Re-enqueue any tasks with potential shuffle data loss that ran on the failed executor
// if this is a shuffle map stage, and we are not using an external shuffle server which
// could serve the shuffle outputs or the executor lost is caused by decommission (which
Expand All @@ -1150,7 +1158,10 @@ private[spark] class TaskSetManager(
!sched.sc.shuffleDriverComponents.supportsReliableStorage() &&
(reason.isInstanceOf[ExecutorDecommission] || !env.blockManager.externalShuffleServiceEnabled)
if (maybeShuffleMapOutputLoss && !isZombie) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val iter1 = taskIdsOnExec.iterator
while (iter1.hasNext) {
val tid = iter1.next()
val info = taskInfos(tid)
val index = info.index
lazy val isShuffleMapOutputAvailable = reason match {
case ExecutorDecommission(_, _) =>
Expand Down Expand Up @@ -1192,18 +1203,23 @@ private[spark] class TaskSetManager(
}
}
}
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
val exitCausedByApp: Boolean = reason match {
case ExecutorExited(_, false, _) => false
case ExecutorKilled | ExecutorDecommission(_, _) => false
case ExecutorProcessLost(_, _, false) => false
// If the task is launching, this indicates that Driver has sent LaunchTask to Executor,
// but Executor has not sent StatusUpdate(TaskState.RUNNING) to Driver. Hence, we assume
// that the task is not running, and it is NetworkFailure rather than TaskFailure.
case _ => !info.launching
val iter2 = taskIdsOnExec.iterator
while (iter2.hasNext) {
val tid = iter2.next()
val info = taskInfos(tid)
if (info.running) {
val exitCausedByApp: Boolean = reason match {
case ExecutorExited(_, false, _) => false
case ExecutorKilled | ExecutorDecommission(_, _) => false
case ExecutorProcessLost(_, _, false) => false
// If the task is launching, this indicates that Driver has sent LaunchTask to Executor,
// but Executor has not sent StatusUpdate(TaskState.RUNNING) to Driver. Hence, we assume
// that the task is not running, and it is NetworkFailure rather than TaskFailure.
case _ => !info.launching
}
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId,
exitCausedByApp, Some(reason.toString)))
}
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp,
Some(reason.toString)))
}
// recalculate valid locality levels and waits when executor is lost
recomputeLocality()
Expand Down Expand Up @@ -1448,6 +1464,10 @@ private[spark] object TaskSetManager {

// 1 minute
val BARRIER_LOGGING_INTERVAL = 60000

// Shared empty set used as default value for executorIdToTaskIds lookups
// to avoid allocating a new empty set on each executorLost call.
private val emptyLongSet = new OpenHashSet[Long](0)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,114 @@ class TaskSetManagerSuite

}

test("SPARK-56235 Reverse index is correctly maintained and used by executorLost") {
sc = new SparkContext("local", "test")
sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
// Create a task set with 4 tasks
val taskSet = FakeTask.createTaskSet(4,
Seq(TaskLocation("host1", "exec1")),
Seq(TaskLocation("host1", "exec1")),
Seq(TaskLocation("host2", "exec2")),
Seq(TaskLocation("host2", "exec2")))
val clock = new ManualClock()
clock.advance(1)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)

// Initially, the reverse index should be empty
assert(manager.executorIdToTaskIds.isEmpty)

// Offer resources: tasks 0, 1 on exec1; tasks 2, 3 on exec2
val task0 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get
val task1 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get
val task2 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get
val task3 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get

// Verify reverse index is correctly populated
assert(manager.executorIdToTaskIds.size === 2)
assert(manager.executorIdToTaskIds("exec1").size === 2)
assert(manager.executorIdToTaskIds("exec1").contains(task0.taskId))
assert(manager.executorIdToTaskIds("exec1").contains(task1.taskId))
assert(manager.executorIdToTaskIds("exec2").size === 2)
assert(manager.executorIdToTaskIds("exec2").contains(task2.taskId))
assert(manager.executorIdToTaskIds("exec2").contains(task3.taskId))

assert(manager.runningTasks === 4)

// Lose exec1 - only tasks on exec1 should be affected
sched.removeExecutor("exec1")
manager.executorLost("exec1", "host1", ExecutorProcessLost())

// Tasks on exec1 should be failed (no longer running)
assert(manager.runningTasks === 2)
// Tasks on exec2 should still be running
assert(manager.taskInfos(task2.taskId).running)
assert(manager.taskInfos(task3.taskId).running)
// Tasks on exec1 should not be running
assert(!manager.taskInfos(task0.taskId).running)
assert(!manager.taskInfos(task1.taskId).running)
}

test("SPARK-56235 Reverse index works correctly with speculative tasks") {
val conf = new SparkConf().set(config.SPECULATION_ENABLED, true)
sc = new SparkContext("local", "test", conf)
sc.conf.set(config.SPECULATION_MULTIPLIER, 0.0)
sc.conf.set(config.SPECULATION_QUANTILE, 0.5)

sched = new FakeTaskScheduler(sc, ("exec1", "host1"),
("exec2", "host2"), ("exec3", "host3"))
sched.initialize(new FakeSchedulerBackend() {
override def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit = {}
})

val taskSet = FakeTask.createTaskSet(2,
Seq(TaskLocation("host1", "exec1")),
Seq(TaskLocation("host2", "exec2")))
val clock = new ManualClock()
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)

// Launch tasks: task 0 on exec1, task 1 on exec2
val task0 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get
val task1 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get

assert(manager.executorIdToTaskIds("exec1").size === 1)
assert(manager.executorIdToTaskIds("exec1").contains(task0.taskId))
assert(manager.executorIdToTaskIds("exec2").size === 1)
assert(manager.executorIdToTaskIds("exec2").contains(task1.taskId))

// Complete task 0, so that speculative task can be launched for task 1
clock.advance(1)
val directTaskResult = new DirectTaskResult[String]() {
override def value(resultSer: SerializerInstance): String = ""
}
manager.handleSuccessfulTask(task0.taskId, directTaskResult)

// Launch speculative copy of task 1 on exec3
clock.advance(1)
manager.checkSpeculatableTasks(0)
manager.speculatableTasks += 1
manager.addPendingTask(1, speculatable = true)
val specTask = manager.resourceOffer("exec3", "host3", ANY)._1.get
assert(specTask.index === 1)
assert(specTask.attemptNumber === 1)

// Verify reverse index now has speculative task on exec3
assert(manager.executorIdToTaskIds("exec3").size === 1)
assert(manager.executorIdToTaskIds("exec3").contains(specTask.taskId))

// Lose exec2 (where original task 1 is running)
sched.removeExecutor("exec2")
manager.executorLost("exec2", "host2", ExecutorProcessLost())

// Original task 1 should be failed, speculative task 1 on exec3 should still be running
assert(!manager.taskInfos(task1.taskId).running)
assert(manager.taskInfos(specTask.taskId).running)
assert(manager.runningTasks === 1)
}

private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty,
Expand Down