diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 6bdd4bbcdc772..bb82950172840 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -236,6 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conn_id="spark_k8s", deploy_mode="cluster", track_driver_via_k8s_api=True, + reconnect_on_retry=True, ) **Requirements** @@ -245,8 +246,10 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conflicts with the flag and a ``ValueError`` will be raised at task start. * The Airflow worker must be able to reach the Kubernetes API server and have permission to read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. -* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of - reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time. +* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the driver pod name is + persisted to task state before polling begins, so a worker crash and retry reconnects to the + existing pod instead of submitting a fresh one. Set ``reconnect_on_retry=False`` to always + submit a fresh driver on retry. * Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not advance to ``Succeeded`` until all sidecars exit. In that case the poll loop will wait diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 3a19950696a5e..662966e4e1459 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -1138,8 +1138,14 @@ def _start_driver_status_tracking(self) -> None: f"returncode = {returncode}" ) - def _poll_k8s_driver_via_api(self) -> None: - """Poll the K8s driver pod phase until it reaches a terminal state.""" + def _poll_k8s_driver_via_api(self) -> str | None: + """ + Poll the K8s driver pod phase until it reaches a terminal state. + + Returns the terminal phase string (e.g. ``"Succeeded"``) on normal completion, + or ``None`` if the pod vanished mid-poll (404 — likely deleted by ``on_kill``). + Raises ``RuntimeError`` on failure phases or unrecoverable API errors. + """ pod_name = self._kubernetes_driver_pod namespace = self._connection["namespace"] app_id = self._kubernetes_application_id or pod_name @@ -1173,7 +1179,7 @@ def _poll_k8s_driver_via_api(self) -> None: "Driver pod %s not found (404); pod was likely deleted by on_kill. Exiting poll loop.", pod_name, ) - return + return None consecutive_api_errors += 1 self.log.warning( "ApiException polling pod %s (%d/%d): %s", @@ -1193,6 +1199,18 @@ def _poll_k8s_driver_via_api(self) -> None: phase = pod.status.phase or "Initializing" self.log.info("Application status for %s (phase: %s)", app_id, phase) if phase == "Succeeded": + if pod.status.container_statuses: + cs = pod.status.container_statuses[0] + if cs.state and cs.state.terminated: + t = cs.state.terminated + self.log.info( + "Container final status: exit_code=%s reason=%s started_at=%s finished_at=%s", + t.exit_code, + t.reason, + t.started_at, + t.finished_at, + ) + terminal_phase = phase break if phase == "Failed": container_state = "" @@ -1224,7 +1242,12 @@ def _poll_k8s_driver_via_api(self) -> None: else: consecutive_unknown = 0 time.sleep(poll_interval) - self._delete_driver_pod() + # Pod deletion is best-effort cleanup. If it fails (e.g. already garbage collected or RBAC + # denied), suppress the error so terminal_phase is still returned and the task + # succeeds. Raising here would skip the task_store write and force an unnecessary retry. + with contextlib.suppress(Exception): + self._delete_driver_pod() + return terminal_phase finally: self._run_post_submit_commands() diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index ac9b550409ff0..5a53557f39b97 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -23,13 +23,18 @@ import requests from tenacity import retry, stop_after_attempt, wait_fixed -from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, inject_transport_information_into_spark_properties, ) from airflow.providers.common.compat.sdk import BaseOperator, conf +try: + from airflow.providers.cncf.kubernetes import kube_client +except ImportError: + kube_client = None # type: ignore[assignment] + try: from airflow.sdk import ResumableJobMixin except ImportError: @@ -140,6 +145,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): # YARN application ID, K8s driver pod name). external_id_key = "spark_job_id" + # Used only for k8s cluster mode. Caches the pod phase ("Succeeded" / "Failed") to task_store at the end of + # poll_until_complete. On retry, get_job_status reads this before querying the K8s API + # so that a completed job can be identified even after the driver pod is garbage collected. + _K8S_DRIVER_STATUS_KEY = "k8s_driver_status" + template_fields: Sequence[str] = ( "application", "conf", @@ -269,11 +279,12 @@ def execute(self, context: Context) -> None: self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) if hook._should_track_driver_via_k8s_api(): - # TODO: Wire into execute_resumable() via ResumableJobMixin - # (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery. - hook.submit(self.application) - hook._poll_k8s_driver_via_api() - return + if self.reconnect_on_retry: + return self.execute_resumable(context) + # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. + driver_id = self.submit_job(context) + self.poll_until_complete(driver_id, context) + return self.get_job_result(driver_id, context) if hook._is_yarn_cluster_mode: if self.reconnect_on_retry and not hook._yarn_track_via_rm_api: raise ValueError( @@ -290,9 +301,19 @@ def execute(self, context: Context) -> None: return self.get_job_result(driver_id, context) hook.submit(self.application) - def submit_job(self, context: Context) -> str: + def submit_job(self, context: Context) -> str | None: if self._hook is None: self._hook = self._get_hook() + if self._hook._is_kubernetes: + self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false" + self._hook.submit(self.application) + pod_name = self._hook._kubernetes_driver_pod + namespace = self._hook._connection["namespace"] + if not pod_name: + raise RuntimeError("spark-submit did not capture a K8s driver pod name") + external_id = f"{namespace}:{pod_name}" + self.log.info("Spark K8s driver pod submitted: %s", external_id) + return external_id if self._hook._is_yarn_cluster_mode: if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true": raise ValueError( @@ -321,12 +342,24 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: if self._hook._is_yarn_cluster_mode: return self._hook.query_yarn_application_status(external_id) if self._hook._is_kubernetes: - # The K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete) - # are currently unreachable: execute_resumable is only called when _should_track_driver_status - # is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR - # that extends ResumableJobMixin support to Kubernetes. - # TODO: call K8s pod status API - raise NotImplementedError("K8s job status not yet implemented") + if (task_state_store := context.get("task_state_store")) is not None: + if (cached := task_state_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None: + if not isinstance(cached, str): + raise ValueError(f"Cached K8s driver status is not a string: {cached!r}") + return cached + if kube_client is None: + raise RuntimeError( + "apache-airflow-providers-cncf-kubernetes is required to query K8s pod status" + ) + namespace, pod_name = self._parse_k8s_external_id(external_id) + try: + client = kube_client.get_kube_client() + pod = client.read_namespaced_pod(pod_name, namespace) + return pod.status.phase or "Pending" + except kube_client.ApiException as e: + if e.status == 404: + return "NotFound" + raise scheme = self._hook._connection.get("rest_scheme", "http") rest_port = self._hook._connection.get("rest_port", 6066) # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. @@ -345,6 +378,14 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: last_exc = e raise last_exc + @staticmethod + def _parse_k8s_external_id(external_id: str) -> tuple[str, str]: + """Parse a K8s external ID of the form 'namespace:pod_name' into its components.""" + parts = external_id.split(":", 1) + if len(parts) != 2: + raise ValueError(f"Invalid K8s external ID format {external_id!r}; expected 'namespace:pod_name'") + return parts[0], parts[1] + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) def _fetch_driver_status(self, url: str, external_id: str) -> str: response = requests.get(url, timeout=30) @@ -397,8 +438,18 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: self._hook._run_post_submit_commands() return if self._hook._is_kubernetes: - # TODO: poll K8s pod phase until terminal - raise NotImplementedError("K8s poll not yet implemented") + if external_id is not None: + _, pod_name = self._parse_k8s_external_id(external_id) + self._hook._kubernetes_driver_pod = pod_name + terminal_phase = self._hook._poll_k8s_driver_via_api() + # Cache only when the pod actually reached Succeeded, the 404/vanished path + # returns None for cases like: pod deleted by on_kill or garbage collected after failure) + # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. + if terminal_phase == "Succeeded" and self.reconnect_on_retry: + if (task_state_store := context.get("task_state_store")) is not None: + task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") + return + self.log.info("Polling driver %s until completion", external_id) self._hook._driver_id = external_id try: diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index a1fe95d60036a..180f6cdd38d09 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -322,6 +322,7 @@ def test_inject_simple_openlineage_config_to_spark(self, mock_get_openlineage_li **self._config, ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False operator.execute(MagicMock()) assert operator.conf == { @@ -389,6 +390,7 @@ def test_inject_composite_openlineage_config_to_spark(self, mock_get_openlineage **self._config, ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False operator.execute({"ti": mock_ti}) assert operator.conf == { @@ -428,6 +430,7 @@ def test_inject_openlineage_composite_config_wrong_transport_to_spark( ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False with caplog.at_level(logging.INFO): operator = SparkSubmitOperator( task_id="spark_submit_job", @@ -460,6 +463,7 @@ def test_inject_openlineage_simple_config_wrong_transport_to_spark( ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False with caplog.at_level(logging.INFO): operator = SparkSubmitOperator( task_id="spark_submit_job", @@ -892,7 +896,10 @@ def _make_k8s_hook(self): hook = MagicMock() hook._should_track_driver_status = False hook._should_track_driver_via_k8s_api.return_value = True + hook._is_kubernetes = True + hook._is_yarn = False hook._is_yarn_cluster_mode = False + hook._conf = {} return hook def test_execute_calls_submit_then_poll_when_flag_set(self): @@ -920,3 +927,185 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + def test_k8s_submit_job_returns_encoded_external_id(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + hook._kubernetes_driver_pod = "spark-abc-driver" + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + + result = operator.submit_job(context={}) + + assert result == "mynamespace:spark-abc-driver" + assert hook._conf.get("spark.kubernetes.submission.waitAppCompletion") == "false" + hook.submit.assert_called_once_with("test.jar") + + def test_k8s_submit_job_raises_when_pod_name_missing(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + hook._kubernetes_driver_pod = None + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + + with pytest.raises(RuntimeError, match="did not capture a K8s driver pod name"): + operator.submit_job(context={}) + + def test_k8s_get_job_status_returns_k8s_driver_status(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + task_store = FakeTaskState({"k8s_driver_status": "Succeeded"}) + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_state_store": task_store}) + + assert result == "Succeeded" + mock_kube.get_kube_client.assert_not_called() + + def test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + task_store = FakeTaskState() + + mock_pod = MagicMock() + mock_pod.status.phase = "Running" + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = mock_pod + result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_state_store": task_store}) + + assert result == "Running" + + def test_k8s_get_job_status_returns_pending_when_phase_is_none(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + + mock_pod = MagicMock() + mock_pod.status.phase = None + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = mock_pod + result = operator.get_job_status("mynamespace:spark-abc-driver", {}) + + assert result == "Pending" + + def test_k8s_get_job_status_returns_not_found_on_404(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + + class FakeApiException(Exception): + def __init__(self, status): + self.status = status + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.ApiException = FakeApiException + mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect = FakeApiException(404) + result = operator.get_job_status("mynamespace:spark-abc-driver", {}) + + assert result == "NotFound" + + def test_k8s_get_job_status_reraises_non_404_api_exception(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + + class FakeApiException(Exception): + def __init__(self, status): + self.status = status + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.ApiException = FakeApiException + mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect = FakeApiException(500) + with pytest.raises(FakeApiException): + operator.get_job_status("mynamespace:spark-abc-driver", {}) + + def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + operator._hook = hook + + operator.poll_until_complete("mynamespace:spark-abc-driver", {}) + + assert hook._kubernetes_driver_pod == "spark-abc-driver" + hook._poll_k8s_driver_via_api.assert_called_once() + + def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + hook._poll_k8s_driver_via_api.return_value = "Succeeded" + operator._hook = hook + task_store = FakeTaskState() + + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_state_store": task_store}) + + assert task_store.get("k8s_driver_status") == "Succeeded" + + def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): + operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + hook = self._make_k8s_hook() + hook._poll_k8s_driver_via_api.return_value = "Succeeded" + operator._hook = hook + task_store = FakeTaskState() + + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_state_store": task_store}) + + assert task_store.get("k8s_driver_status") is None + + def test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + hook._poll_k8s_driver_via_api.side_effect = RuntimeError("Spark application failed (phase=Failed)") + operator._hook = hook + task_store = FakeTaskState() + + with pytest.raises(RuntimeError, match="phase=Failed"): + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_state_store": task_store}) + + assert task_store.get("k8s_driver_status") is None + + def test_k8s_poll_until_complete_tolerates_absent_task_store(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + + operator.poll_until_complete("mynamespace:spark-abc-driver", {}) + + @pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", + ) + def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self): + """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" + operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) + hook = self._make_k8s_hook() + hook._kubernetes_driver_pod = "spark-abc-driver" + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + task_store = FakeTaskState() + persisted_before_poll: list[str | None] = [] + + def track_poll(external_id, context): + persisted_before_poll.append(task_store.get("spark_job_id")) + + operator.poll_until_complete = track_poll + + operator.execute(context={"task_state_store": task_store}) + + assert persisted_before_poll == ["mynamespace:spark-abc-driver"] + + @pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", + ) + def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): + """execute() with reconnect_on_retry=False does not write spark_job_id to task_store.""" + operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + hook = self._make_k8s_hook() + hook._kubernetes_driver_pod = "spark-abc-driver" + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + task_store = FakeTaskState() + + operator.poll_until_complete = lambda external_id, context: None + + operator.execute(context={"task_state_store": task_store}) + + assert task_store.get("spark_job_id") is None