Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,24 @@ private[spark] class Executor(

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val taskId = taskDescription.taskId
val tr = createTaskRunner(context, taskDescription)
runningTasks.put(taskId, tr)
val killMark = killMarks.get(taskId)
if (killMark != null) {
tr.kill(killMark._1, killMark._2)
killMarks.remove(taskId)
}
var taskRunnerOpt: Option[TaskRunner] = None
try {
val tr = createTaskRunner(context, taskDescription)
taskRunnerOpt = Some(tr)
runningTasks.put(taskId, tr)
val killMark = killMarks.get(taskId)
if (killMark != null) {
tr.kill(killMark._1, killMark._2)
killMarks.remove(taskId)
}
threadPool.execute(tr)
} catch {
case t: Throwable =>
// Clean up if task was added to runningTasks before the failure.
// If TaskRunner construction failed, taskRunnerOpt will be None and nothing to clean up.
taskRunnerOpt.foreach { tr =>
runningTasks.remove(tr.taskId)
}
try {
logError(log"Executor launch task ${MDC(TASK_NAME, taskDescription.name)} failed," +
log" reason: ${MDC(REASON, t.getMessage)}")
Expand All @@ -441,9 +448,22 @@ private[spark] class Executor(
TaskState.FAILED,
env.closureSerializer.newInstance().serialize(new ExceptionFailure(t, Seq.empty)))
} catch {
case NonFatal(e) if env.isStopped =>
logError(
log"Executor update launching task " +
log"${MDC(TASK_NAME, taskDescription.name)} " +
log"failed status failed, reason: ${MDC(REASON, t.getMessage)}" +
log", spark env is stopped"
)
// No need to exit the executor as the executor is already stopped.
// Leave it live to clean up the rest tasks and log info (similar to SPARK-19147).
case t: Throwable =>
logError(log"Executor update launching task ${MDC(TASK_NAME, taskDescription.name)} " +
log"failed status failed, reason: ${MDC(REASON, t.getMessage)}")
logError(
log"Executor update launching task " +
log"${MDC(TASK_NAME, taskDescription.name)} " +
log"failed status failed, reason: ${MDC(REASON, t.getMessage)}" +
log", shutting down the executor"
)
System.exit(-1)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
when(executor.threadPool).thenReturn(threadPool)
val runningTasks = new ConcurrentHashMap[Long, Executor#TaskRunner]
when(executor.runningTasks).thenAnswer(_ => runningTasks)
val killMarks = spy(new ConcurrentHashMap[Long, Long])
when(executor.killMarks).thenAnswer(_ => killMarks)
when(executor.conf).thenReturn(conf)

def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = {
Expand Down Expand Up @@ -417,6 +419,8 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
val runningTasks = spy[ConcurrentHashMap[Long, Executor#TaskRunner]](
new ConcurrentHashMap[Long, Executor#TaskRunner])
when(executor.runningTasks).thenAnswer(_ => runningTasks)
val killMarks = spy(new ConcurrentHashMap[Long, Long])
when(executor.killMarks).thenAnswer(_ => killMarks)
when(executor.conf).thenReturn(conf)

// We don't really verify the data, just pass it around.
Expand Down Expand Up @@ -509,6 +513,8 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
val runningTasks = spy[ConcurrentHashMap[Long, Executor#TaskRunner]](
new ConcurrentHashMap[Long, Executor#TaskRunner])
when(executor.runningTasks).thenAnswer(_ => runningTasks)
val killMarks = spy(new ConcurrentHashMap[Long, Long])
when(executor.killMarks).thenAnswer(_ => killMarks)
when(executor.conf).thenReturn(conf)

// We don't really verify the data, just pass it around.
Expand Down
58 changes: 58 additions & 0 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,64 @@ class ExecutorSuite extends SparkFunSuite
}
}

test(
"SPARK-55093: launchTask should handle TaskRunner construction failures"
) {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)

val mockExecutorBackend = mock[ExecutorBackend]
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])

withExecutor("id", "localhost", env) { executor =>
// Use reflection to make createTaskRunner throw an exception by replacing runningTasks
// with a mock that throws when put is called. This simulates a failure after TaskRunner
// construction but tests the same cleanup logic.
val executorClass = classOf[Executor]
val runningTasksField = executorClass.getDeclaredField("runningTasks")
runningTasksField.setAccessible(true)
val originalRunningTasks = runningTasksField.get(executor)

// Create a mock ConcurrentHashMap that throws when put is called
val testException = new RuntimeException("TaskRunner construction failed")
type TaskRunnerType = executor.TaskRunner
val mockRunningTasks =
mock[java.util.concurrent.ConcurrentHashMap[Long, TaskRunnerType]]
when(mockRunningTasks.put(any[Long], any[TaskRunnerType]))
.thenThrow(testException)
runningTasksField.set(executor, mockRunningTasks)

try {
// Launch the task - this should catch the exception and send statusUpdate
executor.launchTask(mockExecutorBackend, taskDescription)

// Verify that statusUpdate was called with FAILED state
verify(mockExecutorBackend).statusUpdate(
meq(taskDescription.taskId),
meq(TaskState.FAILED),
statusCaptor.capture()
)

// Verify that the exception was correctly serialized
val failureData = statusCaptor.getValue
val failReason = serializer
.newInstance()
.deserialize[ExceptionFailure](failureData)
assert(failReason.exception.isDefined)
assert(failReason.exception.get.isInstanceOf[RuntimeException])
assert(
failReason.exception.get.getMessage === "TaskRunner construction failed"
)
} finally {
// Restore the original runningTasks
runningTasksField.set(executor, originalRunningTasks)
}
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
Expand Down
Loading