From 7350e3fda2b2d31b1a1445bd551e7d576d955276 Mon Sep 17 00:00:00 2001 From: deepinsight coder Date: Sat, 13 Jun 2026 21:44:44 +0000 Subject: [PATCH] Fix Databricks operators with templated json payloads --- .../providers/databricks/exceptions.py | 4 + .../databricks/operators/databricks.py | 387 ++++++++++++------ .../databricks/operators/test_databricks.py | 212 ++++++++-- scripts/ci/prek/known_airflow_exceptions.txt | 2 +- 4 files changed, 446 insertions(+), 159 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/exceptions.py b/providers/databricks/src/airflow/providers/databricks/exceptions.py index f384552a34a6e..59c8f3fb60649 100644 --- a/providers/databricks/src/airflow/providers/databricks/exceptions.py +++ b/providers/databricks/src/airflow/providers/databricks/exceptions.py @@ -30,3 +30,7 @@ class DatabricksSqlExecutionError(AirflowException): class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError): """Raised when a sql execution times out.""" + + +class DatabricksOperatorPayloadError(AirflowException): + """Raised when a Databricks operator payload is invalid.""" diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 9898993d4147e..bbeb329798797 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -19,14 +19,17 @@ from __future__ import annotations +import ast import hashlib +import json as json_utils import time from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf +from airflow.providers.databricks.exceptions import DatabricksOperatorPayloadError from airflow.providers.databricks.hooks.databricks import ( DatabricksHook, RunLifeCycleState, @@ -281,6 +284,50 @@ def _inject_airflow_params_into_task(task: dict, params: dict) -> None: task_def[field] = dict(params) +def _coerce_json_to_dict(json: Any) -> dict[str, Any]: + if json is None: + return {} + if isinstance(json, Mapping): + return dict(json) + if isinstance(json, str): + return _parse_json_string_to_dict(json) + raise DatabricksOperatorPayloadError( + f"Databricks json payload must resolve to a mapping, not {type(json).__name__}." + ) + + +def _parse_json_string_to_dict(json: str) -> dict[str, Any]: + if not json: + return {} + try: + parsed_json = json_utils.loads(json) + except json_utils.JSONDecodeError: + try: + parsed_json = ast.literal_eval(json) + except (SyntaxError, ValueError, TypeError, MemoryError) as err: + raise DatabricksOperatorPayloadError( + "Databricks json payload string must be valid JSON or a Python literal dict." + ) from err + + if not isinstance(parsed_json, Mapping): + raise DatabricksOperatorPayloadError( + f"Databricks json payload must resolve to a mapping, not {type(parsed_json).__name__}." + ) + return dict(parsed_json) + + +def _merge_json_with_named_parameters( + json: Any, named_parameters: Mapping[str, Any | None] +) -> dict[str, Any]: + merged_json = _coerce_json_to_dict(json) + merged_json.update( + (param_name, param_value) + for param_name, param_value in named_parameters.items() + if param_value is not None + ) + return merged_json + + class DatabricksJobRunLink(BaseOperatorLink): """Constructs a link to monitor a Databricks Job Run.""" @@ -353,7 +400,23 @@ class DatabricksCreateJobsOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_fields: Sequence[str] = ( + "json", + "name", + "description", + "tags", + "tasks", + "job_clusters", + "email_notifications", + "webhook_notifications", + "notification_settings", + "timeout_seconds", + "schedule", + "max_concurrent_runs", + "git_source", + "access_control_list", + "databricks_conn_id", + ) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" ui_fgcolor = "#fff" @@ -384,40 +447,45 @@ def __init__( ) -> None: """Create a new ``DatabricksCreateJobsOperator``.""" super().__init__(**kwargs) - self.json = json or {} + self.json = json + self.name = name + self.description = description + self.tags = tags + self.tasks = tasks + self.job_clusters = job_clusters + self.email_notifications = email_notifications + self.webhook_notifications = webhook_notifications + self.notification_settings = notification_settings + self.timeout_seconds = timeout_seconds + self.schedule = schedule + self.max_concurrent_runs = max_concurrent_runs + self.git_source = git_source + self.access_control_list = access_control_list self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args - if name is not None: - self.json["name"] = name - if description is not None: - self.json["description"] = description - if tags is not None: - self.json["tags"] = tags - if tasks is not None: - self.json["tasks"] = tasks - if job_clusters is not None: - self.json["job_clusters"] = job_clusters - if email_notifications is not None: - self.json["email_notifications"] = email_notifications - if webhook_notifications is not None: - self.json["webhook_notifications"] = webhook_notifications - if notification_settings is not None: - self.json["notification_settings"] = notification_settings - if timeout_seconds is not None: - self.json["timeout_seconds"] = timeout_seconds - if schedule is not None: - self.json["schedule"] = schedule - if max_concurrent_runs is not None: - self.json["max_concurrent_runs"] = max_concurrent_runs - if git_source is not None: - self.json["git_source"] = git_source - if access_control_list is not None: - self.json["access_control_list"] = access_control_list - if self.json: - self.json = normalise_json_content(self.json) + + def _get_named_json_parameters(self) -> dict[str, Any | None]: + return { + "name": self.name, + "description": self.description, + "tags": self.tags, + "tasks": self.tasks, + "job_clusters": self.job_clusters, + "email_notifications": self.email_notifications, + "webhook_notifications": self.webhook_notifications, + "notification_settings": self.notification_settings, + "timeout_seconds": self.timeout_seconds, + "schedule": self.schedule, + "max_concurrent_runs": self.max_concurrent_runs, + "git_source": self.git_source, + "access_control_list": self.access_control_list, + } + + def _get_merged_json(self) -> dict[str, Any]: + return _merge_json_with_named_parameters(self.json, self._get_named_json_parameters()) @cached_property def _hook(self): @@ -430,14 +498,16 @@ def _hook(self): ) def execute(self, context: Context) -> int: - if "name" not in self.json: + json = cast("dict[str, Any]", normalise_json_content(self._get_merged_json())) + if "name" not in json: raise AirflowException("Missing required parameter: name") - job_id = self._hook.find_job_id_by_name(self.json["name"]) - if not self.json.get("parameters") and self.params: - self.json["parameters"] = [{"name": k, "default": v} for k, v in dict(self.params).items()] + job_id = self._hook.find_job_id_by_name(json["name"]) + if not json.get("parameters") and self.params: + json["parameters"] = [{"name": k, "default": v} for k, v in dict(self.params).items()] + self.json = json if job_id is None: - return self._hook.create_job(self.json) - self._hook.reset_job(str(job_id), self.json) + return self._hook.create_job(json) + self._hook.reset_job(str(job_id), json) return job_id @@ -572,7 +642,25 @@ class DatabricksSubmitRunOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_fields: Sequence[str] = ( + "json", + "tasks", + "spark_jar_task", + "notebook_task", + "spark_python_task", + "spark_submit_task", + "pipeline_task", + "dbt_task", + "new_cluster", + "existing_cluster_id", + "libraries", + "run_name", + "timeout_seconds", + "idempotency_token", + "access_control_list", + "git_source", + "databricks_conn_id", + ) template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -610,7 +698,22 @@ def __init__( ) -> None: """Create a new ``DatabricksSubmitRunOperator``.""" super().__init__(**kwargs) - self.json = json or {} + self.json = json + self.tasks = tasks + self.spark_jar_task = spark_jar_task + self.notebook_task = notebook_task + self.spark_python_task = spark_python_task + self.spark_submit_task = spark_submit_task + self.pipeline_task = pipeline_task + self.dbt_task = dbt_task + self.new_cluster = new_cluster + self.existing_cluster_id = existing_cluster_id + self.libraries = libraries + self.run_name = run_name + self.timeout_seconds = timeout_seconds + self.idempotency_token = idempotency_token + self.access_control_list = access_control_list + self.git_source = git_source self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit @@ -618,48 +721,50 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable - if tasks is not None: - self.json["tasks"] = tasks - if spark_jar_task is not None: - self.json["spark_jar_task"] = spark_jar_task - if notebook_task is not None: - self.json["notebook_task"] = notebook_task - if spark_python_task is not None: - self.json["spark_python_task"] = spark_python_task - if spark_submit_task is not None: - self.json["spark_submit_task"] = spark_submit_task - if pipeline_task is not None: - self.json["pipeline_task"] = pipeline_task - if dbt_task is not None: - self.json["dbt_task"] = dbt_task - if new_cluster is not None: - self.json["new_cluster"] = new_cluster - if existing_cluster_id is not None: - self.json["existing_cluster_id"] = existing_cluster_id - if libraries is not None: - self.json["libraries"] = libraries - if run_name is not None: - self.json["run_name"] = run_name - if timeout_seconds is not None: - self.json["timeout_seconds"] = timeout_seconds - if "run_name" not in self.json: - self.json["run_name"] = run_name or kwargs["task_id"] - if idempotency_token is not None: - self.json["idempotency_token"] = idempotency_token - if access_control_list is not None: - self.json["access_control_list"] = access_control_list - if git_source is not None: - self.json["git_source"] = git_source - - if "dbt_task" in self.json and "git_source" not in self.json: - raise AirflowException("git_source is required for dbt_task") - if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task: - raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'") # This variable will be used in case our task gets killed. self.run_id: int | None = None self.do_xcom_push = do_xcom_push + def _get_named_json_parameters(self) -> dict[str, Any | None]: + return { + "tasks": self.tasks, + "spark_jar_task": self.spark_jar_task, + "notebook_task": self.notebook_task, + "spark_python_task": self.spark_python_task, + "spark_submit_task": self.spark_submit_task, + "pipeline_task": self.pipeline_task, + "dbt_task": self.dbt_task, + "new_cluster": self.new_cluster, + "existing_cluster_id": self.existing_cluster_id, + "libraries": self.libraries, + "run_name": self.run_name, + "timeout_seconds": self.timeout_seconds, + "idempotency_token": self.idempotency_token, + "access_control_list": self.access_control_list, + "git_source": self.git_source, + } + + def _get_merged_json(self) -> dict[str, Any]: + json = _merge_json_with_named_parameters(self.json, self._get_named_json_parameters()) + if "run_name" not in json: + json["run_name"] = self.task_id + return json + + @staticmethod + def _validate_merged_json(json: Mapping[str, Any]) -> None: + if "dbt_task" in json and "git_source" not in json: + raise DatabricksOperatorPayloadError("git_source is required for dbt_task") + pipeline_task = json.get("pipeline_task") + if ( + isinstance(pipeline_task, Mapping) + and "pipeline_id" in pipeline_task + and "pipeline_name" in pipeline_task + ): + raise DatabricksOperatorPayloadError( + "'pipeline_name' is not allowed in conjunction with 'pipeline_id'" + ) + @cached_property def _hook(self): return self._get_hook(caller="DatabricksSubmitRunOperator") @@ -674,28 +779,31 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def execute(self, context: Context): + json = self._get_merged_json() + self._validate_merged_json(json) if ( - "pipeline_task" in self.json - and self.json["pipeline_task"].get("pipeline_id") is None - and self.json["pipeline_task"].get("pipeline_name") + isinstance(json.get("pipeline_task"), Mapping) + and json["pipeline_task"].get("pipeline_id") is None + and json["pipeline_task"].get("pipeline_name") ): # If pipeline_id is not provided, we need to fetch it from the pipeline_name - pipeline_name = self.json["pipeline_task"]["pipeline_name"] - self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name) - del self.json["pipeline_task"]["pipeline_name"] + pipeline_name = json["pipeline_task"]["pipeline_name"] + json["pipeline_task"] = dict(json["pipeline_task"]) + json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name) + del json["pipeline_task"]["pipeline_name"] if self.params: params_dump = dict(self.params) - tasks = self.json.get("tasks") + tasks = json.get("tasks") if isinstance(tasks, list): for task in tasks: if isinstance(task, dict): _inject_airflow_params_into_task(task, params_dump) else: - _inject_airflow_params_into_task(self.json, params_dump) + _inject_airflow_params_into_task(json, params_dump) - json_normalised = normalise_json_content(self.json) - self.run_id = self._hook.submit_run(json_normalised) + self.json = normalise_json_content(json) + self.run_id = self._hook.submit_run(self.json) if self.deferrable: _handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context) else: @@ -902,7 +1010,20 @@ class DatabricksRunNowOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_fields: Sequence[str] = ( + "json", + "job_id", + "job_name", + "job_parameters", + "dbt_commands", + "notebook_params", + "python_params", + "python_named_params", + "jar_params", + "spark_submit_params", + "idempotency_token", + "databricks_conn_id", + ) template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -938,7 +1059,17 @@ def __init__( ) -> None: """Create a new ``DatabricksRunNowOperator``.""" super().__init__(**kwargs) - self.json = json or {} + self.json = json + self.job_id = job_id + self.job_name = job_name + self.job_parameters = job_parameters + self.dbt_commands = dbt_commands + self.notebook_params = notebook_params + self.python_params = python_params + self.python_named_params = python_named_params + self.jar_params = jar_params + self.spark_submit_params = spark_submit_params + self.idempotency_token = idempotency_token self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit @@ -950,34 +1081,32 @@ def __init__( self.databricks_repair_reason_new_settings = databricks_repair_reason_new_settings or {} self.cancel_previous_runs = cancel_previous_runs - if job_id is not None: - self.json["job_id"] = job_id - if job_name is not None: - self.json["job_name"] = job_name - if "job_id" in self.json and "job_name" in self.json: - raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") - if notebook_params is not None: - self.json["notebook_params"] = notebook_params - if python_params is not None: - self.json["python_params"] = python_params - if python_named_params is not None: - self.json["python_named_params"] = python_named_params - if jar_params is not None: - self.json["jar_params"] = jar_params - if spark_submit_params is not None: - self.json["spark_submit_params"] = spark_submit_params - if idempotency_token is not None: - self.json["idempotency_token"] = idempotency_token - if job_parameters is not None: - self.json["job_parameters"] = job_parameters - if dbt_commands is not None: - self.json["dbt_commands"] = dbt_commands - if self.json: - self.json = normalise_json_content(self.json) # This variable will be used in case our task gets killed. self.run_id: int | None = None self.do_xcom_push = do_xcom_push + def _get_named_json_parameters(self) -> dict[str, Any | None]: + return { + "job_id": self.job_id, + "job_name": self.job_name, + "job_parameters": self.job_parameters, + "dbt_commands": self.dbt_commands, + "notebook_params": self.notebook_params, + "python_params": self.python_params, + "python_named_params": self.python_named_params, + "jar_params": self.jar_params, + "spark_submit_params": self.spark_submit_params, + "idempotency_token": self.idempotency_token, + } + + def _get_merged_json(self) -> dict[str, Any]: + return _merge_json_with_named_parameters(self.json, self._get_named_json_parameters()) + + @staticmethod + def _validate_merged_json(json: Mapping[str, Any]) -> None: + if "job_id" in json and "job_name" in json: + raise DatabricksOperatorPayloadError("Argument 'job_name' is not allowed with argument 'job_id'") + @cached_property def _hook(self): return self._get_hook(caller="DatabricksRunNowOperator") @@ -992,26 +1121,32 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def execute(self, context: Context): + json = self._get_merged_json() + self._validate_merged_json(json) hook = self._hook - if "job_name" in self.json: - job_id = hook.find_job_id_by_name(self.json["job_name"]) + if "job_name" in json: + job_id = hook.find_job_id_by_name(json["job_name"]) if job_id is None: - raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found") - self.json["job_id"] = job_id - del self.json["job_name"] + raise DatabricksOperatorPayloadError( + f"Job ID for job name {json['job_name']} can not be found" + ) + json["job_id"] = job_id + del json["job_name"] if self.cancel_previous_runs: - if (job_id := self.json.get("job_id")) is None: + if (job_id := json.get("job_id")) is None: raise ValueError( "cancel_previous_runs=True requires either job_id or job_name to be provided." ) hook.cancel_all_runs(job_id) - if not self.json.get("job_parameters") and self.params: - self.json["job_parameters"] = dict(self.params) + json = cast("dict[str, Any]", normalise_json_content(json)) + if not json.get("job_parameters") and self.params: + json["job_parameters"] = dict(self.params) - self.run_id = hook.run_now(self.json) + self.json = json + self.run_id = hook.run_now(json) if self.deferrable: _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) else: @@ -1036,9 +1171,11 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None repair_json = {"run_id": self.run_id, "rerun_all_failed_tasks": True} if latest_repair_id is not None: repair_json["latest_repair_id"] = latest_repair_id - if "job_parameters" in self.json: - repair_json["job_parameters"] = self.json["job_parameters"] - self.json["latest_repair_id"] = self._hook.repair_run(repair_json) + json = _coerce_json_to_dict(self.json) + if "job_parameters" in json: + repair_json["job_parameters"] = json["job_parameters"] + json["latest_repair_id"] = self._hook.repair_run(repair_json) + self.json = json _handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context) def on_kill(self) -> None: diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 4684b14282c4e..6df4ec78d1fc7 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -35,7 +35,7 @@ ExternalQueryRunFacet, SQLJobFacet, ) -from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred +from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskDeferred from airflow.providers.databricks.hooks.databricks import RunState, SQLStatementState from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, @@ -345,7 +345,7 @@ def test_init_with_named_parameters(self): } ) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_json(self): """ @@ -382,7 +382,7 @@ def test_init_with_json(self): } ) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_merging(self): """ @@ -447,7 +447,7 @@ def test_init_with_merging(self): } ) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_templating(self): json = {"name": "test-{{ ds }}"} @@ -456,7 +456,30 @@ def test_init_with_templating(self): op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={"ds": DATE}) expected = utils.normalise_json_content({"name": f"test-{DATE}"}) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_with_rendered_python_literal_json_and_templated_named_parameters(self, db_mock_class): + class FakeTaskInstance: + @staticmethod + def xcom_pull(task_ids): + return {"name": JOB_NAME, "tasks": TASKS} + + op = DatabricksCreateJobsOperator( + task_id=TASK_ID, + json="{{ ti.xcom_pull(task_ids='payload') }}", + name="templated-{{ ds }}", + ) + op.render_template_fields(context={"ti": FakeTaskInstance(), "ds": DATE}) + db_mock = db_mock_class.return_value + db_mock.create_job.return_value = JOB_ID + db_mock.find_job_id_by_name.return_value = None + + return_result = op.execute({}) + + expected = utils.normalise_json_content({"name": f"templated-{DATE}", "tasks": TASKS}) + db_mock.create_job.assert_called_once_with(expected) + assert return_result == JOB_ID def test_init_with_bad_type(self): json = {"test": datetime.now()} @@ -465,8 +488,9 @@ def test_init_with_bad_type(self): r"Type \<(type|class) \'datetime.datetime\'\> used " r"for parameter json\[test\] is not a number or a string" ) + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) with pytest.raises(AirflowException, match=exception_message): - DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + op.execute(None) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_create(self, db_mock_class): @@ -690,7 +714,7 @@ def test_init_with_notebook_task_named_parameters(self): {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_spark_python_task_named_parameters(self): """ @@ -703,7 +727,7 @@ def test_init_with_spark_python_task_named_parameters(self): {"new_cluster": NEW_CLUSTER, "spark_python_task": SPARK_PYTHON_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_pipeline_name_task_named_parameters(self): """ @@ -712,7 +736,7 @@ def test_init_with_pipeline_name_task_named_parameters(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_NAME_TASK) expected = utils.normalise_json_content({"pipeline_task": PIPELINE_NAME_TASK, "run_name": TASK_ID}) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_pipeline_id_task_named_parameters(self): """ @@ -721,7 +745,7 @@ def test_init_with_pipeline_id_task_named_parameters(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_ID_TASK) expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_spark_submit_task_named_parameters(self): """ @@ -734,7 +758,7 @@ def test_init_with_spark_submit_task_named_parameters(self): {"new_cluster": NEW_CLUSTER, "spark_submit_task": SPARK_SUBMIT_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_dbt_task_named_parameters(self): """ @@ -752,7 +776,7 @@ def test_init_with_dbt_task_named_parameters(self): {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_dbt_task_mixed_parameters(self): """ @@ -771,15 +795,16 @@ def test_init_with_dbt_task_mixed_parameters(self): {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_dbt_task_without_git_source_raises_error(self): """ Test the initializer without the necessary git_source for dbt_task raises error. """ exception_message = "git_source is required for dbt_task" + op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK) with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK) + op.execute(None) def test_init_with_dbt_task_json_without_git_source_raises_error(self): """ @@ -788,8 +813,9 @@ def test_init_with_dbt_task_json_without_git_source_raises_error(self): json = {"dbt_task": DBT_TASK, "new_cluster": NEW_CLUSTER} exception_message = "git_source is required for dbt_task" + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + op.execute(None) def test_init_with_json(self): """ @@ -800,13 +826,13 @@ def test_init_with_json(self): expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_tasks(self): tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}] op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks) expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks": tasks}) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_specified_run_name(self): """ @@ -817,7 +843,7 @@ def test_init_with_specified_run_name(self): expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_pipeline_task(self): """ @@ -829,7 +855,7 @@ def test_pipeline_task(self): expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task, "run_name": RUN_NAME} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_merging(self): """ @@ -850,7 +876,7 @@ def test_init_with_merging(self): "run_name": TASK_ID, } ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_templating(self): json = { @@ -867,7 +893,37 @@ def test_init_with_templating(self): "run_name": TASK_ID, } ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_with_xcom_arg_json_and_templated_named_parameters(self, db_mock_class): + with DAG("test", schedule=None, start_date=datetime.now()): + producer = BaseOperator(task_id="producer") + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, + json=producer.output, + new_cluster={**NEW_CLUSTER, "spark_version": "{{ ds }}"}, + wait_for_termination=False, + ) + ti = MagicMock() + ti.xcom_pull.return_value = { + "new_cluster": {"spark_version": "old", "node_type_id": "old", "num_workers": 1}, + "notebook_task": NOTEBOOK_TASK, + } + op.render_template_fields(context={"ti": ti, "ds": DATE, "expanded_ti_count": None}) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = RUN_ID + + op.execute(None) + + expected = utils.normalise_json_content( + { + "new_cluster": {**NEW_CLUSTER, "spark_version": DATE}, + "notebook_task": NOTEBOOK_TASK, + "run_name": TASK_ID, + } + ) + db_mock.submit_run.assert_called_once_with(expected) def test_init_with_git_source(self): json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} @@ -885,7 +941,7 @@ def test_init_with_git_source(self): "git_source": git_source, } ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_bad_type(self): json = {"test": datetime.now()} @@ -896,7 +952,7 @@ def test_init_with_bad_type(self): r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - utils.normalise_json_content(op.json) + op.execute(None) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_success(self, db_mock_class): @@ -1343,6 +1399,48 @@ def test_submit_run_does_not_override_existing_task_parameters(self, db_mock_cla actual = db_mock.submit_run.call_args.args[0] assert actual["notebook_task"]["base_parameters"] == {"explicit": "value"} + @pytest.mark.parametrize( + ("json", "exception_message"), + [ + pytest.param("[1, 2]", "Databricks json payload must resolve to a mapping", id="list"), + pytest.param("{not-valid", "Databricks json payload string must be valid JSON", id="invalid"), + ], + ) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_with_invalid_rendered_json_raises(self, db_mock_class, json, exception_message): + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + + with pytest.raises(AirflowException, match=exception_message): + op.execute(None) + + db_mock_class.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_with_rendered_dbt_task_without_git_source_raises(self, db_mock_class): + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, + json='{"new_cluster": {"spark_version": "1"}, "dbt_task": {"commands": ["dbt run"]}}', + ) + + with pytest.raises(AirflowException, match="git_source is required for dbt_task"): + op.execute(None) + + db_mock_class.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_with_rendered_pipeline_id_and_name_raises(self, db_mock_class): + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, + json='{"pipeline_task": {"pipeline_id": "1234abcd", "pipeline_name": "pipeline"}}', + ) + + with pytest.raises( + AirflowException, match="'pipeline_name' is not allowed in conjunction with 'pipeline_id'" + ): + op.execute(None) + + db_mock_class.assert_not_called() + class TestDatabricksRunNowOperator: def test_init_with_named_parameters(self): @@ -1352,7 +1450,7 @@ def test_init_with_named_parameters(self): op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID) expected = utils.normalise_json_content({"job_id": 42}) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_json(self): """ @@ -1381,7 +1479,7 @@ def test_init_with_json(self): } ) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_merging(self): """ @@ -1415,7 +1513,7 @@ def test_init_with_merging(self): } ) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) def test_init_with_templating(self): json = {"notebook_params": NOTEBOOK_PARAMS, "jar_params": TEMPLATED_JAR_PARAMS} @@ -1430,17 +1528,45 @@ def test_init_with_templating(self): "job_id": JOB_ID, } ) - assert expected == op.json + assert expected == utils.normalise_json_content(op._get_merged_json()) - def test_init_with_bad_type(self): + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_with_json_string_and_templated_named_parameters(self, db_mock_class): + op = DatabricksRunNowOperator( + task_id=TASK_ID, + json='{"job_id": "1", "notebook_params": {"source": "json"}, "jar_params": ["json"]}', + job_id="{{ params.job_id }}", + notebook_params={"date": "{{ ds }}"}, + wait_for_termination=False, + ) + op.render_template_fields(context={"ds": DATE, "params": {"job_id": JOB_ID}}) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + + op.execute(None) + + expected = utils.normalise_json_content( + { + "job_id": JOB_ID, + "notebook_params": {"date": DATE}, + "jar_params": ["json"], + } + ) + db_mock.run_now.assert_called_once_with(expected) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_init_with_bad_type(self, db_mock_class): json = {"test": datetime.now()} # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( r"Type \<(type|class) \'datetime.datetime\'\> used " r"for parameter json\[test\] is not a number or a string" ) + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) + op.execute(None) + + db_mock_class.assert_called_once() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_success(self, db_mock_class): @@ -1709,19 +1835,39 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_not_called() - def test_init_exception_with_job_name_and_job_id(self): + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_init_exception_with_job_name_and_job_id(self, db_mock_class): exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME) with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME) + op.execute(None) run = {"job_id": JOB_ID, "job_name": JOB_NAME} + op = DatabricksRunNowOperator(task_id=TASK_ID, json=run) with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, json=run) + op.execute(None) run = {"job_id": JOB_ID} + op = DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME) with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME) + op.execute(None) + + db_mock_class.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_exception_with_rendered_job_name_and_job_id(self, db_mock_class): + op = DatabricksRunNowOperator( + task_id=TASK_ID, + json='{"job_id": "42", "job_name": "job-name"}', + ) + + with pytest.raises( + AirflowException, match="Argument 'job_name' is not allowed with argument 'job_id'" + ): + op.execute(None) + + db_mock_class.assert_not_called() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_with_job_name(self, db_mock_class): diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt index 04c6e9534f07f..262f5d6ce547f 100644 --- a/scripts/ci/prek/known_airflow_exceptions.txt +++ b/scripts/ci/prek/known_airflow_exceptions.txt @@ -176,7 +176,7 @@ providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py::1 providers/databricks/src/airflow/providers/databricks/hooks/databricks.py::8 providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py::46 providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py::2 -providers/databricks/src/airflow/providers/databricks/operators/databricks.py::10 +providers/databricks/src/airflow/providers/databricks/operators/databricks.py::6 providers/databricks/src/airflow/providers/databricks/operators/databricks_repos.py::12 providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py::8 providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py::4