diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index ac9b550409ff0..8eee0c3c0c216 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Sequence from typing import TYPE_CHECKING, Any, cast @@ -54,6 +55,171 @@ def execute_resumable(self, context): from airflow.providers.common.compat.sdk import Context +class SparkSubmitResumableBackend(ABC): + """Base interface/abstract class for Spark submit resumable backends.""" + + def __init__(self, operator: SparkSubmitOperator) -> None: + self.operator = operator + + @abstractmethod + def submit_job(self, context: Context) -> str: + """Submit the Spark job and return the external/driver ID.""" + + @abstractmethod + def get_job_status(self, external_id: str, context: Context) -> str: + """Retrieve the status of the Spark job.""" + + @abstractmethod + def is_job_active(self, status: str) -> bool: + """Check if the Spark job is in an active state.""" + + @abstractmethod + def is_job_succeeded(self, status: str) -> bool: + """Check if the Spark job has completed successfully.""" + + @abstractmethod + def poll_until_complete(self, external_id: str, context: Context) -> None: + """Poll the Spark job status until it reaches a terminal state.""" + + @abstractmethod + def on_kill(self) -> None: + """Handle execution termination/kill signal.""" + + +class YarnSparkSubmitBackend(SparkSubmitResumableBackend): + """Resumable backend strategy for YARN cluster mode.""" + + def submit_job(self, context: Context) -> str: + hook = self.operator._hook + if hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true": + raise ValueError( + "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" + "with the need to exit spark-submit immediately to persist the application ID for tracking. " + "Either remove the explicit conf or set reconnect_on_retry=False." + ) + hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" + hook.submit(self.operator.application) + app_id = hook._yarn_application_id + if not app_id: + raise RuntimeError("spark-submit did not produce a YARN application ID") + self.operator.log.info("YARN application submitted: %s", app_id) + return app_id + + def get_job_status(self, external_id: str, context: Context) -> str: + return self.operator._hook.query_yarn_application_status(external_id) + + def is_job_active(self, status: str) -> bool: + return status.upper() in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"} + + def is_job_succeeded(self, status: str) -> bool: + return status.upper() == "SUCCEEDED" + + def poll_until_complete(self, external_id: str, context: Context) -> None: + hook = self.operator._hook + try: + hook._start_yarn_application_status_tracking(external_id) + finally: + hook._run_post_submit_commands() + + def on_kill(self) -> None: + hook = self.operator._hook + if hook._yarn_application_id: + hook._kill_yarn_application(hook._yarn_application_id) + else: + hook.on_kill() + + +class KubernetesSparkSubmitBackend(SparkSubmitResumableBackend): + """Resumable backend strategy for Kubernetes driver-pod tracking.""" + + def submit_job(self, context: Context) -> str: + hook = self.operator._hook + driver_id = hook.submit(self.operator.application) + if not driver_id: + raise RuntimeError("spark-submit did not return a driver ID") + self.operator.log.info("Spark driver submitted: %s", driver_id) + return driver_id + + def get_job_status(self, external_id: str, context: Context) -> str: + # TODO: call K8s pod status API + raise NotImplementedError("K8s job status not yet implemented") + + def is_job_active(self, status: str) -> bool: + return status.upper() in ("PENDING", "RUNNING") + + def is_job_succeeded(self, status: str) -> bool: + return status.upper() == "SUCCEEDED" + + def poll_until_complete(self, external_id: str, context: Context) -> None: + # TODO: poll K8s pod phase until terminal + raise NotImplementedError("K8s poll not yet implemented") + + def on_kill(self) -> None: + self.operator._hook.on_kill() + + +class StandaloneSparkSubmitBackend(SparkSubmitResumableBackend): + """Resumable backend strategy for Spark standalone driver-status tracking.""" + + def submit_job(self, context: Context) -> str: + hook = self.operator._hook + driver_id = hook.submit(self.operator.application) + if not driver_id: + raise RuntimeError("spark-submit did not return a driver ID") + self.operator.log.info("Spark driver submitted: %s", driver_id) + return driver_id + + def get_job_status(self, external_id: str, context: Context) -> str: + hook = self.operator._hook + scheme = hook._connection.get("rest_scheme", "http") + rest_port = hook._connection.get("rest_port", 6066) + master_urls = hook._connection["master"].replace("spark://", "").split(",") + last_exc: Exception = RuntimeError("No Spark masters to query") + for m in master_urls: + host = m.strip().split(":")[0] + url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}" + try: + status = self._fetch_driver_status(url, external_id) + return status + except Exception as e: + self.operator.log.warning("Could not reach Spark master %s: %s", host, e) + last_exc = e + raise last_exc + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) + def _fetch_driver_status(self, url: str, external_id: str) -> str: + response = requests.get(url, timeout=30) + response.raise_for_status() + data = response.json() + if not data.get("success"): + raise RuntimeError( + f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}" + ) + status = data["driverState"] + self.operator.log.info("Driver %s status: %s", external_id, status) + return status + + def is_job_active(self, status: str) -> bool: + return status.upper() in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN") + + def is_job_succeeded(self, status: str) -> bool: + return status.upper() == "FINISHED" + + def poll_until_complete(self, external_id: str, context: Context) -> None: + hook = self.operator._hook + self.operator.log.info("Polling driver %s until completion", external_id) + hook._driver_id = external_id + try: + hook._start_driver_status_tracking() + if hook._driver_status != "FINISHED": + raise RuntimeError(f"Driver {external_id} exited with status {hook._driver_status}") + finally: + hook._run_post_submit_commands() + + def on_kill(self) -> None: + self.operator._hook.on_kill() + + class SparkSubmitOperator(ResumableJobMixin, BaseOperator): """ Wrap the spark-submit binary to kick off a spark-submit job; requires "spark-submit" binary in the PATH. @@ -290,137 +456,46 @@ def execute(self, context: Context) -> None: return self.get_job_result(driver_id, context) hook.submit(self.application) - def submit_job(self, context: Context) -> str: + @property + def _resumable_backend(self) -> SparkSubmitResumableBackend: + if hasattr(self, "_cached_resumable_backend"): + return self._cached_resumable_backend + if self._hook is None: self._hook = self._get_hook() + if self._hook._is_yarn_cluster_mode: - if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true": - raise ValueError( - "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" - "with the need to exit spark-submit immediately to persist the application ID for tracking. " - "Either remove the explicit conf or set reconnect_on_retry=False." - ) - self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" - self._hook.submit(self.application) - app_id = self._hook._yarn_application_id - if not app_id: - raise RuntimeError("spark-submit did not produce a YARN application ID") - self.log.info("YARN application submitted: %s", app_id) - return app_id - driver_id = self._hook.submit(self.application) - if not driver_id: - raise RuntimeError("spark-submit did not return a driver ID") - self.log.info("Spark driver submitted: %s", driver_id) - return driver_id + backend = YarnSparkSubmitBackend(self) + elif self._hook._is_kubernetes: + backend = KubernetesSparkSubmitBackend(self) + else: + backend = StandaloneSparkSubmitBackend(self) + + self._cached_resumable_backend = backend + return backend + + def submit_job(self, context: Context) -> str: + return self._resumable_backend.submit_job(context) def get_job_status(self, external_id: JsonValue, context: Context) -> str: - # called from submit_job which always returns a str (Spark driver IDs are strings) external_id = cast("str", external_id) - if self._hook is None: - self._hook = self._get_hook() - if self._hook._is_yarn_cluster_mode: - return self._hook.query_yarn_application_status(external_id) - if self._hook._is_kubernetes: - # The K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete) - # are currently unreachable: execute_resumable is only called when _should_track_driver_status - # is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR - # that extends ResumableJobMixin support to Kubernetes. - # TODO: call K8s pod status API - raise NotImplementedError("K8s job status not yet implemented") - scheme = self._hook._connection.get("rest_scheme", "http") - rest_port = self._hook._connection.get("rest_port", 6066) - # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. - # The master URL port (e.g. 7077) is the RPC port — not the REST API port. - # Use rest-port connection extra to override spark.master.rest.port (default 6066). - master_urls = self._hook._connection["master"].replace("spark://", "").split(",") - last_exc: Exception = RuntimeError("No Spark masters to query") - for m in master_urls: - host = m.strip().split(":")[0] - url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}" - try: - status = self._fetch_driver_status(url, external_id) - return status - except Exception as e: - self.log.warning("Could not reach Spark master %s: %s", host, e) - last_exc = e - raise last_exc - - @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) - def _fetch_driver_status(self, url: str, external_id: str) -> str: - response = requests.get(url, timeout=30) - response.raise_for_status() - # "success:false" means the master does not recognise the driver ID or is in recovery. - # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala - data = response.json() - if not data.get("success"): - raise RuntimeError( - f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}" - ) - status = data["driverState"] - self.log.info("Driver %s status: %s", external_id, status) - return status + return self._resumable_backend.get_job_status(external_id, context) def is_job_active(self, status: str) -> bool: - if self._hook is None: - self._hook = self._get_hook() - status = status.upper() - if self._hook._is_yarn_cluster_mode: - # https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html - return status in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"} - if self._hook._is_kubernetes: - return status in ("PENDING", "RUNNING") - # RELAUNCHING: driver is being restarted after a failure, still alive. - # UNKNOWN: master is in failure recovery, state is temporarily unavailable. - # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala - return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN") + return self._resumable_backend.is_job_active(status) def is_job_succeeded(self, status: str) -> bool: - if self._hook is None: - self._hook = self._get_hook() - status = status.upper() - if self._hook._is_yarn_cluster_mode: - return status == "SUCCEEDED" - if self._hook._is_kubernetes: - return status == "SUCCEEDED" - # standalone and YARN both use FINISHED - return status == "FINISHED" + return self._resumable_backend.is_job_succeeded(status) def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: - # called from submit_job which always returns a str (Spark driver IDs are strings) external_id = cast("str", external_id) - if self._hook is None: - self._hook = self._get_hook() - if self._hook._is_yarn_cluster_mode: - try: - self._hook._start_yarn_application_status_tracking(external_id) - finally: - self._hook._run_post_submit_commands() - return - if self._hook._is_kubernetes: - # TODO: poll K8s pod phase until terminal - raise NotImplementedError("K8s poll not yet implemented") - self.log.info("Polling driver %s until completion", external_id) - self._hook._driver_id = external_id - try: - self._hook._start_driver_status_tracking() - if self._hook._driver_status != "FINISHED": - raise RuntimeError(f"Driver {external_id} exited with status {self._hook._driver_status}") - finally: - # post-submit commands must fire whether the job succeeded or failed. - self._hook._run_post_submit_commands() + return self._resumable_backend.poll_until_complete(external_id, context) def get_job_result(self, external_id: JsonValue, context: Context) -> None: return None def on_kill(self) -> None: - if self._hook is None: - self._hook = self._get_hook() - if self._hook._is_yarn_cluster_mode and self._hook._yarn_application_id: - # spark-submit has already exited (waitAppCompletion=false), so the hook's - # CLI-based kill has nothing to terminate. Kill the YARN app via REST API instead. - self._hook._kill_yarn_application(self._hook._yarn_application_id) - else: - self._hook.on_kill() + self._resumable_backend.on_kill() def _get_hook(self) -> SparkSubmitHook: return SparkSubmitHook(