From f2b426a16e3f92e40f7bbd8c7002f5cc7cff300a Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 5 Jun 2026 13:58:14 +0530 Subject: [PATCH 1/8] Add crash recovery ability to SparkSubmitOperator against Kubernetes --- providers/apache/spark/docs/operators.rst | 7 +- .../apache/spark/operators/spark_submit.py | 85 +++++++-- .../spark/operators/test_spark_submit.py | 166 ++++++++++++++++++ 3 files changed, 241 insertions(+), 17 deletions(-) diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 6bdd4bbcdc772..bb82950172840 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -236,6 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conn_id="spark_k8s", deploy_mode="cluster", track_driver_via_k8s_api=True, + reconnect_on_retry=True, ) **Requirements** @@ -245,8 +246,10 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conflicts with the flag and a ``ValueError`` will be raised at task start. * The Airflow worker must be able to reach the Kubernetes API server and have permission to read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. -* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of - reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time. +* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the driver pod name is + persisted to task state before polling begins, so a worker crash and retry reconnects to the + existing pod instead of submitting a fresh one. Set ``reconnect_on_retry=False`` to always + submit a fresh driver on retry. * Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not advance to ``Succeeded`` until all sidecars exit. In that case the poll loop will wait diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 3ac4870f313fd..c42b2aa697631 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -23,13 +23,18 @@ import requests from tenacity import retry, stop_after_attempt, wait_fixed -from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, inject_transport_information_into_spark_properties, ) from airflow.providers.common.compat.sdk import BaseOperator, conf +try: + from airflow.providers.cncf.kubernetes import kube_client +except ImportError: + kube_client = None # type: ignore[assignment] + try: from airflow.sdk.bases.resumablemixin import ResumableJobMixin except ImportError: @@ -140,6 +145,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): # YARN application ID, K8s driver pod name). external_id_key = "spark_job_id" + # Used only for k8s cluster mode. Caches the pod phase ("Succeeded" / "Failed") to task_store at the end of + # poll_until_complete. On retry, get_job_status reads this before querying the K8s API + # so that a completed job can be identified even after the driver pod is garbage collected. + _K8S_DRIVER_STATUS_KEY = "k8s_driver_status" + template_fields: Sequence[str] = ( "application", "conf", @@ -269,16 +279,31 @@ def execute(self, context: Context) -> None: self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) if hook._should_track_driver_via_k8s_api(): - # TODO: Wire into execute_resumable() via ResumableJobMixin - # (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery. - hook.submit(self.application) - hook._poll_k8s_driver_via_api() - return + if self.reconnect_on_retry: + return self.execute_resumable(context) + # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. + driver_id = self.submit_job(context) + self.poll_until_complete(driver_id, context) + return self.get_job_result(driver_id, context) hook.submit(self.application) - def submit_job(self, context: Context) -> str: + def submit_job(self, context: Context) -> str | None: if self._hook is None: self._hook = self._get_hook() + if self._hook._is_kubernetes: + self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false" + self._hook.submit(self.application) + pod_name = self._hook._kubernetes_driver_pod + namespace = self._hook._connection["namespace"] + if not pod_name: + self.log.warning( + "spark-submit did not capture a K8s driver pod name; " + "crash recovery will not be available for this run" + ) + return None + external_id = f"{namespace}:{pod_name}" + self.log.info("Spark K8s driver pod submitted: %s", external_id) + return external_id driver_id = self._hook.submit(self.application) if not driver_id: raise RuntimeError("spark-submit did not return a driver ID") @@ -290,17 +315,35 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: external_id = cast("str", external_id) if self._hook is None: self._hook = self._get_hook() - # The YARN and K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete) - # are currently unreachable: execute_resumable is only called when _should_track_driver_status - # is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR - # that extends ResumableJobMixin support to YARN and Kubernetes. if self._hook._is_yarn: # TODO: call YARN ResourceManager REST API # GET http://rm:8088/ws/v1/cluster/apps/{external_id} raise NotImplementedError("YARN job status not yet implemented") if self._hook._is_kubernetes: - # TODO: call K8s pod status API - raise NotImplementedError("K8s job status not yet implemented") + task_store = context.get("task_store") + if task_store is not None: + cached = task_store.get(self._K8S_DRIVER_STATUS_KEY) + if cached: + return cached + if kube_client is None: + raise RuntimeError( + "apache-airflow-providers-cncf-kubernetes is required to query K8s pod status" + ) + namespace, pod_name = external_id.split(":", 1) + parts = external_id.split(":", 1) + if len(parts) != 2: + raise ValueError( + f"Invalid K8s external ID format {external_id!r}; expected 'namespace:pod_name'" + ) + namespace, pod_name = parts + try: + client = kube_client.get_kube_client() + pod = client.read_namespaced_pod(pod_name, namespace) + return pod.status.phase + except kube_client.ApiException as e: + if e.status == 404: + return "NotFound" + raise scheme = self._hook._connection.get("rest_scheme", "http") rest_port = self._hook._connection.get("rest_port", 6066) # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. @@ -365,9 +408,21 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: if self._hook._is_yarn: # TODO: poll YARN ResourceManager until app reaches terminal state raise NotImplementedError("YARN poll not yet implemented") + if self._hook._is_kubernetes: - # TODO: poll K8s pod phase until terminal - raise NotImplementedError("K8s poll not yet implemented") + if external_id is not None: + _, pod_name = str(external_id).split(":", 1) + self._hook._kubernetes_driver_pod = pod_name + self._hook._poll_k8s_driver_via_api() + # The driver pod is deleted on success, so cache the terminal phase before it + # disappears. Failed jobs raise before reaching here, so only "Succeeded" is ever + # cached. A missing key on retry means the pod was garbage collected after failure, and + # resubmitting fresh is the right behaviour in that case. + task_store = context.get("task_store") + if task_store is not None: + task_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") + return + self.log.info("Polling driver %s until completion", external_id) self._hook._driver_id = external_id try: diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 95ad9f5142a57..072be2d928d49 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -322,6 +322,7 @@ def test_inject_simple_openlineage_config_to_spark(self, mock_get_openlineage_li **self._config, ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False operator.execute(MagicMock()) assert operator.conf == { @@ -389,6 +390,7 @@ def test_inject_composite_openlineage_config_to_spark(self, mock_get_openlineage **self._config, ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False operator.execute({"ti": mock_ti}) assert operator.conf == { @@ -428,6 +430,7 @@ def test_inject_openlineage_composite_config_wrong_transport_to_spark( ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False with caplog.at_level(logging.INFO): operator = SparkSubmitOperator( task_id="spark_submit_job", @@ -460,6 +463,7 @@ def test_inject_openlineage_simple_config_wrong_transport_to_spark( ) mock_get_hook.return_value._should_track_driver_status = False + mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False with caplog.at_level(logging.INFO): operator = SparkSubmitOperator( task_id="spark_submit_job", @@ -508,6 +512,7 @@ def _make_operator(self, **kwargs): def _make_hook(self, should_track=False, is_yarn=False, is_kubernetes=False): hook = MagicMock() hook._should_track_driver_status = should_track + hook._should_track_driver_via_k8s_api.return_value = False hook._is_yarn = is_yarn hook._is_kubernetes = is_kubernetes hook._connection = {"master": "spark://localhost:7077"} @@ -746,6 +751,16 @@ def _make_k8s_hook(self): hook = MagicMock() hook._should_track_driver_status = False hook._should_track_driver_via_k8s_api.return_value = True + hook._is_kubernetes = True + hook._is_yarn = False + hook._conf = {} + return hook + + def _make_k8s_resumable_hook(self): + hook = self._make_k8s_hook() + hook._is_kubernetes = True + hook._is_yarn = False + hook._conf = {} return hook def test_execute_calls_submit_then_poll_when_flag_set(self): @@ -773,3 +788,154 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + def test_k8s_submit_job_returns_encoded_external_id(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_resumable_hook() + hook._kubernetes_driver_pod = "spark-abc-driver" + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + + result = operator.submit_job(context={}) + + assert result == "mynamespace:spark-abc-driver" + assert hook._conf.get("spark.kubernetes.submission.waitAppCompletion") == "false" + hook.submit.assert_called_once_with("test.jar") + + def test_k8s_submit_job_warns_when_pod_name_missing(self, caplog): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_resumable_hook() + hook._kubernetes_driver_pod = None + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + + with caplog.at_level(logging.WARNING): + result = operator.submit_job(context={}) + + assert result is None + assert "crash recovery will not be available" in caplog.text + + def test_k8s_get_job_status_returns_k8s_driver_status(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_resumable_hook() + task_store = FakeTaskState({"k8s_driver_status": "Succeeded"}) + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_store": task_store}) + + assert result == "Succeeded" + mock_kube.get_kube_client.assert_not_called() + + def test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_resumable_hook() + task_store = FakeTaskState() + + mock_pod = MagicMock() + mock_pod.status.phase = "Running" + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = mock_pod + result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_store": task_store}) + + assert result == "Running" + + def test_k8s_get_job_status_returns_not_found_on_404(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_resumable_hook() + + class FakeApiException(Exception): + def __init__(self, status): + self.status = status + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.ApiException = FakeApiException + mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect = FakeApiException(404) + result = operator.get_job_status("mynamespace:spark-abc-driver", {}) + + assert result == "NotFound" + + def test_k8s_get_job_status_reraises_non_404_api_exception(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_resumable_hook() + + class FakeApiException(Exception): + def __init__(self, status): + self.status = status + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.ApiException = FakeApiException + mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect = FakeApiException(500) + with pytest.raises(FakeApiException): + operator.get_job_status("mynamespace:spark-abc-driver", {}) + + def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_resumable_hook() + operator._hook = hook + + operator.poll_until_complete("mynamespace:spark-abc-driver", {}) + + assert hook._kubernetes_driver_pod == "spark-abc-driver" + hook._poll_k8s_driver_via_api.assert_called_once() + + def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_resumable_hook() + task_store = FakeTaskState() + + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) + + assert task_store.get("k8s_driver_status") == "Succeeded" + + def test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_resumable_hook() + hook._poll_k8s_driver_via_api.side_effect = RuntimeError("Spark application failed (phase=Failed)") + operator._hook = hook + task_store = FakeTaskState() + + with pytest.raises(RuntimeError, match="phase=Failed"): + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) + + assert task_store.get("k8s_driver_status") is None + + def test_k8s_poll_until_complete_tolerates_absent_task_store(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_resumable_hook() + + operator.poll_until_complete("mynamespace:spark-abc-driver", {}) + + def test_k8s_execute_persists_pod_id_to_task_store_when_reconnect_on_retry(self): + """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" + operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) + hook = self._make_k8s_resumable_hook() + hook._kubernetes_driver_pod = "spark-abc-driver" + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + task_store = FakeTaskState() + persisted_before_poll: list[str | None] = [] + + def track_poll(external_id, context): + persisted_before_poll.append(task_store.get("spark_job_id")) + + operator.poll_until_complete = track_poll + + operator.execute(context={"task_store": task_store}) + + assert persisted_before_poll == ["mynamespace:spark-abc-driver"] + + def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): + """execute() with reconnect_on_retry=False does not write spark_job_id to task_store.""" + operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + hook = self._make_k8s_resumable_hook() + hook._kubernetes_driver_pod = "spark-abc-driver" + hook._connection = {"namespace": "mynamespace"} + operator._hook = hook + task_store = FakeTaskState() + + operator.poll_until_complete = lambda external_id, context: None + + operator.execute(context={"task_store": task_store}) + + assert task_store.get("spark_job_id") is None From 4fe565f4e3b4698f87148805469463b0855564fd Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 5 Jun 2026 17:10:36 +0530 Subject: [PATCH 2/8] enrich logs --- .../providers/apache/spark/hooks/spark_submit.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 7cf1f3248adc8..763e575aa6ecb 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -1165,6 +1165,17 @@ def _poll_k8s_driver_via_api(self) -> None: phase = pod.status.phase or "Initializing" self.log.info("Application status for %s (phase: %s)", app_id, phase) if phase == "Succeeded": + if pod.status.container_statuses: + cs = pod.status.container_statuses[0] + if cs.state and cs.state.terminated: + t = cs.state.terminated + self.log.info( + "Container final status: exit_code=%s reason=%s started_at=%s finished_at=%s", + t.exit_code, + t.reason, + t.started_at, + t.finished_at, + ) break if phase == "Failed": container_state = "" From 35fd5be4fa524a4844da176b56ab2cf3254a4874 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Sun, 7 Jun 2026 10:38:40 +0530 Subject: [PATCH 3/8] comments from potiuk --- .../providers/apache/spark/operators/spark_submit.py | 3 ++- .../unit/apache/spark/operators/test_spark_submit.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index c42b2aa697631..d452b6e31dc4b 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -324,12 +324,13 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: if task_store is not None: cached = task_store.get(self._K8S_DRIVER_STATUS_KEY) if cached: + if TYPE_CHECKING: + assert isinstance(cached, str) return cached if kube_client is None: raise RuntimeError( "apache-airflow-providers-cncf-kubernetes is required to query K8s pod status" ) - namespace, pod_name = external_id.split(":", 1) parts = external_id.split(":", 1) if len(parts) != 2: raise ValueError( diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 072be2d928d49..bba9223329d1b 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -906,6 +906,10 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self): operator.poll_until_complete("mynamespace:spark-abc-driver", {}) + @pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", + ) def test_k8s_execute_persists_pod_id_to_task_store_when_reconnect_on_retry(self): """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) @@ -925,6 +929,10 @@ def track_poll(external_id, context): assert persisted_before_poll == ["mynamespace:spark-abc-driver"] + @pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", + ) def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): """execute() with reconnect_on_retry=False does not write spark_job_id to task_store.""" operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) From fe5ce9339cf4a58156f5922e78541062c8179bef Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 10 Jun 2026 11:25:30 +0530 Subject: [PATCH 4/8] test fixes --- .../spark/operators/test_spark_submit.py | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 73603f5a2dd64..530a1d40c5ca8 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -902,13 +902,6 @@ def _make_k8s_hook(self): hook._conf = {} return hook - def _make_k8s_resumable_hook(self): - hook = self._make_k8s_hook() - hook._is_kubernetes = True - hook._is_yarn = False - hook._conf = {} - return hook - def test_execute_calls_submit_then_poll_when_flag_set(self): operator = self._make_operator(track_driver_via_k8s_api=True) hook = self._make_k8s_hook() @@ -937,7 +930,7 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): def test_k8s_submit_job_returns_encoded_external_id(self): operator = self._make_operator(track_driver_via_k8s_api=True) - hook = self._make_k8s_resumable_hook() + hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} operator._hook = hook @@ -950,7 +943,7 @@ def test_k8s_submit_job_returns_encoded_external_id(self): def test_k8s_submit_job_warns_when_pod_name_missing(self, caplog): operator = self._make_operator(track_driver_via_k8s_api=True) - hook = self._make_k8s_resumable_hook() + hook = self._make_k8s_hook() hook._kubernetes_driver_pod = None hook._connection = {"namespace": "mynamespace"} operator._hook = hook @@ -963,7 +956,7 @@ def test_k8s_submit_job_warns_when_pod_name_missing(self, caplog): def test_k8s_get_job_status_returns_k8s_driver_status(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_resumable_hook() + operator._hook = self._make_k8s_hook() task_store = FakeTaskState({"k8s_driver_status": "Succeeded"}) with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: @@ -974,7 +967,7 @@ def test_k8s_get_job_status_returns_k8s_driver_status(self): def test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_resumable_hook() + operator._hook = self._make_k8s_hook() task_store = FakeTaskState() mock_pod = MagicMock() @@ -988,7 +981,7 @@ def test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self): def test_k8s_get_job_status_returns_not_found_on_404(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_resumable_hook() + operator._hook = self._make_k8s_hook() class FakeApiException(Exception): def __init__(self, status): @@ -1003,7 +996,7 @@ def __init__(self, status): def test_k8s_get_job_status_reraises_non_404_api_exception(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_resumable_hook() + operator._hook = self._make_k8s_hook() class FakeApiException(Exception): def __init__(self, status): @@ -1017,7 +1010,7 @@ def __init__(self, status): def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): operator = self._make_operator(track_driver_via_k8s_api=True) - hook = self._make_k8s_resumable_hook() + hook = self._make_k8s_hook() operator._hook = hook operator.poll_until_complete("mynamespace:spark-abc-driver", {}) @@ -1027,7 +1020,7 @@ def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_resumable_hook() + operator._hook = self._make_k8s_hook() task_store = FakeTaskState() operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) @@ -1036,7 +1029,7 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): def test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self): operator = self._make_operator(track_driver_via_k8s_api=True) - hook = self._make_k8s_resumable_hook() + hook = self._make_k8s_hook() hook._poll_k8s_driver_via_api.side_effect = RuntimeError("Spark application failed (phase=Failed)") operator._hook = hook task_store = FakeTaskState() @@ -1048,7 +1041,7 @@ def test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self): def test_k8s_poll_until_complete_tolerates_absent_task_store(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_resumable_hook() + operator._hook = self._make_k8s_hook() operator.poll_until_complete("mynamespace:spark-abc-driver", {}) @@ -1059,7 +1052,7 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self): def test_k8s_execute_persists_pod_id_to_task_store_when_reconnect_on_retry(self): """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) - hook = self._make_k8s_resumable_hook() + hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} operator._hook = hook @@ -1082,7 +1075,7 @@ def track_poll(external_id, context): def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): """execute() with reconnect_on_retry=False does not write spark_job_id to task_store.""" operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) - hook = self._make_k8s_resumable_hook() + hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} operator._hook = hook From 431a5b51a3ff32fb7969a65d7836d51da39c488b Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 11 Jun 2026 10:22:25 +0530 Subject: [PATCH 5/8] comments from jason and kaxil --- .../apache/spark/hooks/spark_submit.py | 14 ++++-- .../apache/spark/operators/spark_submit.py | 50 +++++++++---------- .../spark/operators/test_spark_submit.py | 31 +++++++++--- 3 files changed, 59 insertions(+), 36 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index c13ae8cb2191e..7a8c7e2c2c3bf 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -1138,8 +1138,14 @@ def _start_driver_status_tracking(self) -> None: f"returncode = {returncode}" ) - def _poll_k8s_driver_via_api(self) -> None: - """Poll the K8s driver pod phase until it reaches a terminal state.""" + def _poll_k8s_driver_via_api(self) -> str | None: + """ + Poll the K8s driver pod phase until it reaches a terminal state. + + Returns the terminal phase string (e.g. ``"Succeeded"``) on normal completion, + or ``None`` if the pod vanished mid-poll (404 — likely deleted by ``on_kill``). + Raises ``RuntimeError`` on failure phases or unrecoverable API errors. + """ pod_name = self._kubernetes_driver_pod namespace = self._connection["namespace"] app_id = self._kubernetes_application_id or pod_name @@ -1173,7 +1179,7 @@ def _poll_k8s_driver_via_api(self) -> None: "Driver pod %s not found (404); pod was likely deleted by on_kill. Exiting poll loop.", pod_name, ) - return + return None consecutive_api_errors += 1 self.log.warning( "ApiException polling pod %s (%d/%d): %s", @@ -1204,6 +1210,7 @@ def _poll_k8s_driver_via_api(self) -> None: t.started_at, t.finished_at, ) + terminal_phase = phase break if phase == "Failed": container_state = "" @@ -1236,6 +1243,7 @@ def _poll_k8s_driver_via_api(self) -> None: consecutive_unknown = 0 time.sleep(poll_interval) self._delete_driver_pod() + return terminal_phase finally: self._run_post_submit_commands() diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index ea79b390b482d..3594b2f88ef6c 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -310,11 +310,7 @@ def submit_job(self, context: Context) -> str | None: pod_name = self._hook._kubernetes_driver_pod namespace = self._hook._connection["namespace"] if not pod_name: - self.log.warning( - "spark-submit did not capture a K8s driver pod name; " - "crash recovery will not be available for this run" - ) - return None + raise RuntimeError("spark-submit did not capture a K8s driver pod name") external_id = f"{namespace}:{pod_name}" self.log.info("Spark K8s driver pod submitted: %s", external_id) return external_id @@ -346,27 +342,20 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: if self._hook._is_yarn_cluster_mode: return self._hook.query_yarn_application_status(external_id) if self._hook._is_kubernetes: - task_store = context.get("task_store") - if task_store is not None: - cached = task_store.get(self._K8S_DRIVER_STATUS_KEY) - if cached: - if TYPE_CHECKING: - assert isinstance(cached, str) + if (task_store := context.get("task_store")) is not None: + if (cached := task_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None: + if not isinstance(cached, str): + raise ValueError(f"Cached K8s driver status is not a string: {cached!r}") return cached if kube_client is None: raise RuntimeError( "apache-airflow-providers-cncf-kubernetes is required to query K8s pod status" ) - parts = external_id.split(":", 1) - if len(parts) != 2: - raise ValueError( - f"Invalid K8s external ID format {external_id!r}; expected 'namespace:pod_name'" - ) - namespace, pod_name = parts + namespace, pod_name = self._parse_k8s_external_id(external_id) try: client = kube_client.get_kube_client() pod = client.read_namespaced_pod(pod_name, namespace) - return pod.status.phase + return pod.status.phase or "Pending" except kube_client.ApiException as e: if e.status == 404: return "NotFound" @@ -389,6 +378,14 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: last_exc = e raise last_exc + @staticmethod + def _parse_k8s_external_id(external_id: str) -> tuple[str, str]: + """Parse a K8s external ID of the form 'namespace:pod_name' into its components.""" + parts = external_id.split(":", 1) + if len(parts) != 2: + raise ValueError(f"Invalid K8s external ID format {external_id!r}; expected 'namespace:pod_name'") + return parts[0], parts[1] + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) def _fetch_driver_status(self, url: str, external_id: str) -> str: response = requests.get(url, timeout=30) @@ -442,16 +439,15 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: return if self._hook._is_kubernetes: if external_id is not None: - _, pod_name = str(external_id).split(":", 1) + _, pod_name = self._parse_k8s_external_id(external_id) self._hook._kubernetes_driver_pod = pod_name - self._hook._poll_k8s_driver_via_api() - # The driver pod is deleted on success, so cache the terminal phase before it - # disappears. Failed jobs raise before reaching here, so only "Succeeded" is ever - # cached. A missing key on retry means the pod was garbage collected after failure, and - # resubmitting fresh is the right behaviour in that case. - task_store = context.get("task_store") - if task_store is not None: - task_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") + terminal_phase = self._hook._poll_k8s_driver_via_api() + # Cache only when the pod actually reached Succeeded, the 404/vanished path + # returns None for cases like: pod deleted by on_kill or garbage collected after failure) + # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. + if terminal_phase == "Succeeded" and self.reconnect_on_retry: + if (task_store := context.get("task_store")) is not None: + task_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") return self.log.info("Polling driver %s until completion", external_id) diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 530a1d40c5ca8..4b9c65356b963 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -941,18 +941,15 @@ def test_k8s_submit_job_returns_encoded_external_id(self): assert hook._conf.get("spark.kubernetes.submission.waitAppCompletion") == "false" hook.submit.assert_called_once_with("test.jar") - def test_k8s_submit_job_warns_when_pod_name_missing(self, caplog): + def test_k8s_submit_job_raises_when_pod_name_missing(self): operator = self._make_operator(track_driver_via_k8s_api=True) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = None hook._connection = {"namespace": "mynamespace"} operator._hook = hook - with caplog.at_level(logging.WARNING): - result = operator.submit_job(context={}) - - assert result is None - assert "crash recovery will not be available" in caplog.text + with pytest.raises(RuntimeError, match="did not capture a K8s driver pod name"): + operator.submit_job(context={}) def test_k8s_get_job_status_returns_k8s_driver_status(self): operator = self._make_operator(track_driver_via_k8s_api=True) @@ -979,6 +976,19 @@ def test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self): assert result == "Running" + def test_k8s_get_job_status_returns_pending_when_phase_is_none(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + operator._hook = self._make_k8s_hook() + + mock_pod = MagicMock() + mock_pod.status.phase = None + + with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: + mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = mock_pod + result = operator.get_job_status("mynamespace:spark-abc-driver", {}) + + assert result == "Pending" + def test_k8s_get_job_status_returns_not_found_on_404(self): operator = self._make_operator(track_driver_via_k8s_api=True) operator._hook = self._make_k8s_hook() @@ -1027,6 +1037,15 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): assert task_store.get("k8s_driver_status") == "Succeeded" + def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): + operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + operator._hook = self._make_k8s_hook() + task_store = FakeTaskState() + + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) + + assert task_store.get("k8s_driver_status") is None + def test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self): operator = self._make_operator(track_driver_via_k8s_api=True) hook = self._make_k8s_hook() From 1b4be42b2a607fc1128872bb43a49916f0f29f7c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 11 Jun 2026 11:40:14 +0530 Subject: [PATCH 6/8] fixing tests --- .../unit/apache/spark/operators/test_spark_submit.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 4b9c65356b963..01bb69692c224 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -1030,7 +1030,9 @@ def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): operator = self._make_operator(track_driver_via_k8s_api=True) - operator._hook = self._make_k8s_hook() + hook = self._make_k8s_hook() + hook._poll_k8s_driver_via_api.return_value = "Succeeded" + operator._hook = hook task_store = FakeTaskState() operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) @@ -1039,7 +1041,9 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) - operator._hook = self._make_k8s_hook() + hook = self._make_k8s_hook() + hook._poll_k8s_driver_via_api.return_value = "Succeeded" + operator._hook = hook task_store = FakeTaskState() operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) From 5620c44a3c75de216b17a59b38964b0484b451e0 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Sat, 13 Jun 2026 16:29:59 +0530 Subject: [PATCH 7/8] comments from jason --- .../airflow/providers/apache/spark/hooks/spark_submit.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 7a8c7e2c2c3bf..662966e4e1459 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -1242,7 +1242,11 @@ def _poll_k8s_driver_via_api(self) -> str | None: else: consecutive_unknown = 0 time.sleep(poll_interval) - self._delete_driver_pod() + # Pod deletion is best-effort cleanup. If it fails (e.g. already garbage collected or RBAC + # denied), suppress the error so terminal_phase is still returned and the task + # succeeds. Raising here would skip the task_store write and force an unnecessary retry. + with contextlib.suppress(Exception): + self._delete_driver_pod() return terminal_phase finally: self._run_post_submit_commands() From 47aa8982dd9dbcfb395c5af69c259ecad0bb7d37 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 15 Jun 2026 11:10:15 +0530 Subject: [PATCH 8/8] fixing rebase issues --- .../apache/spark/operators/spark_submit.py | 8 ++++---- .../apache/spark/operators/test_spark_submit.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 3594b2f88ef6c..5a53557f39b97 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -342,8 +342,8 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str: if self._hook._is_yarn_cluster_mode: return self._hook.query_yarn_application_status(external_id) if self._hook._is_kubernetes: - if (task_store := context.get("task_store")) is not None: - if (cached := task_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None: + if (task_state_store := context.get("task_state_store")) is not None: + if (cached := task_state_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None: if not isinstance(cached, str): raise ValueError(f"Cached K8s driver status is not a string: {cached!r}") return cached @@ -446,8 +446,8 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: # returns None for cases like: pod deleted by on_kill or garbage collected after failure) # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. if terminal_phase == "Succeeded" and self.reconnect_on_retry: - if (task_store := context.get("task_store")) is not None: - task_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") + if (task_state_store := context.get("task_state_store")) is not None: + task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") return self.log.info("Polling driver %s until completion", external_id) diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 6b56231e7a34b..180f6cdd38d09 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -957,7 +957,7 @@ def test_k8s_get_job_status_returns_k8s_driver_status(self): task_store = FakeTaskState({"k8s_driver_status": "Succeeded"}) with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: - result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_store": task_store}) + result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_state_store": task_store}) assert result == "Succeeded" mock_kube.get_kube_client.assert_not_called() @@ -972,7 +972,7 @@ def test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self): with mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") as mock_kube: mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = mock_pod - result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_store": task_store}) + result = operator.get_job_status("mynamespace:spark-abc-driver", {"task_state_store": task_store}) assert result == "Running" @@ -1035,7 +1035,7 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): operator._hook = hook task_store = FakeTaskState() - operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_state_store": task_store}) assert task_store.get("k8s_driver_status") == "Succeeded" @@ -1046,7 +1046,7 @@ def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): operator._hook = hook task_store = FakeTaskState() - operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_state_store": task_store}) assert task_store.get("k8s_driver_status") is None @@ -1058,7 +1058,7 @@ def test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self): task_store = FakeTaskState() with pytest.raises(RuntimeError, match="phase=Failed"): - operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_store": task_store}) + operator.poll_until_complete("mynamespace:spark-abc-driver", {"task_state_store": task_store}) assert task_store.get("k8s_driver_status") is None @@ -1072,7 +1072,7 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_persists_pod_id_to_task_store_when_reconnect_on_retry(self): + def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self): """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) hook = self._make_k8s_hook() @@ -1087,7 +1087,7 @@ def track_poll(external_id, context): operator.poll_until_complete = track_poll - operator.execute(context={"task_store": task_store}) + operator.execute(context={"task_state_store": task_store}) assert persisted_before_poll == ["mynamespace:spark-abc-driver"] @@ -1106,6 +1106,6 @@ def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): operator.poll_until_complete = lambda external_id, context: None - operator.execute(context={"task_store": task_store}) + operator.execute(context={"task_state_store": task_store}) assert task_store.get("spark_job_id") is None