diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a4ce80a9..fbef31c32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Features - Add `invocation_id` to the default query comment ([#1377](https://github.com/databricks/dbt-databricks/issues/1377)) +- Include `job_id`, `run_id`, and `task_key` from Databricks Job context in `adapter_response`, enabling correlation between dbt runs and Databricks workflow executions via `run_results.json` ([#722](https://github.com/databricks/dbt-databricks/issues/722)) ### Fixes diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 3c9fd3867..f4bfe58c8 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -46,7 +46,12 @@ ConnectionCreateError, ) from dbt.adapters.databricks.events.other_events import QueryError -from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils +from dbt.adapters.databricks.handle import ( + CursorWrapper, + DatabricksAdapterResponse, + DatabricksHandle, + SqlUtils, +) from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials @@ -549,7 +554,7 @@ def get_response(cls, cursor: Any) -> AdapterResponse: if isinstance(cursor, CursorWrapper): return cursor.get_response() else: - return AdapterResponse("OK") + return DatabricksAdapterResponse.from_cursor(cursor) def clear_transaction(self) -> None: """Noop.""" diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py index aadf871c1..cc99424d5 100644 --- a/dbt/adapters/databricks/handle.py +++ b/dbt/adapters/databricks/handle.py @@ -1,7 +1,9 @@ import decimal +import os import re import sys from collections.abc import Callable, Sequence +from dataclasses import dataclass from types import TracebackType from typing import TYPE_CHECKING, Any, Optional, TypeVar @@ -28,6 +30,46 @@ FailLogOp = Callable[[Exception], str] +# Databricks sets these environment variables when dbt runs inside a Databricks +# workflow job task. Reading them once per process is sufficient since they are +# constant for the lifetime of the task execution. +_JOB_RUN_ENV_VARS = { + "job_id": "DATABRICKS_JOB_ID", + "run_id": "DATABRICKS_RUN_ID", + "task_key": "DATABRICKS_TASK_KEY", +} + + +def _get_job_run_context() -> dict[str, Optional[str]]: + """Return Databricks job-run context from environment variables. + + Returns a dict with all three keys; values are ``None`` when the + corresponding environment variable is not set (i.e. dbt is running + outside of a Databricks Job). + """ + return {key: os.environ.get(env_var) for key, env_var in _JOB_RUN_ENV_VARS.items()} + + +@dataclass +class DatabricksAdapterResponse(AdapterResponse): + """Extends the base adapter response with Databricks Job context.""" + + job_id: Optional[str] = None + run_id: Optional[str] = None + task_key: Optional[str] = None + + @classmethod + def from_cursor(cls, cursor: Any) -> "DatabricksAdapterResponse": + ctx = _get_job_run_context() + return cls( + _message="OK", + query_id=getattr(cursor, "query_id", None) or "N/A", + job_id=ctx.get("job_id"), + run_id=ctx.get("run_id"), + task_key=ctx.get("task_key"), + ) + + class CursorWrapper: """ Wrap the DBSQL cursor to abstract the details from DatabricksConnectionManager. @@ -79,7 +121,7 @@ def fetchmany(self, size: int) -> Sequence[tuple]: return self._safe_execute(lambda cursor: cursor.fetchmany(size)) def get_response(self) -> AdapterResponse: - return AdapterResponse(_message="OK", query_id=self._cursor.query_id or "N/A") + return DatabricksAdapterResponse.from_cursor(self._cursor) T = TypeVar("T") diff --git a/tests/unit/test_handle.py b/tests/unit/test_handle.py index 5e7a34a93..30fee6d0d 100644 --- a/tests/unit/test_handle.py +++ b/tests/unit/test_handle.py @@ -1,13 +1,20 @@ +import os import sys from decimal import Decimal -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from databricks.sql.client import Cursor from dbt.adapters.contracts.connection import AdapterResponse from dbt_common.exceptions import DbtRuntimeError -from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils +from dbt.adapters.databricks.handle import ( + CursorWrapper, + DatabricksAdapterResponse, + DatabricksHandle, + SqlUtils, + _get_job_run_context, +) class TestSqlUtils: @@ -101,12 +108,40 @@ def test_fetchmany(self, cursor): def test_get_response__no_query_id(self, cursor): cursor.query_id = None wrapper = CursorWrapper(cursor) - assert wrapper.get_response() == AdapterResponse("OK", query_id="N/A") + response = wrapper.get_response() + assert response._message == "OK" + assert response.query_id == "N/A" def test_get_response__with_query_id(self, cursor): cursor.query_id = "id" wrapper = CursorWrapper(cursor) - assert wrapper.get_response() == AdapterResponse("OK", query_id="id") + response = wrapper.get_response() + assert response._message == "OK" + assert response.query_id == "id" + + @patch.dict( + os.environ, + {"DATABRICKS_JOB_ID": "123", "DATABRICKS_RUN_ID": "456", "DATABRICKS_TASK_KEY": "my_task"}, + ) + def test_get_response__with_job_context(self, cursor): + cursor.query_id = "qid" + wrapper = CursorWrapper(cursor) + response = wrapper.get_response() + assert isinstance(response, DatabricksAdapterResponse) + assert response.job_id == "123" + assert response.run_id == "456" + assert response.task_key == "my_task" + assert response.query_id == "qid" + + @patch.dict(os.environ, {}, clear=True) + def test_get_response__without_job_context(self, cursor): + cursor.query_id = "qid" + wrapper = CursorWrapper(cursor) + response = wrapper.get_response() + assert isinstance(response, DatabricksAdapterResponse) + assert response.job_id is None + assert response.run_id is None + assert response.task_key is None def test_with__no_exception(self, cursor): with CursorWrapper(cursor) as c: @@ -209,3 +244,71 @@ def test_close__open_raising_exception(self, conn, cursor): handle.close() cursor.close.assert_called_once() conn.close.assert_called_once() + + +class TestGetJobRunContext: + @patch.dict( + os.environ, + {"DATABRICKS_JOB_ID": "111", "DATABRICKS_RUN_ID": "222", "DATABRICKS_TASK_KEY": "etl"}, + ) + def test_all_vars_set(self): + ctx = _get_job_run_context() + assert ctx == {"job_id": "111", "run_id": "222", "task_key": "etl"} + + @patch.dict(os.environ, {"DATABRICKS_JOB_ID": "111"}, clear=True) + def test_partial_vars(self): + ctx = _get_job_run_context() + assert ctx["job_id"] == "111" + assert ctx["run_id"] is None + assert ctx["task_key"] is None + + @patch.dict(os.environ, {}, clear=True) + def test_no_vars(self): + ctx = _get_job_run_context() + assert ctx == {"job_id": None, "run_id": None, "task_key": None} + + +class TestDatabricksAdapterResponse: + def test_from_cursor__with_all_context(self): + cursor = Mock() + cursor.query_id = "q1" + with patch.dict( + os.environ, + { + "DATABRICKS_JOB_ID": "10", + "DATABRICKS_RUN_ID": "20", + "DATABRICKS_TASK_KEY": "transform", + }, + ): + resp = DatabricksAdapterResponse.from_cursor(cursor) + assert resp._message == "OK" + assert resp.query_id == "q1" + assert resp.job_id == "10" + assert resp.run_id == "20" + assert resp.task_key == "transform" + + def test_from_cursor__no_context(self): + cursor = Mock() + cursor.query_id = "q2" + with patch.dict(os.environ, {}, clear=True): + resp = DatabricksAdapterResponse.from_cursor(cursor) + assert resp._message == "OK" + assert resp.query_id == "q2" + assert resp.job_id is None + assert resp.run_id is None + assert resp.task_key is None + + def test_from_cursor__no_query_id(self): + cursor = Mock() + cursor.query_id = None + with patch.dict(os.environ, {}, clear=True): + resp = DatabricksAdapterResponse.from_cursor(cursor) + assert resp.query_id == "N/A" + + def test_str_representation(self): + resp = DatabricksAdapterResponse(_message="OK", query_id="q1", job_id="10") + assert str(resp) == "OK" + + def test_is_adapter_response_subclass(self): + resp = DatabricksAdapterResponse(_message="OK") + assert isinstance(resp, AdapterResponse)