Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -54,6 +55,171 @@ def execute_resumable(self, context):
from airflow.providers.common.compat.sdk import Context


class SparkSubmitResumableBackend(ABC):
"""Base interface/abstract class for Spark submit resumable backends."""

def __init__(self, operator: SparkSubmitOperator) -> None:
self.operator = operator

@abstractmethod
def submit_job(self, context: Context) -> str:
"""Submit the Spark job and return the external/driver ID."""

@abstractmethod
def get_job_status(self, external_id: str, context: Context) -> str:
"""Retrieve the status of the Spark job."""

@abstractmethod
def is_job_active(self, status: str) -> bool:
"""Check if the Spark job is in an active state."""

@abstractmethod
def is_job_succeeded(self, status: str) -> bool:
"""Check if the Spark job has completed successfully."""

@abstractmethod
def poll_until_complete(self, external_id: str, context: Context) -> None:
"""Poll the Spark job status until it reaches a terminal state."""

@abstractmethod
def on_kill(self) -> None:
"""Handle execution termination/kill signal."""


class YarnSparkSubmitBackend(SparkSubmitResumableBackend):
"""Resumable backend strategy for YARN cluster mode."""

def submit_job(self, context: Context) -> str:
hook = self.operator._hook
if hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true":
raise ValueError(
"spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts"
"with the need to exit spark-submit immediately to persist the application ID for tracking. "
"Either remove the explicit conf or set reconnect_on_retry=False."
)
hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
hook.submit(self.operator.application)
app_id = hook._yarn_application_id
if not app_id:
raise RuntimeError("spark-submit did not produce a YARN application ID")
self.operator.log.info("YARN application submitted: %s", app_id)
return app_id

def get_job_status(self, external_id: str, context: Context) -> str:
return self.operator._hook.query_yarn_application_status(external_id)

def is_job_active(self, status: str) -> bool:
return status.upper() in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"}

def is_job_succeeded(self, status: str) -> bool:
return status.upper() == "SUCCEEDED"

def poll_until_complete(self, external_id: str, context: Context) -> None:
hook = self.operator._hook
try:
hook._start_yarn_application_status_tracking(external_id)
finally:
hook._run_post_submit_commands()

def on_kill(self) -> None:
hook = self.operator._hook
if hook._yarn_application_id:
hook._kill_yarn_application(hook._yarn_application_id)
else:
hook.on_kill()


class KubernetesSparkSubmitBackend(SparkSubmitResumableBackend):
"""Resumable backend strategy for Kubernetes driver-pod tracking."""

def submit_job(self, context: Context) -> str:
hook = self.operator._hook
driver_id = hook.submit(self.operator.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
self.operator.log.info("Spark driver submitted: %s", driver_id)
return driver_id

def get_job_status(self, external_id: str, context: Context) -> str:
# TODO: call K8s pod status API
raise NotImplementedError("K8s job status not yet implemented")

def is_job_active(self, status: str) -> bool:
return status.upper() in ("PENDING", "RUNNING")

def is_job_succeeded(self, status: str) -> bool:
return status.upper() == "SUCCEEDED"

def poll_until_complete(self, external_id: str, context: Context) -> None:
# TODO: poll K8s pod phase until terminal
raise NotImplementedError("K8s poll not yet implemented")

def on_kill(self) -> None:
self.operator._hook.on_kill()


class StandaloneSparkSubmitBackend(SparkSubmitResumableBackend):
"""Resumable backend strategy for Spark standalone driver-status tracking."""

def submit_job(self, context: Context) -> str:
hook = self.operator._hook
driver_id = hook.submit(self.operator.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
self.operator.log.info("Spark driver submitted: %s", driver_id)
return driver_id

def get_job_status(self, external_id: str, context: Context) -> str:
hook = self.operator._hook
scheme = hook._connection.get("rest_scheme", "http")
rest_port = hook._connection.get("rest_port", 6066)
master_urls = hook._connection["master"].replace("spark://", "").split(",")
last_exc: Exception = RuntimeError("No Spark masters to query")
for m in master_urls:
host = m.strip().split(":")[0]
url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
try:
status = self._fetch_driver_status(url, external_id)
return status
except Exception as e:
self.operator.log.warning("Could not reach Spark master %s: %s", host, e)
last_exc = e
raise last_exc

@retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
def _fetch_driver_status(self, url: str, external_id: str) -> str:
response = requests.get(url, timeout=30)
response.raise_for_status()
data = response.json()
if not data.get("success"):
raise RuntimeError(
f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}"
)
status = data["driverState"]
self.operator.log.info("Driver %s status: %s", external_id, status)
return status

def is_job_active(self, status: str) -> bool:
return status.upper() in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")

def is_job_succeeded(self, status: str) -> bool:
return status.upper() == "FINISHED"

def poll_until_complete(self, external_id: str, context: Context) -> None:
hook = self.operator._hook
self.operator.log.info("Polling driver %s until completion", external_id)
hook._driver_id = external_id
try:
hook._start_driver_status_tracking()
if hook._driver_status != "FINISHED":
raise RuntimeError(f"Driver {external_id} exited with status {hook._driver_status}")
finally:
hook._run_post_submit_commands()

def on_kill(self) -> None:
self.operator._hook.on_kill()


class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
"""
Wrap the spark-submit binary to kick off a spark-submit job; requires "spark-submit" binary in the PATH.
Expand Down Expand Up @@ -290,137 +456,46 @@ def execute(self, context: Context) -> None:
return self.get_job_result(driver_id, context)
hook.submit(self.application)

def submit_job(self, context: Context) -> str:
@property
def _resumable_backend(self) -> SparkSubmitResumableBackend:
if hasattr(self, "_cached_resumable_backend"):
return self._cached_resumable_backend

if self._hook is None:
self._hook = self._get_hook()

if self._hook._is_yarn_cluster_mode:
if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true":
raise ValueError(
"spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts"
"with the need to exit spark-submit immediately to persist the application ID for tracking. "
"Either remove the explicit conf or set reconnect_on_retry=False."
)
self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
self._hook.submit(self.application)
app_id = self._hook._yarn_application_id
if not app_id:
raise RuntimeError("spark-submit did not produce a YARN application ID")
self.log.info("YARN application submitted: %s", app_id)
return app_id
driver_id = self._hook.submit(self.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
self.log.info("Spark driver submitted: %s", driver_id)
return driver_id
backend = YarnSparkSubmitBackend(self)
elif self._hook._is_kubernetes:
backend = KubernetesSparkSubmitBackend(self)
else:
backend = StandaloneSparkSubmitBackend(self)

self._cached_resumable_backend = backend
return backend

def submit_job(self, context: Context) -> str:
return self._resumable_backend.submit_job(context)

def get_job_status(self, external_id: JsonValue, context: Context) -> str:
# called from submit_job which always returns a str (Spark driver IDs are strings)
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn_cluster_mode:
return self._hook.query_yarn_application_status(external_id)
if self._hook._is_kubernetes:
# The K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete)
# are currently unreachable: execute_resumable is only called when _should_track_driver_status
# is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR
# that extends ResumableJobMixin support to Kubernetes.
# TODO: call K8s pod status API
raise NotImplementedError("K8s job status not yet implemented")
scheme = self._hook._connection.get("rest_scheme", "http")
rest_port = self._hook._connection.get("rest_port", 6066)
# HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order.
# The master URL port (e.g. 7077) is the RPC port — not the REST API port.
# Use rest-port connection extra to override spark.master.rest.port (default 6066).
master_urls = self._hook._connection["master"].replace("spark://", "").split(",")
last_exc: Exception = RuntimeError("No Spark masters to query")
for m in master_urls:
host = m.strip().split(":")[0]
url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
try:
status = self._fetch_driver_status(url, external_id)
return status
except Exception as e:
self.log.warning("Could not reach Spark master %s: %s", host, e)
last_exc = e
raise last_exc

@retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
def _fetch_driver_status(self, url: str, external_id: str) -> str:
response = requests.get(url, timeout=30)
response.raise_for_status()
# "success:false" means the master does not recognise the driver ID or is in recovery.
# https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
data = response.json()
if not data.get("success"):
raise RuntimeError(
f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}"
)
status = data["driverState"]
self.log.info("Driver %s status: %s", external_id, status)
return status
return self._resumable_backend.get_job_status(external_id, context)

def is_job_active(self, status: str) -> bool:
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
if self._hook._is_yarn_cluster_mode:
# https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
return status in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"}
if self._hook._is_kubernetes:
return status in ("PENDING", "RUNNING")
# RELAUNCHING: driver is being restarted after a failure, still alive.
# UNKNOWN: master is in failure recovery, state is temporarily unavailable.
# https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
return self._resumable_backend.is_job_active(status)

def is_job_succeeded(self, status: str) -> bool:
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
if self._hook._is_yarn_cluster_mode:
return status == "SUCCEEDED"
if self._hook._is_kubernetes:
return status == "SUCCEEDED"
# standalone and YARN both use FINISHED
return status == "FINISHED"
return self._resumable_backend.is_job_succeeded(status)

def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
# called from submit_job which always returns a str (Spark driver IDs are strings)
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn_cluster_mode:
try:
self._hook._start_yarn_application_status_tracking(external_id)
finally:
self._hook._run_post_submit_commands()
return
if self._hook._is_kubernetes:
# TODO: poll K8s pod phase until terminal
raise NotImplementedError("K8s poll not yet implemented")
self.log.info("Polling driver %s until completion", external_id)
self._hook._driver_id = external_id
try:
self._hook._start_driver_status_tracking()
if self._hook._driver_status != "FINISHED":
raise RuntimeError(f"Driver {external_id} exited with status {self._hook._driver_status}")
finally:
# post-submit commands must fire whether the job succeeded or failed.
self._hook._run_post_submit_commands()
return self._resumable_backend.poll_until_complete(external_id, context)

def get_job_result(self, external_id: JsonValue, context: Context) -> None:
return None

def on_kill(self) -> None:
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn_cluster_mode and self._hook._yarn_application_id:
# spark-submit has already exited (waitAppCompletion=false), so the hook's
# CLI-based kill has nothing to terminate. Kill the YARN app via REST API instead.
self._hook._kill_yarn_application(self._hook._yarn_application_id)
else:
self._hook.on_kill()
self._resumable_backend.on_kill()

def _get_hook(self) -> SparkSubmitHook:
return SparkSubmitHook(
Expand Down
Loading