-
Notifications
You must be signed in to change notification settings - Fork 17.2k
Add crash recovery ability to SparkSubmitOperator against Kubernetes #68067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f2b426a
4fe565f
35fd5be
6167263
3b94d70
616b3be
fe5ce93
2698348
431a5b5
1b4be42
5620c44
921894f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,13 +23,18 @@ | |
| import requests | ||
| from tenacity import retry, stop_after_attempt, wait_fixed | ||
|
|
||
| from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook | ||
| from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook | ||
| from airflow.providers.common.compat.openlineage.utils.spark import ( | ||
| inject_parent_job_information_into_spark_properties, | ||
| inject_transport_information_into_spark_properties, | ||
| ) | ||
| from airflow.providers.common.compat.sdk import BaseOperator, conf | ||
|
|
||
| try: | ||
| from airflow.providers.cncf.kubernetes import kube_client | ||
| except ImportError: | ||
| kube_client = None # type: ignore[assignment] | ||
|
|
||
| try: | ||
| from airflow.sdk import ResumableJobMixin | ||
| except ImportError: | ||
|
|
@@ -140,6 +145,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): | |
| # YARN application ID, K8s driver pod name). | ||
| external_id_key = "spark_job_id" | ||
|
|
||
| # Used only for k8s cluster mode. Caches the pod phase ("Succeeded" / "Failed") to task_store at the end of | ||
| # poll_until_complete. On retry, get_job_status reads this before querying the K8s API | ||
| # so that a completed job can be identified even after the driver pod is garbage collected. | ||
| _K8S_DRIVER_STATUS_KEY = "k8s_driver_status" | ||
|
|
||
| template_fields: Sequence[str] = ( | ||
| "application", | ||
| "conf", | ||
|
|
@@ -269,11 +279,12 @@ def execute(self, context: Context) -> None: | |
| self.poll_until_complete(driver_id, context) | ||
| return self.get_job_result(driver_id, context) | ||
| if hook._should_track_driver_via_k8s_api(): | ||
| # TODO: Wire into execute_resumable() via ResumableJobMixin | ||
| # (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery. | ||
| hook.submit(self.application) | ||
| hook._poll_k8s_driver_via_api() | ||
| return | ||
| if self.reconnect_on_retry: | ||
| return self.execute_resumable(context) | ||
| # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This path still writes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh god, good catch. The risk is real: a later retry with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handled in 431a5b5 |
||
| driver_id = self.submit_job(context) | ||
| self.poll_until_complete(driver_id, context) | ||
| return self.get_job_result(driver_id, context) | ||
| if hook._is_yarn_cluster_mode: | ||
| if self.reconnect_on_retry and not hook._yarn_track_via_rm_api: | ||
| raise ValueError( | ||
|
|
@@ -290,9 +301,19 @@ def execute(self, context: Context) -> None: | |
| return self.get_job_result(driver_id, context) | ||
| hook.submit(self.application) | ||
|
|
||
| def submit_job(self, context: Context) -> str: | ||
| def submit_job(self, context: Context) -> str | None: | ||
| if self._hook is None: | ||
| self._hook = self._get_hook() | ||
| if self._hook._is_kubernetes: | ||
| self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false" | ||
| self._hook.submit(self.application) | ||
| pod_name = self._hook._kubernetes_driver_pod | ||
| namespace = self._hook._connection["namespace"] | ||
| if not pod_name: | ||
| raise RuntimeError("spark-submit did not capture a K8s driver pod name") | ||
| external_id = f"{namespace}:{pod_name}" | ||
| self.log.info("Spark K8s driver pod submitted: %s", external_id) | ||
| return external_id | ||
| if self._hook._is_yarn_cluster_mode: | ||
| if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true": | ||
| raise ValueError( | ||
|
|
@@ -321,12 +342,24 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: | |
| if self._hook._is_yarn_cluster_mode: | ||
| return self._hook.query_yarn_application_status(external_id) | ||
| if self._hook._is_kubernetes: | ||
| # The K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete) | ||
| # are currently unreachable: execute_resumable is only called when _should_track_driver_status | ||
| # is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR | ||
| # that extends ResumableJobMixin support to Kubernetes. | ||
| # TODO: call K8s pod status API | ||
| raise NotImplementedError("K8s job status not yet implemented") | ||
| if (task_store := context.get("task_store")) is not None: | ||
| if (cached := task_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None: | ||
| if not isinstance(cached, str): | ||
| raise ValueError(f"Cached K8s driver status is not a string: {cached!r}") | ||
| return cached | ||
| if kube_client is None: | ||
| raise RuntimeError( | ||
| "apache-airflow-providers-cncf-kubernetes is required to query K8s pod status" | ||
| ) | ||
| namespace, pod_name = self._parse_k8s_external_id(external_id) | ||
| try: | ||
| client = kube_client.get_kube_client() | ||
| pod = client.read_namespaced_pod(pod_name, namespace) | ||
| return pod.status.phase or "Pending" | ||
| except kube_client.ApiException as e: | ||
| if e.status == 404: | ||
| return "NotFound" | ||
| raise | ||
| scheme = self._hook._connection.get("rest_scheme", "http") | ||
| rest_port = self._hook._connection.get("rest_port", 6066) | ||
| # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. | ||
|
|
@@ -345,6 +378,14 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: | |
| last_exc = e | ||
| raise last_exc | ||
|
|
||
| @staticmethod | ||
| def _parse_k8s_external_id(external_id: str) -> tuple[str, str]: | ||
| """Parse a K8s external ID of the form 'namespace:pod_name' into its components.""" | ||
| parts = external_id.split(":", 1) | ||
| if len(parts) != 2: | ||
| raise ValueError(f"Invalid K8s external ID format {external_id!r}; expected 'namespace:pod_name'") | ||
| return parts[0], parts[1] | ||
|
|
||
| @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) | ||
| def _fetch_driver_status(self, url: str, external_id: str) -> str: | ||
| response = requests.get(url, timeout=30) | ||
|
|
@@ -397,8 +438,18 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: | |
| self._hook._run_post_submit_commands() | ||
| return | ||
| if self._hook._is_kubernetes: | ||
| # TODO: poll K8s pod phase until terminal | ||
| raise NotImplementedError("K8s poll not yet implemented") | ||
| if external_id is not None: | ||
| _, pod_name = self._parse_k8s_external_id(external_id) | ||
| self._hook._kubernetes_driver_pod = pod_name | ||
| terminal_phase = self._hook._poll_k8s_driver_via_api() | ||
| # Cache only when the pod actually reached Succeeded, the 404/vanished path | ||
| # returns None for cases like: pod deleted by on_kill or garbage collected after failure) | ||
| # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. | ||
| if terminal_phase == "Succeeded" and self.reconnect_on_retry: | ||
| if (task_store := context.get("task_store")) is not None: | ||
| task_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") | ||
| return | ||
|
|
||
| self.log.info("Polling driver %s until completion", external_id) | ||
| self._hook._driver_id = external_id | ||
| try: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.