Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Features

- Add `invocation_id` to the default query comment ([#1377](https://github.com/databricks/dbt-databricks/issues/1377))
- Include `job_id`, `run_id`, and `task_key` from Databricks Job context in `adapter_response`, enabling correlation between dbt runs and Databricks workflow executions via `run_results.json` ([#722](https://github.com/databricks/dbt-databricks/issues/722))

### Fixes

Expand Down
9 changes: 7 additions & 2 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@
ConnectionCreateError,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils
from dbt.adapters.databricks.handle import (
CursorWrapper,
DatabricksAdapterResponse,
DatabricksHandle,
SqlUtils,
)
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials
Expand Down Expand Up @@ -549,7 +554,7 @@ def get_response(cls, cursor: Any) -> AdapterResponse:
if isinstance(cursor, CursorWrapper):
return cursor.get_response()
else:
return AdapterResponse("OK")
return DatabricksAdapterResponse.from_cursor(cursor)

def clear_transaction(self) -> None:
"""Noop."""
Expand Down
44 changes: 43 additions & 1 deletion dbt/adapters/databricks/handle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import decimal
import os
import re
import sys
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, TypeVar

Expand All @@ -28,6 +30,46 @@
FailLogOp = Callable[[Exception], str]


# Databricks sets these environment variables when dbt runs inside a Databricks
# workflow job task. Reading them once per process is sufficient since they are
# constant for the lifetime of the task execution.
_JOB_RUN_ENV_VARS = {
"job_id": "DATABRICKS_JOB_ID",
"run_id": "DATABRICKS_RUN_ID",
"task_key": "DATABRICKS_TASK_KEY",
}


def _get_job_run_context() -> dict[str, Optional[str]]:
"""Return Databricks job-run context from environment variables.

Returns a dict with all three keys; values are ``None`` when the
corresponding environment variable is not set (i.e. dbt is running
outside of a Databricks Job).
"""
return {key: os.environ.get(env_var) for key, env_var in _JOB_RUN_ENV_VARS.items()}


@dataclass
class DatabricksAdapterResponse(AdapterResponse):
"""Extends the base adapter response with Databricks Job context."""

job_id: Optional[str] = None
run_id: Optional[str] = None
task_key: Optional[str] = None

@classmethod
def from_cursor(cls, cursor: Any) -> "DatabricksAdapterResponse":
ctx = _get_job_run_context()
return cls(
_message="OK",
query_id=getattr(cursor, "query_id", None) or "N/A",
job_id=ctx.get("job_id"),
run_id=ctx.get("run_id"),
task_key=ctx.get("task_key"),
)


class CursorWrapper:
"""
Wrap the DBSQL cursor to abstract the details from DatabricksConnectionManager.
Expand Down Expand Up @@ -79,7 +121,7 @@ def fetchmany(self, size: int) -> Sequence[tuple]:
return self._safe_execute(lambda cursor: cursor.fetchmany(size))

def get_response(self) -> AdapterResponse:
return AdapterResponse(_message="OK", query_id=self._cursor.query_id or "N/A")
return DatabricksAdapterResponse.from_cursor(self._cursor)

T = TypeVar("T")

Expand Down
111 changes: 107 additions & 4 deletions tests/unit/test_handle.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os
import sys
from decimal import Decimal
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest
from databricks.sql.client import Cursor
from dbt.adapters.contracts.connection import AdapterResponse
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils
from dbt.adapters.databricks.handle import (
CursorWrapper,
DatabricksAdapterResponse,
DatabricksHandle,
SqlUtils,
_get_job_run_context,
)


class TestSqlUtils:
Expand Down Expand Up @@ -101,12 +108,40 @@ def test_fetchmany(self, cursor):
def test_get_response__no_query_id(self, cursor):
cursor.query_id = None
wrapper = CursorWrapper(cursor)
assert wrapper.get_response() == AdapterResponse("OK", query_id="N/A")
response = wrapper.get_response()
assert response._message == "OK"
assert response.query_id == "N/A"

def test_get_response__with_query_id(self, cursor):
cursor.query_id = "id"
wrapper = CursorWrapper(cursor)
assert wrapper.get_response() == AdapterResponse("OK", query_id="id")
response = wrapper.get_response()
assert response._message == "OK"
assert response.query_id == "id"

@patch.dict(
os.environ,
{"DATABRICKS_JOB_ID": "123", "DATABRICKS_RUN_ID": "456", "DATABRICKS_TASK_KEY": "my_task"},
)
def test_get_response__with_job_context(self, cursor):
cursor.query_id = "qid"
wrapper = CursorWrapper(cursor)
response = wrapper.get_response()
assert isinstance(response, DatabricksAdapterResponse)
assert response.job_id == "123"
assert response.run_id == "456"
assert response.task_key == "my_task"
assert response.query_id == "qid"

@patch.dict(os.environ, {}, clear=True)
def test_get_response__without_job_context(self, cursor):
cursor.query_id = "qid"
wrapper = CursorWrapper(cursor)
response = wrapper.get_response()
assert isinstance(response, DatabricksAdapterResponse)
assert response.job_id is None
assert response.run_id is None
assert response.task_key is None

def test_with__no_exception(self, cursor):
with CursorWrapper(cursor) as c:
Expand Down Expand Up @@ -209,3 +244,71 @@ def test_close__open_raising_exception(self, conn, cursor):
handle.close()
cursor.close.assert_called_once()
conn.close.assert_called_once()


class TestGetJobRunContext:
@patch.dict(
os.environ,
{"DATABRICKS_JOB_ID": "111", "DATABRICKS_RUN_ID": "222", "DATABRICKS_TASK_KEY": "etl"},
)
def test_all_vars_set(self):
ctx = _get_job_run_context()
assert ctx == {"job_id": "111", "run_id": "222", "task_key": "etl"}

@patch.dict(os.environ, {"DATABRICKS_JOB_ID": "111"}, clear=True)
def test_partial_vars(self):
ctx = _get_job_run_context()
assert ctx["job_id"] == "111"
assert ctx["run_id"] is None
assert ctx["task_key"] is None

@patch.dict(os.environ, {}, clear=True)
def test_no_vars(self):
ctx = _get_job_run_context()
assert ctx == {"job_id": None, "run_id": None, "task_key": None}


class TestDatabricksAdapterResponse:
def test_from_cursor__with_all_context(self):
cursor = Mock()
cursor.query_id = "q1"
with patch.dict(
os.environ,
{
"DATABRICKS_JOB_ID": "10",
"DATABRICKS_RUN_ID": "20",
"DATABRICKS_TASK_KEY": "transform",
},
):
resp = DatabricksAdapterResponse.from_cursor(cursor)
assert resp._message == "OK"
assert resp.query_id == "q1"
assert resp.job_id == "10"
assert resp.run_id == "20"
assert resp.task_key == "transform"

def test_from_cursor__no_context(self):
cursor = Mock()
cursor.query_id = "q2"
with patch.dict(os.environ, {}, clear=True):
resp = DatabricksAdapterResponse.from_cursor(cursor)
assert resp._message == "OK"
assert resp.query_id == "q2"
assert resp.job_id is None
assert resp.run_id is None
assert resp.task_key is None

def test_from_cursor__no_query_id(self):
cursor = Mock()
cursor.query_id = None
with patch.dict(os.environ, {}, clear=True):
resp = DatabricksAdapterResponse.from_cursor(cursor)
assert resp.query_id == "N/A"

def test_str_representation(self):
resp = DatabricksAdapterResponse(_message="OK", query_id="q1", job_id="10")
assert str(resp) == "OK"

def test_is_adapter_response_subclass(self):
resp = DatabricksAdapterResponse(_message="OK")
assert isinstance(resp, AdapterResponse)
Loading