Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions providers/apache/spark/docs/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full
conn_id="spark_k8s",
deploy_mode="cluster",
track_driver_via_k8s_api=True,
reconnect_on_retry=True,
)

**Requirements**
Expand All @@ -245,8 +246,10 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full
conflicts with the flag and a ``ValueError`` will be raised at task start.
* The Airflow worker must be able to reach the Kubernetes API server and have permission to
read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail.
* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of
reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time.
* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the driver pod name is
persisted to task state before polling begins, so a worker crash and retry reconnects to the
existing pod instead of submitting a fresh one. Set ``reconnect_on_retry=False`` to always
submit a fresh driver on retry.
* Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar
containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not
advance to ``Succeeded`` until all sidecars exit. In that case the poll loop will wait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,14 @@ def _start_driver_status_tracking(self) -> None:
f"returncode = {returncode}"
)

def _poll_k8s_driver_via_api(self) -> None:
"""Poll the K8s driver pod phase until it reaches a terminal state."""
def _poll_k8s_driver_via_api(self) -> str | None:
"""
Poll the K8s driver pod phase until it reaches a terminal state.

Returns the terminal phase string (e.g. ``"Succeeded"``) on normal completion,
or ``None`` if the pod vanished mid-poll (404 — likely deleted by ``on_kill``).
Raises ``RuntimeError`` on failure phases or unrecoverable API errors.
"""
pod_name = self._kubernetes_driver_pod
namespace = self._connection["namespace"]
app_id = self._kubernetes_application_id or pod_name
Expand Down Expand Up @@ -1173,7 +1179,7 @@ def _poll_k8s_driver_via_api(self) -> None:
"Driver pod %s not found (404); pod was likely deleted by on_kill. Exiting poll loop.",
pod_name,
)
return
return None
consecutive_api_errors += 1
self.log.warning(
"ApiException polling pod %s (%d/%d): %s",
Expand All @@ -1193,6 +1199,18 @@ def _poll_k8s_driver_via_api(self) -> None:
phase = pod.status.phase or "Initializing"
self.log.info("Application status for %s (phase: %s)", app_id, phase)
if phase == "Succeeded":
if pod.status.container_statuses:
cs = pod.status.container_statuses[0]
if cs.state and cs.state.terminated:
t = cs.state.terminated
self.log.info(
"Container final status: exit_code=%s reason=%s started_at=%s finished_at=%s",
t.exit_code,
t.reason,
t.started_at,
t.finished_at,
)
terminal_phase = phase
break
if phase == "Failed":
container_state = ""
Expand Down Expand Up @@ -1224,7 +1242,12 @@ def _poll_k8s_driver_via_api(self) -> None:
else:
consecutive_unknown = 0
time.sleep(poll_interval)
self._delete_driver_pod()
# Pod deletion is best-effort cleanup. If it fails (e.g. already garbage collected or RBAC
# denied), suppress the error so terminal_phase is still returned and the task
# succeeds. Raising here would skip the task_store write and force an unnecessary retry.
with contextlib.suppress(Exception):
self._delete_driver_pod()
return terminal_phase
finally:
self._run_post_submit_commands()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import requests
from tenacity import retry, stop_after_attempt, wait_fixed

from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
inject_transport_information_into_spark_properties,
)
from airflow.providers.common.compat.sdk import BaseOperator, conf

try:
from airflow.providers.cncf.kubernetes import kube_client
except ImportError:
kube_client = None # type: ignore[assignment]

try:
from airflow.sdk import ResumableJobMixin
except ImportError:
Expand Down Expand Up @@ -140,6 +145,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
# YARN application ID, K8s driver pod name).
external_id_key = "spark_job_id"

# Used only for k8s cluster mode. Caches the pod phase ("Succeeded" / "Failed") to task_store at the end of
# poll_until_complete. On retry, get_job_status reads this before querying the K8s API
# so that a completed job can be identified even after the driver pod is garbage collected.
_K8S_DRIVER_STATUS_KEY = "k8s_driver_status"
Comment thread
amoghrajesh marked this conversation as resolved.

template_fields: Sequence[str] = (
"application",
"conf",
Expand Down Expand Up @@ -269,11 +279,12 @@ def execute(self, context: Context) -> None:
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
if hook._should_track_driver_via_k8s_api():
# TODO: Wire into execute_resumable() via ResumableJobMixin
# (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery.
hook.submit(self.application)
hook._poll_k8s_driver_via_api()
return
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path still writes k8s_driver_status="Succeeded" to the task store via poll_until_complete, so persistence isn't fully skipped. Mostly cosmetic today, but nothing ever clears that key, so if the operator is later switched to reconnect_on_retry=True, a retry there reads the stale "Succeeded" before the live pod phase and can skip a job that is still running. Maybe clear or overwrite the key when submitting fresh?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh god, good catch. The risk is real: a later retry with reconnect_on_retry=True reads the stale "Succeeded" from a previous non-resumable run and exits immediately, treating a newly-submitted pod as complete. Fixed by gating the write on self.reconnect_on_retry, the cache only exists to serve crash recovery so there's no reason to write it when reconnect is off. Added a test to cover this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled in 431a5b5

driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
if hook._is_yarn_cluster_mode:
if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
raise ValueError(
Expand All @@ -290,9 +301,19 @@ def execute(self, context: Context) -> None:
return self.get_job_result(driver_id, context)
hook.submit(self.application)

def submit_job(self, context: Context) -> str:
def submit_job(self, context: Context) -> str | None:
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_kubernetes:
self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false"
self._hook.submit(self.application)
pod_name = self._hook._kubernetes_driver_pod
namespace = self._hook._connection["namespace"]
if not pod_name:
raise RuntimeError("spark-submit did not capture a K8s driver pod name")
external_id = f"{namespace}:{pod_name}"
self.log.info("Spark K8s driver pod submitted: %s", external_id)
return external_id
if self._hook._is_yarn_cluster_mode:
if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true":
raise ValueError(
Expand Down Expand Up @@ -321,12 +342,24 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str:
if self._hook._is_yarn_cluster_mode:
return self._hook.query_yarn_application_status(external_id)
if self._hook._is_kubernetes:
# The K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete)
# are currently unreachable: execute_resumable is only called when _should_track_driver_status
# is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR
# that extends ResumableJobMixin support to Kubernetes.
# TODO: call K8s pod status API
raise NotImplementedError("K8s job status not yet implemented")
if (task_store := context.get("task_store")) is not None:
if (cached := task_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None:
if not isinstance(cached, str):
raise ValueError(f"Cached K8s driver status is not a string: {cached!r}")
return cached
if kube_client is None:
raise RuntimeError(
"apache-airflow-providers-cncf-kubernetes is required to query K8s pod status"
)
namespace, pod_name = self._parse_k8s_external_id(external_id)
try:
client = kube_client.get_kube_client()
pod = client.read_namespaced_pod(pod_name, namespace)
return pod.status.phase or "Pending"
except kube_client.ApiException as e:
if e.status == 404:
return "NotFound"
raise
scheme = self._hook._connection.get("rest_scheme", "http")
rest_port = self._hook._connection.get("rest_port", 6066)
# HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order.
Expand All @@ -345,6 +378,14 @@ def get_job_status(self, external_id: JsonValue, context: Context) -> str:
last_exc = e
raise last_exc

@staticmethod
def _parse_k8s_external_id(external_id: str) -> tuple[str, str]:
"""Parse a K8s external ID of the form 'namespace:pod_name' into its components."""
parts = external_id.split(":", 1)
if len(parts) != 2:
raise ValueError(f"Invalid K8s external ID format {external_id!r}; expected 'namespace:pod_name'")
return parts[0], parts[1]

@retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
def _fetch_driver_status(self, url: str, external_id: str) -> str:
response = requests.get(url, timeout=30)
Expand Down Expand Up @@ -397,8 +438,18 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
self._hook._run_post_submit_commands()
return
if self._hook._is_kubernetes:
# TODO: poll K8s pod phase until terminal
raise NotImplementedError("K8s poll not yet implemented")
if external_id is not None:
_, pod_name = self._parse_k8s_external_id(external_id)
self._hook._kubernetes_driver_pod = pod_name
terminal_phase = self._hook._poll_k8s_driver_via_api()
# Cache only when the pod actually reached Succeeded, the 404/vanished path
# returns None for cases like: pod deleted by on_kill or garbage collected after failure)
# and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission.
if terminal_phase == "Succeeded" and self.reconnect_on_retry:
if (task_store := context.get("task_store")) is not None:
task_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded")
return

self.log.info("Polling driver %s until completion", external_id)
self._hook._driver_id = external_id
try:
Expand Down
Loading
Loading