From da8f70bca899e70d2ecf9a2793d368f27199f276 Mon Sep 17 00:00:00 2001 From: Alexey Egorov <5102843+alexeyegorov@users.noreply.github.com> Date: Sun, 25 Jan 2026 19:52:26 +0100 Subject: [PATCH 1/2] chore: update to version and add session mode support This commit updates the Databricks adapter to version 1.10.15+session, introducing support for session mode execution. Key changes include: - Added `DatabricksSessionHandle` and `SessionCursorWrapper` for handling SparkSession-based execution. - Enhanced `DatabricksCredentials` to manage connection methods and validate session mode configurations. - Updated connection management to support session mode, including automatic selection of submission methods for Python models. - Improve SparkSession retrieval in Databricks adapter. This commit enhances the `DatabricksSessionHandle` and `SessionPythonJobHelper` classes to improve the retrieval of the existing SparkSession. It introduces multiple methods to obtain the SparkSession, ensuring compatibility with various Databricks environments. Additionally, it refactors method signatures for consistency and readability. --- dbt/adapters/databricks/__version__.py | 2 +- dbt/adapters/databricks/connections.py | 267 +++++++---- dbt/adapters/databricks/credentials.py | 218 ++++++--- dbt/adapters/databricks/handle.py | 41 ++ dbt/adapters/databricks/impl.py | 13 + .../python_models/python_submissions.py | 435 +++++++++++++----- dbt/adapters/databricks/session.py | 355 ++++++++++++++ pyproject.toml | 2 +- tests/unit/test_session.py | 423 +++++++++++++++++ 9 files changed, 1508 insertions(+), 248 deletions(-) create mode 100644 dbt/adapters/databricks/session.py create mode 100644 tests/unit/test_session.py diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index b7e5cc689..bf0b4d2c5 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.11.4" +version = "1.10.15-5" diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 291ac7677..c28917b20 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -49,12 +49,12 @@ from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +from dbt.adapters.databricks.session import DatabricksSessionHandle, SessionCursorWrapper from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials if TYPE_CHECKING: from agate import Table - DATABRICKS_QUERY_COMMENT = f""" {{%- set comment_dict = {{}} -%}} {{%- do comment_dict.update( @@ -96,16 +96,20 @@ def from_context(query_header_context: Any) -> "QueryContextWrapper": compute_name = None model_query_tags_override = None materialized = None - relation_name = getattr(query_header_context, "relation_name", "[unknown]") + relation_name = getattr(query_header_context, "relation_name", + "[unknown]") # Extract config-related attributes safely - if hasattr(query_header_context, "config") and query_header_context.config: + if hasattr(query_header_context, + "config") and query_header_context.config: config = query_header_context.config compute_name = config.get("databricks_compute") - query_tags_str = config.extra.get("query_tags") if hasattr(config, "extra") else None + query_tags_str = config.extra.get("query_tags") if hasattr( + config, "extra") else None if query_tags_str: - model_query_tags_override = QueryTagsUtils.parse_query_tags(query_tags_str) + model_query_tags_override = QueryTagsUtils.parse_query_tags( + query_tags_str) if hasattr(config, "materialized"): materialized = config.materialized @@ -121,6 +125,7 @@ def from_context(query_header_context: Any) -> "QueryContextWrapper": class DatabricksMacroQueryStringSetter(MacroQueryStringSetter): + def _get_comment_macro(self) -> Optional[str]: if self.config.query_comment.comment == DEFAULT_QUERY_COMMENT: return DATABRICKS_QUERY_COMMENT @@ -149,31 +154,40 @@ def has_capability(self, capability: DBRCapability) -> bool: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_manager: Optional[DatabricksCredentialManager] = None + # Cache for session mode (1.10.x doesn't have DBRCapabilities, so we use a simple dict) + _session_capabilities: Optional[dict] = None _dbr_capabilities_cache: dict[str, DBRCapabilities] = {} - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): + def __init__(self, profile: AdapterRequiredConfig, + mp_context: SpawnContext): super().__init__(profile, mp_context) self._api_client: Optional[DatabricksApiClient] = None @property def api_client(self) -> DatabricksApiClient: if self._api_client is None: - credentials = cast(DatabricksCredentials, self.profile.credentials) - self._api_client = DatabricksApiClient(credentials, 15 * 60) + self._api_client = DatabricksApiClient.create( + cast(DatabricksCredentials, self.profile.credentials), 15 * 60) return self._api_client + def is_session_mode(self) -> bool: + """Check if the connection is using session mode.""" + credentials = cast(DatabricksCredentials, self.profile.credentials) + return credentials.is_session_mode + def is_cluster(self) -> bool: conn = self.get_thread_connection() databricks_conn = cast(DatabricksDBTConnection, conn) - return is_cluster_http_path(databricks_conn.http_path, conn.credentials.cluster_id) + return is_cluster_http_path(databricks_conn.http_path, + conn.credentials.cluster_id) - def _get_capabilities_for_http_path(self, http_path: str) -> DBRCapabilities: + def _get_capabilities_for_http_path(self, + http_path: str) -> DBRCapabilities: return self._dbr_capabilities_cache.get(http_path, DBRCapabilities()) @classmethod - def _query_dbr_version( - cls, creds: DatabricksCredentials, http_path: str - ) -> Optional[tuple[int, int]]: + def _query_dbr_version(cls, creds: DatabricksCredentials, + http_path: str) -> Optional[tuple[int, int]]: is_cluster = is_cluster_http_path(http_path, creds.cluster_id) if not is_cluster: @@ -181,14 +195,16 @@ def _query_dbr_version( try: if cls.credentials_manager is None: - raise DbtRuntimeError("credentials_manager must be set before querying DBR version") + raise DbtRuntimeError( + "credentials_manager must be set before querying DBR version" + ) conn_args = SqlUtils.prepare_connection_arguments( - creds, cls.credentials_manager, http_path, {} - ) + creds, cls.credentials_manager, http_path, {}) with dbsql.connect(**conn_args) as conn: with conn.cursor() as cursor: - cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion") + cursor.execute( + "SET spark.databricks.clusterUsageTags.sparkVersion") result = cursor.fetchone() if result: return SqlUtils.extract_dbr_version(result[1]) @@ -198,7 +214,8 @@ def _query_dbr_version( return None @classmethod - def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) -> None: + def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, + http_path: str) -> None: if http_path not in cls._dbr_capabilities_cache: is_cluster = is_cluster_http_path(http_path, creds.cluster_id) dbr_version = cls._query_dbr_version(creds, http_path) @@ -210,19 +227,23 @@ def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) - def cancel_open(self) -> list[str]: cancelled = super().cancel_open() - logger.info("Cancelling open python jobs") - PythonRunTracker.cancel_runs(self.api_client) + # Only cancel Python jobs via API if not in session mode + if not self.is_session_mode(): + logger.info("Cancelling open python jobs") + PythonRunTracker.cancel_runs(self.api_client) return cancelled def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - handle: DatabricksHandle = self.get_thread_connection().handle + handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection( + ).handle dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) def set_query_header(self, query_header_context: dict[str, Any]) -> None: - self.query_header = DatabricksMacroQueryStringSetter(self.profile, query_header_context) + self.query_header = DatabricksMacroQueryStringSetter( + self.profile, query_header_context) @contextmanager def exception_handler(self, sql: str) -> Iterator[None]: @@ -248,9 +269,9 @@ def exception_handler(self, sql: str) -> Iterator[None]: raise DbtDatabaseError(str(exc)) from exc # override/overload - def set_connection_name( - self, name: Optional[str] = None, query_header_context: Any = None - ) -> Connection: + def set_connection_name(self, + name: Optional[str] = None, + query_header_context: Any = None) -> Connection: conn_name: str = "master" if name is None else name wrapped = QueryContextWrapper.from_context(query_header_context) @@ -269,12 +290,14 @@ def set_connection_name( if conn.name != conn_name: orig_conn_name: str = conn.name or "" conn.name = conn_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + fire_event( + ConnectionReused(orig_conn_name=orig_conn_name, + conn_name=conn_name)) return conn def _create_fresh_connection( - self, conn_name: str, query_header_context: QueryContextWrapper + self, conn_name: str, query_header_context: QueryContextWrapper ) -> DatabricksDBTConnection: conn = DatabricksDBTConnection( type=Identifier(self.TYPE), @@ -285,18 +308,23 @@ def _create_fresh_connection( credentials=self.profile.credentials, ) creds = cast(DatabricksCredentials, self.profile.credentials) - conn.http_path = QueryConfigUtils.get_http_path(query_header_context, creds) - conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) + conn.http_path = QueryConfigUtils.get_http_path( + query_header_context, creds) + conn.thread_identifier = cast(tuple[int, int], + self.get_thread_identifier()) conn._query_header_context = query_header_context - conn.capabilities = self._get_capabilities_for_http_path(conn.http_path) + conn.capabilities = self._get_capabilities_for_http_path( + conn.http_path) + conn.handle = LazyHandle(self.open) logger.debug(ConnectionCreate(str(conn))) self.set_thread_connection(conn) fire_event( - NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) - ) + NewConnection(conn_name=conn_name, + conn_type=self.TYPE, + node_info=get_node_info())) return conn @@ -318,10 +346,12 @@ def add_query( close_cursor: bool = False, ) -> tuple[Connection, Any]: connection = self.get_thread_connection() - fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) + fire_event( + ConnectionUsed(conn_type=self.TYPE, + conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: log_sql = redact_credentials(sql) if abridge_sql_log: @@ -332,22 +362,22 @@ def add_query( conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info(), - ) - ) + )) pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = handle.execute(sql, bindings) response = self.get_response(cursor) + # SQLQueryStatus in 1.10.x may not support query_id parameter + query_id = getattr(response, 'query_id', None) fire_event( SQLQueryStatus( - status=str(cursor.get_response()), + status=str(response), elapsed=round((time.time() - pre), 2), node_info=get_node_info(), query_id=response.query_id, - ) - ) + )) return connection, cursor except Error: @@ -380,50 +410,56 @@ def execute( cursor.close() def _execute_with_cursor( - self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper] + self, + log_sql: str, + f: Callable[[DatabricksHandle | DatabricksSessionHandle], + CursorWrapper | SessionCursorWrapper], ) -> "Table": connection = self.get_thread_connection() - fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) + fire_event( + ConnectionUsed(conn_type=self.TYPE, + conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: fire_event( SQLQuery( conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info(), - ) - ) + )) pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = f(handle) response = self.get_response(cursor) + # SQLQueryStatus in 1.10.x may not support query_id parameter fire_event( SQLQueryStatus( status=str(response), - query_id=response.query_id, elapsed=round((time.time() - pre), 2), node_info=get_node_info(), - ) - ) + )) return self.get_result_from_cursor(cursor, None) finally: if cursor: cursor.close() - def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table": + def list_schemas(self, + database: str, + schema: Optional[str] = None) -> "Table": database = database.strip("`") if schema: schema = schema.strip("`").lower() return self._execute_with_cursor( f"GetSchemas(database={database}, schema={schema})", - lambda cursor: cursor.list_schemas(database=database, schema=schema), + lambda cursor: cursor.list_schemas(database=database, + schema=schema), ) def list_tables(self, database: str, schema: str) -> "Table": @@ -431,13 +467,15 @@ def list_tables(self, database: str, schema: str) -> "Table": schema = schema.strip("`").lower() return self._execute_with_cursor( f"GetTables(database={database}, schema={schema})", - lambda cursor: cursor.list_tables(database=database, schema=schema), + lambda cursor: cursor.list_tables(database=database, schema=schema + ), ) # override def release(self) -> None: with self.lock: - conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + conn = cast(Optional[DatabricksDBTConnection], + self.get_if_exists()) if conn is None: return @@ -448,9 +486,12 @@ def release(self) -> None: def cleanup_all(self) -> None: with self.lock: # Close the current thread connection if it exists - conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + conn = cast(Optional[DatabricksDBTConnection], + self.get_if_exists()) if conn: - fire_event(ConnectionClosedInCleanup(conn_name=cast_to_str(conn.name))) + fire_event( + ConnectionClosedInCleanup( + conn_name=cast_to_str(conn.name))) self.close(conn) # garbage collect these connections @@ -464,33 +505,93 @@ def open(cls, connection: Connection) -> Connection: return connection creds: DatabricksCredentials = connection.credentials + + # Dispatch based on connection method + if creds.is_session_mode: + return cls._open_session(databricks_connection, creds) + else: + return cls._open_dbsql(databricks_connection, creds) + + @classmethod + def _open_session(cls, databricks_connection: DatabricksDBTConnection, + creds: DatabricksCredentials) -> Connection: + """Open a connection using SparkSession mode.""" + logger.debug("Opening connection in session mode") + + def connect() -> DatabricksSessionHandle: + try: + handle = DatabricksSessionHandle.create( + catalog=creds.database, + schema=creds.schema, + session_properties=creds.session_properties, + ) + databricks_connection.session_id = handle.session_id + + # Cache capabilities for session mode (simplified for 1.10.x) + cls._cache_session_capabilities(handle) + + logger.debug(f"Session mode connection opened: {handle}") + return handle + except Exception as exc: + logger.error(ConnectionCreateError(exc)) + raise DbtDatabaseError( + f"Failed to create session connection: {exc}") from exc + + # Session mode doesn't need retry logic as SparkSession is already available + databricks_connection.handle = connect() + databricks_connection.state = ConnectionState.OPEN + return databricks_connection + + @classmethod + def _cache_session_capabilities(cls, + handle: DatabricksSessionHandle) -> None: + """Cache DBR capabilities for session mode (simplified for 1.10.x).""" + if cls._session_capabilities is None: + dbr_version = handle.dbr_version + cls._session_capabilities = { + "dbr_version": dbr_version, + "is_sql_warehouse": False, + } + logger.debug( + f"Cached session capabilities: DBR version {dbr_version}") + + @classmethod + def _open_dbsql(cls, databricks_connection: DatabricksDBTConnection, + creds: DatabricksCredentials) -> Connection: + """Open a connection using DBSQL connector.""" timeout = creds.connect_timeout - cls.credentials_manager = creds.authenticate() + credentials_manager = creds.authenticate() + # In DBSQL mode, authenticate() always returns a credentials manager + assert credentials_manager is not None, "Credentials manager is required for DBSQL mode" + cls.credentials_manager = credentials_manager # Get merged query tags if we have query header context - query_header_context = getattr(databricks_connection, "_query_header_context", None) + query_header_context = getattr(databricks_connection, + "_query_header_context", None) merged_query_tags = {} if query_header_context: - merged_query_tags = QueryConfigUtils.get_merged_query_tags(query_header_context, creds) + merged_query_tags = QueryConfigUtils.get_merged_query_tags( + query_header_context, creds) conn_args = SqlUtils.prepare_connection_arguments( - creds, cls.credentials_manager, databricks_connection.http_path, merged_query_tags - ) + creds, cls.credentials_manager, databricks_connection.http_path, + merged_query_tags) def connect() -> DatabricksHandle: try: # TODO: what is the error when a user specifies a catalog they don't have access to conn = DatabricksHandle.from_connection_args( conn_args, - is_cluster_http_path(databricks_connection.http_path, creds.cluster_id), + is_cluster_http_path(databricks_connection.http_path, + creds.cluster_id), ) if conn: databricks_connection.session_id = conn.session_id - cls._cache_dbr_capabilities(creds, databricks_connection.http_path) + cls._cache_dbr_capabilities( + creds, databricks_connection.http_path) databricks_connection.capabilities = cls._dbr_capabilities_cache[ - databricks_connection.http_path - ] + databricks_connection.http_path] return conn else: raise DbtDatabaseError("Failed to create connection") @@ -507,12 +608,13 @@ def exponential_backoff(attempt: int) -> int: retryable_exceptions = [Error] return cls.retry_connection( - connection, + databricks_connection, connect=connect, logger=logger, retryable_exceptions=retryable_exceptions, retry_limit=creds.connect_retries, - retry_timeout=(timeout if timeout is not None else exponential_backoff), + retry_timeout=(timeout + if timeout is not None else exponential_backoff), ) # override @@ -527,7 +629,7 @@ def close(cls, connection: Connection) -> Connection: @classmethod def get_response(cls, cursor: Any) -> AdapterResponse: - if isinstance(cursor, CursorWrapper): + if isinstance(cursor, (CursorWrapper, SessionCursorWrapper)): return cursor.get_response() else: return AdapterResponse("OK") @@ -547,7 +649,8 @@ class QueryConfigUtils: """ @staticmethod - def get_http_path(context: QueryContextWrapper, creds: DatabricksCredentials) -> str: + def get_http_path(context: QueryContextWrapper, + creds: DatabricksCredentials) -> str: """ Get the http_path for the compute specified for the node. If none is specified default will be used. @@ -559,7 +662,8 @@ def get_http_path(context: QueryContextWrapper, creds: DatabricksCredentials) -> # Get the http_path for the named compute. http_path = None if creds.compute: - http_path = creds.compute.get(context.compute_name, {}).get("http_path", None) + http_path = creds.compute.get(context.compute_name, + {}).get("http_path", None) # no http_path for the named compute resource is an error condition if not http_path: @@ -588,27 +692,24 @@ def get_merged_query_tags( # Default tags that will only exists for queries tied to a specific model if query_header_context: - if hasattr(query_header_context, "model_name") and query_header_context.model_name: + if hasattr(query_header_context, + "model_name") and query_header_context.model_name: default_tags[QueryTagsUtils.DBT_MODEL_NAME_QUERY_TAG_KEY] = ( - query_header_context.model_name - ) - if hasattr(query_header_context, "materialized") and query_header_context.materialized: + query_header_context.model_name) + if hasattr(query_header_context, + "materialized") and query_header_context.materialized: default_tags[QueryTagsUtils.DBT_MATERIALIZED_QUERY_TAG_KEY] = ( - query_header_context.materialized - ) + query_header_context.materialized) # Parse connection tags from JSON string - connection_tags = ( - QueryTagsUtils.parse_query_tags(creds.query_tags) if creds.query_tags else {} - ) + connection_tags = (QueryTagsUtils.parse_query_tags(creds.query_tags) + if creds.query_tags else {}) # Extract model-level query tags from context model_tags = {} - if ( - query_header_context - and hasattr(query_header_context, "model_query_tags_override") - and query_header_context.model_query_tags_override - ): + if (query_header_context + and hasattr(query_header_context, "model_query_tags_override") + and query_header_context.model_query_tags_override): model_tags = query_header_context.model_query_tags_override return QueryTagsUtils.merge_query_tags( diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7e8af786b..fa0671cd3 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,12 +1,13 @@ import itertools import json +import os import re from collections.abc import Iterable from dataclasses import dataclass, field from typing import Any, Callable, Optional, cast from dbt.adapters.contracts.connection import Credentials -from dbt_common.exceptions import DbtConfigError, DbtValidationError +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError from mashumaro import DataClassDictMixin from requests import PreparedRequest from requests.auth import AuthBase @@ -16,9 +17,18 @@ from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger +# Connection method constants +CONNECTION_METHOD_SESSION = "session" +CONNECTION_METHOD_DBSQL = "dbsql" + +# Environment variable for session mode +DBT_DATABRICKS_SESSION_MODE_ENV = "DBT_DATABRICKS_SESSION_MODE" +DATABRICKS_RUNTIME_VERSION_ENV = "DATABRICKS_RUNTIME_VERSION" + CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") -EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") +EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile( + r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" REDIRECT_URL = "http://localhost:8020" @@ -48,6 +58,9 @@ class DatabricksCredentials(Credentials): connection_parameters: Optional[dict[str, Any]] = None auth_type: Optional[str] = None + # Connection method: "session" for SparkSession mode, "dbsql" for DBSQL connector (default) + method: Optional[str] = None + # Named compute resources specified in the profile. Used for # creating a connection when a model specifies a compute resource. compute: Optional[dict[str, Any]] = None @@ -71,7 +84,8 @@ def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) data.setdefault("database", None) data.setdefault("connection_parameters", {}) - data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30) + data["connection_parameters"].setdefault( + "_retry_stop_after_attempts_count", 30) data["connection_parameters"].setdefault("_retry_delay_max", 60) return data @@ -85,57 +99,112 @@ def __post_init__(self) -> None: session_properties = self.session_properties or {} if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties: if self.database is None: - self.database = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] + self.database = session_properties[ + CATALOG_KEY_IN_SESSION_PROPERTIES] del session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] else: raise DbtValidationError( f"Got duplicate keys: (`{CATALOG_KEY_IN_SESSION_PROPERTIES}` " - 'in session_properties) all map to "database"' - ) + 'in session_properties) all map to "database"') self.session_properties = session_properties if self.database is not None: database = self.database.strip() if not database: - raise DbtValidationError(f"Invalid catalog name : `{self.database}`.") + raise DbtValidationError( + f"Invalid catalog name : `{self.database}`.") self.database = database else: self.database = "hive_metastore" connection_parameters = self.connection_parameters or {} for key in ( - "server_hostname", - "http_path", - "access_token", - "client_id", - "client_secret", - "session_configuration", - "catalog", - "schema", - "_user_agent_entry", - "user_agent_entry", + "server_hostname", + "http_path", + "access_token", + "client_id", + "client_secret", + "session_configuration", + "catalog", + "schema", + "_user_agent_entry", + "user_agent_entry", ): if key in connection_parameters: - raise DbtValidationError(f"The connection parameter `{key}` is reserved.") + raise DbtValidationError( + f"The connection parameter `{key}` is reserved.") if "http_headers" in connection_parameters: http_headers = connection_parameters["http_headers"] if not isinstance(http_headers, dict) or any( - not isinstance(key, str) or not isinstance(value, str) - for key, value in http_headers.items() - ): + not isinstance(key, str) or not isinstance(value, str) + for key, value in http_headers.items()): raise DbtValidationError( "The connection parameter `http_headers` should be dict of strings: " - f"{http_headers}." - ) + f"{http_headers}.") if "_socket_timeout" not in connection_parameters: connection_parameters["_socket_timeout"] = 600 self.connection_parameters = connection_parameters - self._credentials_manager = DatabricksCredentialManager.create_from(self) + + # Auto-detect and validate connection method + self._init_connection_method() + + # Only create credentials manager for non-session mode + if not self.is_session_mode: + self._credentials_manager = DatabricksCredentialManager.create_from( + self) + + def _init_connection_method(self) -> None: + """Initialize and validate the connection method.""" + if self.method is None: + # Auto-detect session mode + if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, + "").lower() == "true": + self.method = CONNECTION_METHOD_SESSION + elif os.getenv(DATABRICKS_RUNTIME_VERSION_ENV) and not self.host: + # Running on Databricks cluster without host configured + self.method = CONNECTION_METHOD_SESSION + else: + self.method = CONNECTION_METHOD_DBSQL + + # Validate method value + if self.method not in (CONNECTION_METHOD_SESSION, + CONNECTION_METHOD_DBSQL): + raise DbtValidationError( + f"Invalid connection method: '{self.method}'. " + f"Must be '{CONNECTION_METHOD_SESSION}' or '{CONNECTION_METHOD_DBSQL}'." + ) + + @property + def is_session_mode(self) -> bool: + """Check if using session mode (SparkSession) for connections.""" + return self.method == CONNECTION_METHOD_SESSION + + def _validate_session_mode(self) -> None: + """Validate configuration for session mode.""" + try: + from pyspark.sql import SparkSession # noqa: F401 + except ImportError: + raise DbtRuntimeError( + "Session mode requires PySpark. " + "Please ensure you are running on a Databricks cluster with PySpark available." + ) + + if self.schema is None: + raise DbtValidationError("Schema is required for session mode.") def validate_creds(self) -> None: + """Validate credentials based on connection method.""" + if self.is_session_mode: + self._validate_session_mode() + else: + self._validate_dbsql_creds() + + def _validate_dbsql_creds(self) -> None: + """Validate credentials for DBSQL connector mode.""" for key in ["host", "http_path"]: if not getattr(self, key): - raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") + raise DbtConfigError( + f"The config '{key}' is required to connect to Databricks") if not self.token and self.auth_type != "oauth": raise DbtConfigError( "The config `auth_type: oauth` is required when not using access token" @@ -144,16 +213,13 @@ def validate_creds(self) -> None: if not self.client_id and self.client_secret: raise DbtConfigError( "The config 'client_id' is required to connect " - "to Databricks when 'client_secret' is present" - ) + "to Databricks when 'client_secret' is present") if (not self.azure_client_id and self.azure_client_secret) or ( - self.azure_client_id and not self.azure_client_secret - ): + self.azure_client_id and not self.azure_client_secret): raise DbtConfigError( "The config 'azure_client_id' and 'azure_client_secret' " - "must be both present or both absent" - ) + "must be both present or both absent") @classmethod def get_invocation_env(cls) -> Optional[str]: @@ -162,25 +228,23 @@ def get_invocation_env(cls) -> Optional[str]: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env): - raise DbtValidationError(f"Invalid invocation environment: {invocation_env}") + raise DbtValidationError( + f"Invalid invocation environment: {invocation_env}") return invocation_env @classmethod - def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: + def get_all_http_headers( + cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: http_session_headers_str = GlobalState.get_http_session_headers() - http_session_headers_dict: dict[str, str] = ( - { - k: v if isinstance(v, str) else json.dumps(v) - for k, v in json.loads(http_session_headers_str).items() - } - if http_session_headers_str is not None - else {} - ) + http_session_headers_dict: dict[str, str] = ({ + k: + v if isinstance(v, str) else json.dumps(v) + for k, v in json.loads(http_session_headers_str).items() + } if http_session_headers_str is not None else {}) - intersect_http_header_keys = ( - user_http_session_headers.keys() & http_session_headers_dict.keys() - ) + intersect_http_header_keys = (user_http_session_headers.keys() + & http_session_headers_dict.keys()) if len(intersect_http_header_keys) > 0: raise DbtValidationError( @@ -197,19 +261,39 @@ def type(self) -> str: @property def unique_field(self) -> str: + if self.is_session_mode: + # For session mode, use a unique identifier based on catalog and schema + return f"session://{self.database}/{self.schema}" return cast(str, self.host) - def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]: + def connection_info(self, + *, + with_aliases: bool = False + ) -> Iterable[tuple[str, Any]]: as_dict = self.to_dict(omit_none=False) connection_keys = set(self._connection_keys(with_aliases=with_aliases)) aliases: list[str] = [] if with_aliases: - aliases = [k for k, v in self._ALIASES.items() if v in connection_keys] - for key in itertools.chain(self._connection_keys(with_aliases=with_aliases), aliases): + aliases = [ + k for k, v in self._ALIASES.items() if v in connection_keys + ] + for key in itertools.chain( + self._connection_keys(with_aliases=with_aliases), aliases): if key in as_dict: yield key, as_dict[key] - def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: + def _connection_keys_session(self) -> tuple[str, ...]: + """Connection keys for session mode.""" + connection_keys = ["method", "schema"] + if self.database: + connection_keys.insert(1, "catalog") + if self.session_properties: + connection_keys.append("session_properties") + return tuple(connection_keys) + + def _connection_keys(self, + *, + with_aliases: bool = False) -> tuple[str, ...]: # Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)` # is called from only: # @@ -218,6 +302,11 @@ def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # # Thus, if `with_aliases` is `True`, `DatabricksCredentials._connection_keys` should return # the internal key names; otherwise it can use aliases to show in `dbt debug`. + + # Session mode has different connection keys + if self.is_session_mode: + return self._connection_keys_session() + connection_keys = ["host", "http_path", "schema"] if with_aliases: connection_keys.insert(2, "database") @@ -237,10 +326,18 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: @property def cluster_id(self) -> Optional[str]: - return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] + return self.extract_cluster_id( + self.http_path) # type: ignore[arg-type] + + def authenticate(self) -> Optional["DatabricksCredentialManager"]: + """Authenticate and return credentials manager. - def authenticate(self) -> "DatabricksCredentialManager": + For session mode, returns None as no external authentication is needed. + For DBSQL mode, validates credentials and returns the credentials manager. + """ self.validate_creds() + if self.is_session_mode: + return None assert self._credentials_manager is not None, "Credentials manager is not set." return self._credentials_manager @@ -279,7 +376,9 @@ class DatabricksCredentialManager(DataClassDictMixin): auth_type: Optional[str] = None @classmethod - def create_from(cls, credentials: DatabricksCredentials) -> "DatabricksCredentialManager": + def create_from( + cls, credentials: DatabricksCredentials + ) -> "DatabricksCredentialManager": return DatabricksCredentialManager( host=credentials.host or "", token=credentials.token, @@ -344,8 +443,10 @@ def __post_init__(self) -> None: self._config = self.authenticate_with_external_browser() else: auth_methods = { - "oauth-m2m": self.authenticate_with_oauth_m2m, - "legacy-azure-client-secret": self.legacy_authenticate_with_azure_client_secret, + "oauth-m2m": + self.authenticate_with_oauth_m2m, + "legacy-azure-client-secret": + self.legacy_authenticate_with_azure_client_secret, } # If the secret starts with dose, high chance is it is a databricks secret @@ -367,18 +468,20 @@ def __post_init__(self) -> None: break # Exit loop if authentication is successful except Exception as e: exceptions.append((auth_type, e)) - next_auth_type = auth_sequence[i + 1] if i + 1 < len(auth_sequence) else None + next_auth_type = auth_sequence[i + 1] if i + 1 < len( + auth_sequence) else None if next_auth_type: logger.warning( f"Failed to authenticate with {auth_type}, " - f"trying {next_auth_type} next. Error: {e}" - ) + f"trying {next_auth_type} next. Error: {e}") else: logger.error( f"Failed to authenticate with {auth_type}. " f"No more authentication methods to try. Error: {e}" ) - raise Exception(f"All authentication methods failed. Details: {exceptions}") + raise Exception( + f"All authentication methods failed. Details: {exceptions}" + ) @property def api_client(self) -> WorkspaceClient: @@ -386,6 +489,7 @@ def api_client(self) -> WorkspaceClient: @property def credentials_provider(self) -> PySQLCredentialProvider: + def inner() -> Callable[[], dict[str, str]]: return self.header_factory diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py index aadf871c1..e50e049a3 100644 --- a/dbt/adapters/databricks/handle.py +++ b/dbt/adapters/databricks/handle.py @@ -288,6 +288,47 @@ def translate_bindings(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[A return list(map(lambda x: float(x) if isinstance(x, decimal.Decimal) else x, bindings)) return None + @staticmethod + def format_bindings_for_sql(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[str]]: + """ + Format bindings as SQL literals for string substitution in session mode. + + This method properly quotes string values and handles special cases to ensure + SQL injection safety and correct SQL syntax. Used when executing SQL via + SparkSession.sql() which doesn't support parameterized queries. + + Args: + bindings: Sequence of binding values (strings, numbers, None, etc.) + + Returns: + Sequence of SQL literal strings, or None if bindings is None/empty + """ + if not bindings: + return None + + formatted = [] + for value in bindings: + if value is None: + formatted.append("NULL") + elif isinstance(value, bool): + formatted.append("TRUE" if value else "FALSE") + elif isinstance(value, str): + # Escape single quotes by doubling them, then wrap in quotes + escaped = value.replace("'", "''") + formatted.append(f"'{escaped}'") + elif isinstance(value, (int, float, decimal.Decimal)): + # Numbers don't need quotes + if isinstance(value, decimal.Decimal): + formatted.append(str(float(value))) + else: + formatted.append(str(value)) + else: + # For other types, convert to string and quote + escaped = str(value).replace("'", "''") + formatted.append(f"'{escaped}'") + + return formatted + @staticmethod def clean_sql(sql: str) -> str: cleaned = sql.strip() diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 2190e1f9c..f17a978dd 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -57,6 +57,7 @@ AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, ServerlessClusterPythonJobHelper, + SessionPythonJobHelper, WorkflowPythonJobHelper, ) from dbt.adapters.databricks.relation import ( @@ -801,6 +802,7 @@ def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: "all_purpose_cluster": AllPurposeClusterPythonJobHelper, "serverless_cluster": ServerlessClusterPythonJobHelper, "workflow_job": WorkflowPythonJobHelper, + "session": SessionPythonJobHelper, } @log_code_execution @@ -809,6 +811,17 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterRe "user_folder_for_python", self.behavior.use_user_folder_for_python.setting, # type: ignore[attr-defined] ) + + # Auto-select session submission when in session mode + from dbt.adapters.databricks.credentials import DatabricksCredentials + from dbt.adapters.databricks.logging import logger + + creds = cast(DatabricksCredentials, self.config.credentials) + if creds.is_session_mode: + if parsed_model["config"].get("submission_method") is None: + parsed_model["config"]["submission_method"] = "session" + logger.debug("Auto-selected 'session' submission method for Python model") + return super().submit_python_job(parsed_model, compiled_code) @available diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index 23c2a27c1..f661f35f3 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from dbt.adapters.base import PythonJobHelper from dbt_common.exceptions import DbtRuntimeError @@ -12,6 +12,9 @@ from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +if TYPE_CHECKING: + from pyspark.sql import SparkSession + DEFAULT_TIMEOUT = 60 * 60 * 24 @@ -29,7 +32,8 @@ class BaseDatabricksHelper(PythonJobHelper): tracker = PythonRunTracker() - def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, + credentials: DatabricksCredentials) -> None: self.credentials = credentials self.credentials.validate_creds() self.parsed_model = ParsedPythonModel(**parsed_model) @@ -62,9 +66,8 @@ def submit(self, compiled_code: str) -> None: class PythonCommandSubmitter(PythonSubmitter): """Submitter for Python models using the Command API.""" - def __init__( - self, api_client: DatabricksApiClient, tracker: PythonRunTracker, cluster_id: str - ) -> None: + def __init__(self, api_client: DatabricksApiClient, + tracker: PythonRunTracker, cluster_id: str) -> None: self.api_client = api_client self.tracker = tracker self.cluster_id = cluster_id @@ -77,8 +80,7 @@ def submit(self, compiled_code: str) -> None: command_exec: Optional[CommandExecution] = None try: command_exec = self.api_client.commands.execute( - self.cluster_id, context_id, compiled_code - ) + self.cluster_id, context_id, compiled_code) self.tracker.insert_command(command_exec) self.api_client.commands.poll_for_completion(command_exec) @@ -86,42 +88,46 @@ def submit(self, compiled_code: str) -> None: finally: if command_exec: self.tracker.remove_command(command_exec) - self.api_client.command_contexts.destroy(self.cluster_id, context_id) + self.api_client.command_contexts.destroy(self.cluster_id, + context_id) class PythonNotebookUploader: """Uploads a compiled Python model as a notebook to the Databricks workspace.""" - def __init__(self, api_client: DatabricksApiClient, parsed_model: ParsedPythonModel) -> None: + def __init__(self, api_client: DatabricksApiClient, + parsed_model: ParsedPythonModel) -> None: self.api_client = api_client self.catalog = parsed_model.catalog self.schema = parsed_model.schema_ self.identifier = parsed_model.identifier - self.job_grants = ( - parsed_model.config.python_job_config.grants - if parsed_model.config.python_job_config - else {} - ) + self.job_grants = (parsed_model.config.python_job_config.grants + if parsed_model.config.python_job_config else {}) self.notebook_access_control_list = parsed_model.config.notebook_access_control_list def upload(self, compiled_code: str) -> str: """Upload the compiled code to the Databricks workspace.""" - logger.debug( - f"[Notebook Upload Debug] Creating workspace dir for " - f"catalog={self.catalog}, schema={self.schema}" - ) - workdir = self.api_client.workspace.create_python_model_dir(self.catalog, self.schema) + logger.debug(f"[Notebook Upload Debug] Creating workspace dir for " + f"catalog={self.catalog}, schema={self.schema}") + workdir = self.api_client.workspace.create_python_model_dir( + self.catalog, self.schema) file_path = f"{workdir}{self.identifier}" - logger.debug(f"[Notebook Upload Debug] Uploading notebook to path: {file_path}") + logger.debug( + f"[Notebook Upload Debug] Uploading notebook to path: {file_path}") # Log notebook content length - logger.debug(f"[Notebook Upload Debug] Notebook content length: {len(compiled_code)} chars") + logger.debug( + f"[Notebook Upload Debug] Notebook content length: {len(compiled_code)} chars" + ) self.api_client.workspace.upload_notebook(file_path, compiled_code) - logger.debug(f"[Notebook Upload Debug] Successfully uploaded notebook to {file_path}") + logger.debug( + f"[Notebook Upload Debug] Successfully uploaded notebook to {file_path}" + ) if self.job_grants or self.notebook_access_control_list: - logger.debug("[Notebook Upload Debug] Setting permissions for notebook") + logger.debug( + "[Notebook Upload Debug] Setting permissions for notebook") self.set_notebook_permissions(file_path) return file_path @@ -131,14 +137,19 @@ def set_notebook_permissions(self, notebook_path: str) -> None: permission_builder = PythonPermissionBuilder(self.api_client) access_control_list = permission_builder.build_permissions( - self.job_grants, self.notebook_access_control_list, target_type="notebook" - ) + self.job_grants, + self.notebook_access_control_list, + target_type="notebook") if access_control_list: - logger.debug(f"Setting permissions on notebook: {notebook_path}") - self.api_client.notebook_permissions.put(notebook_path, access_control_list) + logger.debug( + f"Setting permissions on notebook: {notebook_path}") + self.api_client.notebook_permissions.put( + notebook_path, access_control_list) except Exception as e: - logger.error(f"Failed to set permissions on notebook {notebook_path}: {str(e)}") + logger.error( + f"Failed to set permissions on notebook {notebook_path}: {str(e)}" + ) raise DbtRuntimeError( f"Failed to set permissions on notebook: path={notebook_path}, error: {str(e)}" ) @@ -167,26 +178,30 @@ def __init__( def _get_job_owner_for_config(self) -> tuple[str, str]: """Get the owner of the job (and type) for the access control list.""" curr_user = self.api_client.curr_user.get_username() - is_service_principal = self.api_client.curr_user.is_service_principal(curr_user) + is_service_principal = self.api_client.curr_user.is_service_principal( + curr_user) source = "service_principal_name" if is_service_principal else "user_name" return curr_user, source @staticmethod - def _build_job_permission( - job_grants: list[dict[str, Any]], permission: str - ) -> list[dict[str, Any]]: + def _build_job_permission(job_grants: list[dict[str, Any]], + permission: str) -> list[dict[str, Any]]: """Build the access control list for the job.""" - return [{**grant, **{"permission_level": permission}} for grant in job_grants] + return [{ + **grant, + **{ + "permission_level": permission + } + } for grant in job_grants] def _filter_permissions( - self, acls: list[dict[str, Any]], valid_permissions: set[str] - ) -> list[dict[str, Any]]: + self, acls: list[dict[str, Any]], + valid_permissions: set[str]) -> list[dict[str, Any]]: return [ - acl - for acl in acls - if "permission_level" in acl and acl["permission_level"] in valid_permissions + acl for acl in acls if "permission_level" in acl + and acl["permission_level"] in valid_permissions ] def build_job_permissions( @@ -196,22 +211,19 @@ def build_job_permissions( ) -> list[dict[str, Any]]: access_control_list = [] owner, permissions_attribute = self._get_job_owner_for_config() - access_control_list.append( - { - permissions_attribute: owner, - "permission_level": "IS_OWNER", - } - ) + access_control_list.append({ + permissions_attribute: owner, + "permission_level": "IS_OWNER", + }) access_control_list.extend( - self._build_job_permission(job_grants.get("view", []), "CAN_VIEW") - ) + self._build_job_permission(job_grants.get("view", []), "CAN_VIEW")) access_control_list.extend( - self._build_job_permission(job_grants.get("run", []), "CAN_MANAGE_RUN") - ) + self._build_job_permission(job_grants.get("run", []), + "CAN_MANAGE_RUN")) access_control_list.extend( - self._build_job_permission(job_grants.get("manage", []), "CAN_MANAGE") - ) + self._build_job_permission(job_grants.get("manage", []), + "CAN_MANAGE")) combined_acls = access_control_list + acls return self._filter_permissions(combined_acls, self.JOB_PERMISSIONS) @@ -224,17 +236,21 @@ def build_notebook_permissions( access_control_list = [] access_control_list.extend( - self._build_job_permission(job_grants.get("view", []), "CAN_READ") - ) - access_control_list.extend(self._build_job_permission(job_grants.get("run", []), "CAN_RUN")) + self._build_job_permission(job_grants.get("view", []), "CAN_READ")) access_control_list.extend( - self._build_job_permission(job_grants.get("manage", []), "CAN_MANAGE") - ) + self._build_job_permission(job_grants.get("run", []), "CAN_RUN")) + access_control_list.extend( + self._build_job_permission(job_grants.get("manage", []), + "CAN_MANAGE")) combined_acls = access_control_list + acls - filtered_acls = self._filter_permissions(combined_acls, self.NOTEBOOK_PERMISSIONS) + filtered_acls = self._filter_permissions(combined_acls, + self.NOTEBOOK_PERMISSIONS) - return [acl for acl in filtered_acls if acl.get("permission_level") != "IS_OWNER"] + return [ + acl for acl in filtered_acls + if acl.get("permission_level") != "IS_OWNER" + ] def build_permissions( self, @@ -286,10 +302,12 @@ def __init__( packages = parsed_model.config.packages index_url = parsed_model.config.index_url additional_libraries = parsed_model.config.additional_libs - library_config = get_library_config(packages, index_url, additional_libraries) + library_config = get_library_config(packages, index_url, + additional_libraries) self.cluster_spec = {**cluster_spec, **library_config} self.job_grants = parsed_model.config.python_job_config.grants - self.additional_job_settings = parsed_model.config.python_job_config.dict() + self.additional_job_settings = parsed_model.config.python_job_config.dict( + ) self.environment_key = parsed_model.config.environment_key self.environment_deps = parsed_model.config.environment_dependencies @@ -305,25 +323,27 @@ def compile(self, path: str) -> PythonJobDetails: if self.environment_key: job_spec["environment_key"] = self.environment_key - if self.environment_deps and not self.additional_job_settings.get("environments"): - additional_job_config["environments"] = [ - { - "environment_key": self.environment_key, - "spec": {"environment_version": "4", "dependencies": self.environment_deps}, - } - ] + if self.environment_deps and not self.additional_job_settings.get( + "environments"): + additional_job_config["environments"] = [{ + "environment_key": + self.environment_key, + "spec": { + "client": "2", + "dependencies": self.environment_deps + }, + }] job_spec.update(self.cluster_spec) access_control_list = self.permission_builder.build_job_permissions( - self.job_grants, self.access_control_list - ) + self.job_grants, self.access_control_list) if access_control_list: job_spec["access_control_list"] = access_control_list job_spec["queue"] = {"enabled": True} - return PythonJobDetails( - run_name=self.run_name, job_spec=job_spec, additional_job_config=additional_job_config - ) + return PythonJobDetails(run_name=self.run_name, + job_spec=job_spec, + additional_job_config=additional_job_config) class PythonNotebookSubmitter(PythonSubmitter): @@ -356,7 +376,8 @@ def create( parsed_model, cluster_spec, ) - return PythonNotebookSubmitter(api_client, tracker, notebook_uploader, config_compiler) + return PythonNotebookSubmitter(api_client, tracker, notebook_uploader, + config_compiler) @override def submit(self, compiled_code: str) -> None: @@ -366,20 +387,21 @@ def submit(self, compiled_code: str) -> None: job_config = self.config_compiler.compile(file_path) run_id = self.api_client.job_runs.submit( - job_config.run_name, job_config.job_spec, **job_config.additional_job_config - ) + job_config.run_name, job_config.job_spec, + **job_config.additional_job_config) self.tracker.insert_run_id(run_id) try: if "access_control_list" in job_config.job_spec: try: - job_id = self.api_client.job_runs.get_job_id_from_run_id(run_id) + job_id = self.api_client.job_runs.get_job_id_from_run_id( + run_id) logger.debug(f"Setting permissions on job: {job_id}") self.api_client.workflow_permissions.patch( - job_id, job_config.job_spec["access_control_list"] - ) + job_id, job_config.job_spec["access_control_list"]) except Exception as e: - logger.error(f"Failed to set permissions on job {run_id}: {str(e)}") + logger.error( + f"Failed to set permissions on job {run_id}: {str(e)}") raise DbtRuntimeError( f"Failed to set permissions on job: run_id={run_id}, error: {str(e)}" ) @@ -414,7 +436,8 @@ class AllPurposeClusterPythonJobHelper(BaseDatabricksHelper): Top level helper for Python models using job runs or Command API on an all-purpose cluster. """ - def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, + credentials: DatabricksCredentials) -> None: self.credentials = credentials self.credentials.validate_creds() self.parsed_model = ParsedPythonModel(**parsed_model) @@ -428,8 +451,7 @@ def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> No config = self.parsed_model.config self.create_notebook = config.create_notebook self.cluster_id = config.cluster_id or self.credentials.extract_cluster_id( - config.http_path or self.credentials.http_path or "" - ) + config.http_path or self.credentials.http_path or "") self.validate_config() self.command_submitter = self.build_submitter() @@ -444,22 +466,23 @@ def build_submitter(self) -> PythonSubmitter: {"existing_cluster_id": self.cluster_id}, ) else: - return PythonCommandSubmitter(self.api_client, self.tracker, self.cluster_id or "") + return PythonCommandSubmitter(self.api_client, self.tracker, + self.cluster_id or "") @override def validate_config(self) -> None: if not self.cluster_id: raise ValueError( "Databricks `http_path` or `cluster_id` of an all-purpose cluster is required " - "for the `all_purpose_cluster` submission method." - ) + "for the `all_purpose_cluster` submission method.") class ServerlessClusterPythonJobHelper(BaseDatabricksHelper): """Top level helper for Python models using job runs on a serverless cluster.""" def build_submitter(self) -> PythonSubmitter: - return PythonNotebookSubmitter.create(self.api_client, self.tracker, self.parsed_model, {}) + return PythonNotebookSubmitter.create(self.api_client, self.tracker, + self.parsed_model, {}) class PythonWorkflowConfigCompiler: @@ -478,18 +501,22 @@ def __init__( self.post_hook_tasks = post_hook_tasks @staticmethod - def create(parsed_model: ParsedPythonModel) -> "PythonWorkflowConfigCompiler": - cluster_settings = PythonWorkflowConfigCompiler.cluster_settings(parsed_model) + def create( + parsed_model: ParsedPythonModel) -> "PythonWorkflowConfigCompiler": + cluster_settings = PythonWorkflowConfigCompiler.cluster_settings( + parsed_model) config = parsed_model.config if config.python_job_config: - cluster_settings.update(config.python_job_config.additional_task_settings) + cluster_settings.update( + config.python_job_config.additional_task_settings) workflow_spec = config.python_job_config.dict() - workflow_spec["name"] = PythonWorkflowConfigCompiler.workflow_name(parsed_model) + workflow_spec["name"] = PythonWorkflowConfigCompiler.workflow_name( + parsed_model) existing_job_id = config.python_job_config.existing_job_id post_hook_tasks = config.python_job_config.post_hook_tasks - return PythonWorkflowConfigCompiler( - cluster_settings, workflow_spec, existing_job_id, post_hook_tasks - ) + return PythonWorkflowConfigCompiler(cluster_settings, + workflow_spec, existing_job_id, + post_hook_tasks) else: return PythonWorkflowConfigCompiler(cluster_settings, {}, "", []) @@ -499,7 +526,8 @@ def workflow_name(parsed_model: ParsedPythonModel) -> str: if parsed_model.config.python_job_config: name = parsed_model.config.python_job_config.name return ( - name or f"dbt__{parsed_model.catalog}-{parsed_model.schema_}-{parsed_model.identifier}" + name or + f"dbt__{parsed_model.catalog}-{parsed_model.schema_}-{parsed_model.identifier}" ) @staticmethod @@ -584,7 +612,8 @@ def __init__( @staticmethod def create( - api_client: DatabricksApiClient, tracker: PythonRunTracker, parsed_model: ParsedPythonModel + api_client: DatabricksApiClient, tracker: PythonRunTracker, + parsed_model: ParsedPythonModel ) -> "PythonNotebookWorkflowSubmitter": uploader = PythonNotebookUploader(api_client, parsed_model) config_compiler = PythonWorkflowConfigCompiler.create(parsed_model) @@ -615,34 +644,43 @@ def submit(self, compiled_code: str) -> None: file_path = self.uploader.upload(compiled_code) logger.debug(f"[Workflow Debug] Uploaded notebook to: {file_path}") - workflow_config, existing_job_id = self.config_compiler.compile(file_path) + workflow_config, existing_job_id = self.config_compiler.compile( + file_path) logger.debug(f"[Workflow Debug] Workflow config: {workflow_config}") logger.debug(f"[Workflow Debug] Existing job ID: {existing_job_id}") - job_id = self.workflow_creater.create_or_update(workflow_config, existing_job_id) + job_id = self.workflow_creater.create_or_update( + workflow_config, existing_job_id) logger.debug(f"[Workflow Debug] Created/updated job ID: {job_id}") access_control_list = self.permission_builder.build_job_permissions( - self.job_grants, self.acls - ) + self.job_grants, self.acls) logger.debug(f"[Workflow Debug] Setting ACL: {access_control_list}") self.api_client.workflow_permissions.put(job_id, access_control_list) - logger.debug(f"[Workflow Debug] Running job {job_id} with queueing enabled") + logger.debug( + f"[Workflow Debug] Running job {job_id} with queueing enabled") run_id = self.api_client.workflows.run(job_id, enable_queueing=True) - logger.debug(f"[Workflow Debug] Started workflow run with ID: {run_id}") + logger.debug( + f"[Workflow Debug] Started workflow run with ID: {run_id}") self.tracker.insert_run_id(run_id) try: - logger.debug(f"[Workflow Debug] Polling for completion of run {run_id}") + logger.debug( + f"[Workflow Debug] Polling for completion of run {run_id}") self.api_client.job_runs.poll_for_completion(run_id) - logger.debug(f"[Workflow Debug] Workflow run {run_id} completed successfully") + logger.debug( + f"[Workflow Debug] Workflow run {run_id} completed successfully" + ) except Exception as e: - logger.error(f"[Workflow Debug] Workflow run {run_id} failed with error: {e}") + logger.error( + f"[Workflow Debug] Workflow run {run_id} failed with error: {e}" + ) # Try to get more info about the failure try: run_info = self.api_client.job_runs.get_run_info(run_id) - logger.error(f"[Workflow Debug] Run info for failed run: {run_info}") + logger.error( + f"[Workflow Debug] Run info for failed run: {run_info}") except Exception: pass raise @@ -655,6 +693,191 @@ class WorkflowPythonJobHelper(BaseDatabricksHelper): @override def build_submitter(self) -> PythonSubmitter: - return PythonNotebookWorkflowSubmitter.create( - self.api_client, self.tracker, self.parsed_model + return PythonNotebookWorkflowSubmitter.create(self.api_client, + self.tracker, + self.parsed_model) + + +class SessionStateManager: + """Manages session state to prevent leakage between Python models.""" + + @staticmethod + def cleanup_temp_views(spark: "SparkSession") -> None: + """Drop temporary views created during model execution.""" + try: + # Get list of temp views from the current database + temp_views = [ + row.viewName for row in spark.sql("SHOW VIEWS").collect() + if hasattr(row, "isTemporary") and row.isTemporary + ] + for view in temp_views: + try: + spark.catalog.dropTempView(view) + logger.debug(f"Dropped temp view: {view}") + except Exception as e: + logger.warning(f"Failed to drop temp view {view}: {e}") + except Exception as e: + logger.debug(f"Could not list temp views for cleanup: {e}") + + @staticmethod + def get_clean_exec_globals(spark: "SparkSession") -> dict[str, Any]: + """Return a clean execution context with minimal state.""" + return { + "spark": spark, + "dbt": __import__("dbt"), + # Standard Python builtins are available by default + } + + +class SessionPythonSubmitter(PythonSubmitter): + """Submitter for Python models using direct execution in current SparkSession. + + NOTE: This does NOT collect data to the driver. The compiled code contains + df.write.saveAsTable() which writes directly to storage, just like API-based + submission methods. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._state_manager = SessionStateManager() + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + try: + # Get clean execution context + exec_globals = self._state_manager.get_clean_exec_globals( + self._spark) + + # Log a preview of the code being executed + preview_len = min(500, len(compiled_code)) + logger.debug( + f"[Session Python] Executing code preview: {compiled_code[:preview_len]}..." + ) + + # Execute the compiled Python model code + # The compiled code will: + # 1. Execute model() function to get a DataFrame + # 2. Call df.write.saveAsTable() to persist to Delta + # 3. No collect() - data stays distributed + exec(compiled_code, exec_globals) + + logger.debug( + "[Session Python] Model execution completed successfully") + + except Exception as e: + logger.error(f"Python model execution failed: {e}") + raise DbtRuntimeError(f"Python model execution failed: {e}") from e + + finally: + # Clean up temp views to prevent state leakage + try: + self._state_manager.cleanup_temp_views(self._spark) + except Exception as cleanup_error: + logger.warning( + f"Failed to cleanup temp views: {cleanup_error}") + + +class SessionPythonJobHelper(PythonJobHelper): + """Helper for Python models executing directly in session mode. + + This helper executes Python models directly in the current SparkSession + without using the Databricks API. It's designed for running dbt on + job clusters where the SparkSession is already available. + """ + + tracker = PythonRunTracker() + + def __init__(self, parsed_model: dict, + credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.parsed_model = ParsedPythonModel(**parsed_model) + + # Get SparkSession directly - no API client needed + import os + from pyspark.sql import SparkSession + from pyspark import SparkContext + + # On Databricks, the SparkSession may already exist or we may need to create it + # Try multiple methods to get an existing session first + spark = None + + # Method 1: Try to get from active SparkContext and create SparkSession from it + # This is the most reliable method on Databricks + try: + sc = SparkContext._active_spark_context + if sc is not None: + # Create SparkSession from existing SparkContext + # This avoids the need for a master URL + spark = SparkSession(sc) + logger.debug( + "[Session Python] Got SparkSession from active SparkContext" + ) + except Exception as e: + logger.debug( + f"[Session Python] Could not get SparkSession from SparkContext: {e}" + ) + + # Method 2: Try getActiveSession() (available in PySpark 3.0+) + if spark is None: + try: + spark = SparkSession.getActiveSession() + if spark is not None: + logger.debug( + "[Session Python] Got SparkSession from getActiveSession()" + ) + except (AttributeError, Exception) as e: + logger.debug( + f"[Session Python] getActiveSession() not available or failed: {e}" + ) + + # Method 3: Try to get from global 'spark' variable (Databricks convention) + if spark is None: + try: + import __main__ + if hasattr(__main__, 'spark'): + spark = __main__.spark + logger.debug( + "[Session Python] Got SparkSession from __main__.spark" + ) + except Exception as e: + logger.debug( + f"[Session Python] Could not get SparkSession from __main__.spark: {e}" + ) + + # If no existing SparkSession found, provide a clear error message + if spark is None: + databricks_runtime = os.getenv("DATABRICKS_RUNTIME_VERSION") + if databricks_runtime: + raise DbtRuntimeError( + "[Session Python] Could not find an existing SparkSession. " + "This typically happens when using the native 'dbt task' in Databricks Jobs, " + "which does not provide a SparkSession context.\n\n" + "Session mode is only compatible with:\n" + " - Databricks Notebooks (where 'spark' is pre-initialized)\n" + " - Python tasks that initialize SparkSession before running dbt\n" + " - Environments where SparkSession is already available\n\n" + "For the native dbt task, use DBSQL mode instead (the default):\n" + " - Set 'method: dbsql' in your profile (or omit 'method' entirely)\n" + " - Configure 'host' and 'http_path' to connect to a SQL warehouse or cluster\n\n" + f"Databricks runtime version: {databricks_runtime}") + else: + raise DbtRuntimeError( + "[Session Python] Session mode requires a Databricks cluster environment " + "with an active SparkSession. " + "DATABRICKS_RUNTIME_VERSION environment variable not found. " + "Ensure you are running on a Databricks cluster in a context where " + "SparkSession is available (e.g., Notebook or Python task with Spark initialized)." + ) + + self._spark = spark + logger.debug( + f"[Session Python] Using SparkSession: {self._spark.sparkContext.applicationId}" ) + + self._submitter = SessionPythonSubmitter(self._spark) + + def submit(self, compiled_code: str) -> None: + """Submit the compiled Python model for execution.""" + self._submitter.submit(compiled_code) diff --git a/dbt/adapters/databricks/session.py b/dbt/adapters/databricks/session.py new file mode 100644 index 000000000..66b82ae1a --- /dev/null +++ b/dbt/adapters/databricks/session.py @@ -0,0 +1,355 @@ +""" +Session mode support for dbt-databricks. + +This module provides SparkSession-based execution for running dbt on Databricks job clusters +without requiring the DBSQL connector. It enables complete dbt pipelines (SQL + Python models) +to execute entirely within a single SparkSession. + +Key components: +- SessionCursorWrapper: Adapts DataFrame results to the cursor interface expected by dbt +- DatabricksSessionHandle: Wraps SparkSession to provide the handle interface +""" + +import sys +from collections.abc import Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional + +from dbt.adapters.contracts.connection import AdapterResponse +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.handle import SqlUtils +from dbt.adapters.databricks.logging import logger + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, Row, SparkSession + from pyspark.sql.types import StructField + + +class SessionCursorWrapper: + """ + Wraps SparkSession DataFrame results to provide a cursor-like interface. + + This adapter allows dbt to use SparkSession.sql() results in the same way + it uses DBSQL cursor results, maintaining compatibility with the existing + connection management code. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._df: Optional["DataFrame"] = None + self._rows: Optional[list["Row"]] = None + self._query_id: str = "session-query" + self.open = True + + def execute( + self, + sql: str, + bindings: Optional[Sequence[Any]] = None + ) -> "SessionCursorWrapper": + """Execute a SQL statement and store the resulting DataFrame.""" + cleaned_sql = SqlUtils.clean_sql(sql) + + # Handle bindings by formatting as SQL literals and substituting + # This ensures string values are properly quoted (e.g., 'k. A.' instead of k. A.) + if bindings: + translated = SqlUtils.translate_bindings(bindings) + if translated: + formatted_bindings = SqlUtils.format_bindings_for_sql(translated) + if formatted_bindings: + cleaned_sql = cleaned_sql % tuple(formatted_bindings) + + logger.debug(f"Session mode executing SQL: {cleaned_sql[:200]}...") + self._df = self._spark.sql(cleaned_sql) + self._rows = None # Reset cached rows + return self + + def fetchall(self) -> Sequence[tuple]: + """Fetch all rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + return [tuple(row) for row in (self._rows or [])] + + def fetchone(self) -> Optional[tuple]: + """Fetch the next row from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if self._rows: + return tuple(self._rows.pop(0)) + return None + + def fetchmany(self, size: int) -> Sequence[tuple]: + """Fetch the next `size` rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if not self._rows: + return [] + result = [tuple(row) for row in self._rows[:size]] + self._rows = self._rows[size:] + return result + + @property + def description(self) -> Optional[list[tuple]]: + """Return column descriptions in DB-API format.""" + if self._df is None: + return None + return [self._field_to_description(f) for f in self._df.schema.fields] + + @staticmethod + def _field_to_description(field: "StructField") -> tuple: + """Convert a StructField to a DB-API description tuple.""" + # DB-API description: (name, type_code, display_size, internal_size, + # precision, scale, null_ok) + return ( + field.name, + field.dataType.simpleString(), + None, + None, + None, + None, + field.nullable, + ) + + def get_response(self) -> AdapterResponse: + """Return an adapter response for the executed query.""" + return AdapterResponse(_message="OK", query_id=self._query_id) + + def cancel(self) -> None: + """Cancel the current operation (no-op for session mode).""" + logger.debug("SessionCursorWrapper.cancel() called (no-op)") + self.open = False + + def close(self) -> None: + """Close the cursor.""" + logger.debug("SessionCursorWrapper.close() called") + self.open = False + self._df = None + self._rows = None + + def __str__(self) -> str: + return f"SessionCursor(query-id={self._query_id})" + + def __enter__(self) -> "SessionCursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + + +class DatabricksSessionHandle: + """ + Handle for a Databricks SparkSession. + + Provides the same interface as DatabricksHandle but uses the active SparkSession + instead of the DBSQL connector. This enables dbt to run on job clusters without + requiring external API connections. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self.open = True + self._cursor: Optional[SessionCursorWrapper] = None + self._dbr_version: Optional[tuple[int, int]] = None + + @staticmethod + def create( + catalog: Optional[str] = None, + schema: Optional[str] = None, + session_properties: Optional[dict[str, Any]] = None, + ) -> "DatabricksSessionHandle": + """ + Create a DatabricksSessionHandle using the active SparkSession. + + Args: + catalog: Optional catalog to set as current + schema: Optional schema (not used - dbt uses fully qualified names and + creates schemas during execution) + session_properties: Optional session configuration properties + + Returns: + A new DatabricksSessionHandle instance + """ + import os + from pyspark.sql import SparkSession + from pyspark import SparkContext + + # On Databricks, the SparkSession may already exist or we may need to create it + # Try multiple methods to get an existing session first + spark = None + + # Method 1: Try to get from active SparkContext and create SparkSession from it + # This is the most reliable method on Databricks + try: + sc = SparkContext._active_spark_context + if sc is not None: + # Create SparkSession from existing SparkContext + # This avoids the need for a master URL + spark = SparkSession(sc) + logger.debug("Got SparkSession from active SparkContext") + except Exception as e: + logger.debug(f"Could not get SparkSession from SparkContext: {e}") + + # Method 2: Try getActiveSession() (available in PySpark 3.0+) + if spark is None: + try: + spark = SparkSession.getActiveSession() + if spark is not None: + logger.debug("Got SparkSession from getActiveSession()") + except (AttributeError, Exception) as e: + logger.debug( + f"getActiveSession() not available or failed: {e}") + + # Method 3: Try to get from global 'spark' variable (Databricks convention) + if spark is None: + try: + import __main__ + if hasattr(__main__, 'spark'): + spark = __main__.spark + logger.debug("Got SparkSession from __main__.spark") + except Exception as e: + logger.debug( + f"Could not get SparkSession from __main__.spark: {e}") + + # If no existing SparkSession found, provide a clear error message + if spark is None: + databricks_runtime = os.getenv("DATABRICKS_RUNTIME_VERSION") + if databricks_runtime: + raise DbtRuntimeError( + "Session mode could not find an existing SparkSession. " + "This typically happens when using the native 'dbt task' in Databricks Jobs, " + "which does not provide a SparkSession context.\n\n" + "Session mode is only compatible with:\n" + " - Databricks Notebooks (where 'spark' is pre-initialized)\n" + " - Python tasks that initialize SparkSession before running dbt\n" + " - Environments where SparkSession is already available\n\n" + "For the native dbt task, use DBSQL mode instead (the default):\n" + " - Set 'method: dbsql' in your profile (or omit 'method' entirely)\n" + " - Configure 'host' and 'http_path' to connect to a SQL warehouse or cluster\n\n" + f"Databricks runtime version: {databricks_runtime}") + else: + raise DbtRuntimeError( + "Session mode requires a Databricks cluster environment with an active SparkSession. " + "DATABRICKS_RUNTIME_VERSION environment variable not found. " + "Ensure you are running on a Databricks cluster in a context where " + "SparkSession is available (e.g., Notebook or Python task with Spark initialized)." + ) + + # Set catalog if provided + if catalog: + try: + spark.catalog.setCurrentCatalog(catalog) + logger.debug(f"Set current catalog to: {catalog}") + except Exception as e: + logger.warning(f"Failed to set catalog '{catalog}': {e}") + # Fall back to USE CATALOG for older Spark versions + spark.sql(f"USE CATALOG {catalog}") + + # Note: We intentionally do NOT call setCurrentDatabase(schema) here. + # The schema from the profile is used as a base/prefix for generated schema names + # (e.g., "dbt" -> "dbt_seeds", "dbt_bronze" via generate_schema_name macro). + # dbt creates these schemas during execution via CREATE SCHEMA IF NOT EXISTS, + # and uses fully qualified names (catalog.schema.table) for all operations. + # Setting the current database would fail if the schema doesn't exist yet. + + # Apply session properties + if session_properties: + for key, value in session_properties.items(): + spark.conf.set(key, str(value)) + logger.debug(f"Set session property {key}={value}") + + handle = DatabricksSessionHandle(spark) + logger.debug(f"Created session handle: {handle}") + return handle + + @property + def dbr_version(self) -> tuple[int, int]: + """Get the DBR version of the current cluster.""" + if self._dbr_version is None: + try: + version_str = self._spark.conf.get( + "spark.databricks.clusterUsageTags.sparkVersion", "") + if version_str: + self._dbr_version = SqlUtils.extract_dbr_version( + version_str) + else: + # If we can't get the version, assume latest + logger.warning( + "Could not determine DBR version, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + except Exception as e: + logger.warning( + f"Failed to get DBR version: {e}, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + return self._dbr_version + + @property + def session_id(self) -> str: + """Get a unique identifier for this session.""" + try: + # Try to get the Spark application ID as a session identifier + return self._spark.sparkContext.applicationId or "session-unknown" + except Exception: + return "session-unknown" + + def execute( + self, + sql: str, + bindings: Optional[Sequence[Any]] = None) -> SessionCursorWrapper: + """Execute a SQL statement and return a cursor wrapper.""" + if not self.open: + raise DbtRuntimeError( + "Attempting to execute on a closed session handle") + + if self._cursor: + self._cursor.close() + + self._cursor = SessionCursorWrapper(self._spark) + return self._cursor.execute(sql, bindings) + + def list_schemas(self, + database: str, + schema: Optional[str] = None) -> SessionCursorWrapper: + """List schemas in the given database/catalog.""" + if schema: + sql = f"SHOW SCHEMAS IN {database} LIKE '{schema}'" + else: + sql = f"SHOW SCHEMAS IN {database}" + return self.execute(sql) + + def list_tables(self, database: str, schema: str) -> SessionCursorWrapper: + """List tables in the given database and schema.""" + sql = f"SHOW TABLES IN {database}.{schema}" + return self.execute(sql) + + def cancel(self) -> None: + """Cancel any in-progress operations.""" + logger.debug("DatabricksSessionHandle.cancel() called") + if self._cursor: + self._cursor.cancel() + self.open = False + + def close(self) -> None: + """Close the session handle.""" + logger.debug("DatabricksSessionHandle.close() called") + if self._cursor: + self._cursor.close() + self.open = False + # Note: We don't stop the SparkSession as it may be shared + + def rollback(self) -> None: + """Required for interface compatibility, but not implemented.""" + logger.debug("NotImplemented: rollback (session mode)") + + def __del__(self) -> None: + if self._cursor: + self._cursor.close() + self.close() + + def __str__(self) -> str: + return f"SessionHandle(session-id={self.session_id})" diff --git a/pyproject.toml b/pyproject.toml index 4c120c86a..c388c8f1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "dbt-databricks" +name = "dbt-databricks-session" dynamic = ["version"] description = "The Databricks adapter plugin for dbt" readme = "README.md" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..b74719463 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,423 @@ +"""Unit tests for session mode components.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.databricks.session import ( + DatabricksSessionHandle, + SessionCursorWrapper, +) + + +class TestSessionCursorWrapper: + """Tests for SessionCursorWrapper.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + return spark + + @pytest.fixture + def cursor(self, mock_spark): + """Create a SessionCursorWrapper with mock SparkSession.""" + return SessionCursorWrapper(mock_spark) + + def test_execute_cleans_sql(self, cursor, mock_spark): + """Test that execute cleans SQL (strips whitespace and trailing semicolon).""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + result = cursor.execute(" SELECT 1; ") + + mock_spark.sql.assert_called_once_with("SELECT 1") + assert result is cursor + + def test_execute_with_bindings(self, cursor, mock_spark): + """Test that execute handles bindings with proper SQL literal formatting.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT %s, %s", (1, "test")) + + # String values should be properly quoted + mock_spark.sql.assert_called_once_with("SELECT 1, 'test'") + + def test_fetchall_returns_tuples(self, cursor, mock_spark): + """Test that fetchall returns list of tuples.""" + mock_df = MagicMock() + mock_row1 = MagicMock() + mock_row1.__iter__ = lambda self: iter([1, "a"]) + mock_row2 = MagicMock() + mock_row2.__iter__ = lambda self: iter([2, "b"]) + mock_df.collect.return_value = [mock_row1, mock_row2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchall() + + assert result == [(1, "a"), (2, "b")] + + def test_fetchone_returns_single_tuple(self, cursor, mock_spark): + """Test that fetchone returns a single tuple.""" + mock_df = MagicMock() + mock_row = MagicMock() + mock_row.__iter__ = lambda self: iter([1, "a"]) + mock_df.collect.return_value = [mock_row] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result == (1, "a") + + def test_fetchone_returns_none_when_empty(self, cursor, mock_spark): + """Test that fetchone returns None when no rows.""" + mock_df = MagicMock() + mock_df.collect.return_value = [] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result is None + + def test_fetchmany_returns_limited_rows(self, cursor, mock_spark): + """Test that fetchmany returns limited number of rows.""" + mock_df = MagicMock() + mock_rows = [MagicMock() for _ in range(5)] + for i, row in enumerate(mock_rows): + row.__iter__ = (lambda i: lambda self: iter([i]))(i) + mock_df.collect.return_value = mock_rows + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchmany(2) + + assert len(result) == 2 + assert result == [(0, ), (1, )] + + def test_description_returns_column_info(self, cursor, mock_spark): + """Test that description returns column metadata.""" + mock_df = MagicMock() + mock_field1 = MagicMock() + mock_field1.name = "id" + mock_field1.dataType.simpleString.return_value = "int" + mock_field1.nullable = False + + mock_field2 = MagicMock() + mock_field2.name = "name" + mock_field2.dataType.simpleString.return_value = "string" + mock_field2.nullable = True + + mock_df.schema.fields = [mock_field1, mock_field2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + desc = cursor.description + + assert len(desc) == 2 + assert desc[0][0] == "id" + assert desc[0][1] == "int" + assert desc[0][6] is False + assert desc[1][0] == "name" + assert desc[1][1] == "string" + assert desc[1][6] is True + + def test_description_returns_none_before_execute(self, cursor): + """Test that description returns None before execute.""" + assert cursor.description is None + + def test_get_response_returns_adapter_response(self, cursor, mock_spark): + """Test that get_response returns an AdapterResponse.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT 1") + response = cursor.get_response() + + assert response._message == "OK" + assert response.query_id == "session-query" + + def test_close_sets_open_to_false(self, cursor): + """Test that close sets open to False.""" + assert cursor.open is True + cursor.close() + assert cursor.open is False + + def test_context_manager(self, mock_spark): + """Test that cursor works as context manager.""" + with SessionCursorWrapper(mock_spark) as cursor: + assert cursor.open is True + assert cursor.open is False + + def test_execute_with_string_special_characters(self, cursor, mock_spark): + """Test that execute properly quotes strings with special characters.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s)", + ("k. A.", "O'Brien")) + + # Special characters should be properly quoted and escaped + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES ('k. A.', 'O''Brien')") + + def test_execute_with_null_value(self, cursor, mock_spark): + """Test that execute handles NULL values correctly.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s)", (1, None)) + + # NULL should be formatted as SQL NULL + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES (1, NULL)") + + def test_execute_with_boolean_values(self, cursor, mock_spark): + """Test that execute handles boolean values correctly.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s)", (True, False)) + + # Booleans should be formatted as TRUE/FALSE + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES (TRUE, FALSE)") + + def test_execute_with_numeric_values(self, cursor, mock_spark): + """Test that execute handles numeric values correctly.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s, %s)", (42, 3.14, 100)) + + # Numbers should remain unquoted + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES (42, 3.14, 100)") + + def test_execute_with_comma_in_string(self, cursor, mock_spark): + """Test that execute properly handles strings containing commas.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s)", + ("Freie und Hansestadt Hamburg", )) + + # String with spaces and special characters should be properly quoted + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES ('Freie und Hansestadt Hamburg')") + + def test_execute_with_parentheses_in_string(self, cursor, mock_spark): + """Test that execute properly handles strings containing parentheses.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s)", + ("Haus in Planung (projektiert)", )) + + # String with parentheses should be properly quoted + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES ('Haus in Planung (projektiert)')") + + +class TestDatabricksSessionHandle: + """Tests for DatabricksSessionHandle.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + spark.sparkContext.applicationId = "app-123" + spark.conf.get.return_value = "14.3.x-scala2.12" + return spark + + @pytest.fixture + def handle(self, mock_spark): + """Create a DatabricksSessionHandle with mock SparkSession.""" + return DatabricksSessionHandle(mock_spark) + + def test_session_id_returns_application_id(self, handle): + """Test that session_id returns the Spark application ID.""" + assert handle.session_id == "app-123" + + def test_dbr_version_extracts_version(self, handle, mock_spark): + """Test that dbr_version extracts version from Spark config.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + version = handle.dbr_version + + assert version == (14, 3) + + def test_dbr_version_caches_result(self, handle, mock_spark): + """Test that dbr_version caches the result.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + _ = handle.dbr_version + _ = handle.dbr_version + + # Should only call conf.get once due to caching + assert mock_spark.conf.get.call_count == 1 + + def test_execute_returns_cursor(self, handle, mock_spark): + """Test that execute returns a SessionCursorWrapper.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor = handle.execute("SELECT 1") + + assert isinstance(cursor, SessionCursorWrapper) + + def test_execute_closes_previous_cursor(self, handle, mock_spark): + """Test that execute closes any previous cursor.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor1 = handle.execute("SELECT 1") + assert cursor1.open is True + + cursor2 = handle.execute("SELECT 2") + assert cursor1.open is False + assert cursor2.open is True + + def test_close_sets_open_to_false(self, handle): + """Test that close sets open to False.""" + assert handle.open is True + handle.close() + assert handle.open is False + + def test_list_schemas_executes_show_schemas(self, handle, mock_spark): + """Test that list_schemas executes SHOW SCHEMAS.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog") + + def test_list_schemas_with_pattern(self, handle, mock_spark): + """Test that list_schemas includes LIKE pattern.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with( + "SHOW SCHEMAS IN my_catalog LIKE 'my_schema'") + + def test_list_tables_executes_show_tables(self, handle, mock_spark): + """Test that list_tables executes SHOW TABLES.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_tables("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with( + "SHOW TABLES IN my_catalog.my_schema") + + def test_create_gets_existing_spark_session(self): + """Test that create finds an existing SparkSession via SparkContext.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + { + "pyspark": MagicMock(), + "pyspark.sql": MagicMock() + }, + ): + import sys + + # Mock SparkContext._active_spark_context to return a context + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + # Mock SparkSession constructor to return our mock session + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + handle = DatabricksSessionHandle.create() + + # Should have created SparkSession from SparkContext + sys.modules["pyspark.sql"].SparkSession.assert_called_once_with( + mock_sc) + assert handle.session_id == "app-456" + + def test_create_sets_catalog(self): + """Test that create sets the catalog.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + { + "pyspark": MagicMock(), + "pyspark.sql": MagicMock() + }, + ): + import sys + + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + DatabricksSessionHandle.create(catalog="my_catalog") + + mock_spark.catalog.setCurrentCatalog.assert_called_once_with( + "my_catalog") + + def test_create_does_not_set_schema(self): + """Test that create does NOT set the schema/database. + + dbt uses the schema from the profile as a base/prefix for generated schema names + (e.g., "dbt" -> "dbt_seeds", "dbt_bronze" via generate_schema_name macro). + dbt creates these schemas during execution via CREATE SCHEMA IF NOT EXISTS, + and uses fully qualified names (catalog.schema.table) for all operations. + Setting the current database would fail if the schema doesn't exist yet. + """ + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + { + "pyspark": MagicMock(), + "pyspark.sql": MagicMock() + }, + ): + import sys + + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + DatabricksSessionHandle.create(schema="my_schema") + + # Schema should NOT be set - this is the fix for SCHEMA_NOT_FOUND error + mock_spark.catalog.setCurrentDatabase.assert_not_called() + + def test_create_sets_session_properties(self): + """Test that create sets session properties.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + { + "pyspark": MagicMock(), + "pyspark.sql": MagicMock() + }, + ): + import sys + + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + DatabricksSessionHandle.create(session_properties={ + "key1": "value1", + "key2": 123 + }) + + mock_spark.conf.set.assert_any_call("key1", "value1") + mock_spark.conf.set.assert_any_call("key2", "123") From 6d4c5f0253ffa77f6f7b3a41bf7156b271f63681 Mon Sep 17 00:00:00 2001 From: Alexey Egorov <5102843+alexeyegorov@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:14:36 +0100 Subject: [PATCH 2/2] fix: formatting etc. --- dbt/adapters/databricks/connections.py | 198 +++++------ dbt/adapters/databricks/credentials.py | 131 +++---- dbt/adapters/databricks/handle.py | 11 +- .../python_models/python_submissions.py | 332 ++++++++---------- dbt/adapters/databricks/session.py | 89 +++-- tests/unit/test_session.py | 74 ++-- 6 files changed, 361 insertions(+), 474 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index c28917b20..5211ecaa3 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -96,20 +96,16 @@ def from_context(query_header_context: Any) -> "QueryContextWrapper": compute_name = None model_query_tags_override = None materialized = None - relation_name = getattr(query_header_context, "relation_name", - "[unknown]") + relation_name = getattr(query_header_context, "relation_name", "[unknown]") # Extract config-related attributes safely - if hasattr(query_header_context, - "config") and query_header_context.config: + if hasattr(query_header_context, "config") and query_header_context.config: config = query_header_context.config compute_name = config.get("databricks_compute") - query_tags_str = config.extra.get("query_tags") if hasattr( - config, "extra") else None + query_tags_str = config.extra.get("query_tags") if hasattr(config, "extra") else None if query_tags_str: - model_query_tags_override = QueryTagsUtils.parse_query_tags( - query_tags_str) + model_query_tags_override = QueryTagsUtils.parse_query_tags(query_tags_str) if hasattr(config, "materialized"): materialized = config.materialized @@ -125,7 +121,6 @@ def from_context(query_header_context: Any) -> "QueryContextWrapper": class DatabricksMacroQueryStringSetter(MacroQueryStringSetter): - def _get_comment_macro(self) -> Optional[str]: if self.config.query_comment.comment == DEFAULT_QUERY_COMMENT: return DATABRICKS_QUERY_COMMENT @@ -158,16 +153,16 @@ class DatabricksConnectionManager(SparkConnectionManager): _session_capabilities: Optional[dict] = None _dbr_capabilities_cache: dict[str, DBRCapabilities] = {} - def __init__(self, profile: AdapterRequiredConfig, - mp_context: SpawnContext): + def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) self._api_client: Optional[DatabricksApiClient] = None @property def api_client(self) -> DatabricksApiClient: if self._api_client is None: - self._api_client = DatabricksApiClient.create( - cast(DatabricksCredentials, self.profile.credentials), 15 * 60) + self._api_client = DatabricksApiClient( + cast(DatabricksCredentials, self.profile.credentials), 15 * 60 + ) return self._api_client def is_session_mode(self) -> bool: @@ -178,16 +173,15 @@ def is_session_mode(self) -> bool: def is_cluster(self) -> bool: conn = self.get_thread_connection() databricks_conn = cast(DatabricksDBTConnection, conn) - return is_cluster_http_path(databricks_conn.http_path, - conn.credentials.cluster_id) + return is_cluster_http_path(databricks_conn.http_path, conn.credentials.cluster_id) - def _get_capabilities_for_http_path(self, - http_path: str) -> DBRCapabilities: + def _get_capabilities_for_http_path(self, http_path: str) -> DBRCapabilities: return self._dbr_capabilities_cache.get(http_path, DBRCapabilities()) @classmethod - def _query_dbr_version(cls, creds: DatabricksCredentials, - http_path: str) -> Optional[tuple[int, int]]: + def _query_dbr_version( + cls, creds: DatabricksCredentials, http_path: str + ) -> Optional[tuple[int, int]]: is_cluster = is_cluster_http_path(http_path, creds.cluster_id) if not is_cluster: @@ -195,16 +189,14 @@ def _query_dbr_version(cls, creds: DatabricksCredentials, try: if cls.credentials_manager is None: - raise DbtRuntimeError( - "credentials_manager must be set before querying DBR version" - ) + raise DbtRuntimeError("credentials_manager must be set before querying DBR version") conn_args = SqlUtils.prepare_connection_arguments( - creds, cls.credentials_manager, http_path, {}) + creds, cls.credentials_manager, http_path, {} + ) with dbsql.connect(**conn_args) as conn: with conn.cursor() as cursor: - cursor.execute( - "SET spark.databricks.clusterUsageTags.sparkVersion") + cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion") result = cursor.fetchone() if result: return SqlUtils.extract_dbr_version(result[1]) @@ -214,8 +206,7 @@ def _query_dbr_version(cls, creds: DatabricksCredentials, return None @classmethod - def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, - http_path: str) -> None: + def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) -> None: if http_path not in cls._dbr_capabilities_cache: is_cluster = is_cluster_http_path(http_path, creds.cluster_id) dbr_version = cls._query_dbr_version(creds, http_path) @@ -236,14 +227,12 @@ def cancel_open(self) -> list[str]: def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection( - ).handle + handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection().handle dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) def set_query_header(self, query_header_context: dict[str, Any]) -> None: - self.query_header = DatabricksMacroQueryStringSetter( - self.profile, query_header_context) + self.query_header = DatabricksMacroQueryStringSetter(self.profile, query_header_context) @contextmanager def exception_handler(self, sql: str) -> Iterator[None]: @@ -269,9 +258,9 @@ def exception_handler(self, sql: str) -> Iterator[None]: raise DbtDatabaseError(str(exc)) from exc # override/overload - def set_connection_name(self, - name: Optional[str] = None, - query_header_context: Any = None) -> Connection: + def set_connection_name( + self, name: Optional[str] = None, query_header_context: Any = None + ) -> Connection: conn_name: str = "master" if name is None else name wrapped = QueryContextWrapper.from_context(query_header_context) @@ -290,14 +279,12 @@ def set_connection_name(self, if conn.name != conn_name: orig_conn_name: str = conn.name or "" conn.name = conn_name - fire_event( - ConnectionReused(orig_conn_name=orig_conn_name, - conn_name=conn_name)) + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) return conn def _create_fresh_connection( - self, conn_name: str, query_header_context: QueryContextWrapper + self, conn_name: str, query_header_context: QueryContextWrapper ) -> DatabricksDBTConnection: conn = DatabricksDBTConnection( type=Identifier(self.TYPE), @@ -308,13 +295,10 @@ def _create_fresh_connection( credentials=self.profile.credentials, ) creds = cast(DatabricksCredentials, self.profile.credentials) - conn.http_path = QueryConfigUtils.get_http_path( - query_header_context, creds) - conn.thread_identifier = cast(tuple[int, int], - self.get_thread_identifier()) + conn.http_path = QueryConfigUtils.get_http_path(query_header_context, creds) + conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) conn._query_header_context = query_header_context - conn.capabilities = self._get_capabilities_for_http_path( - conn.http_path) + conn.capabilities = self._get_capabilities_for_http_path(conn.http_path) conn.handle = LazyHandle(self.open) @@ -322,9 +306,8 @@ def _create_fresh_connection( self.set_thread_connection(conn) fire_event( - NewConnection(conn_name=conn_name, - conn_type=self.TYPE, - node_info=get_node_info())) + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) return conn @@ -346,9 +329,7 @@ def add_query( close_cursor: bool = False, ) -> tuple[Connection, Any]: connection = self.get_thread_connection() - fire_event( - ConnectionUsed(conn_type=self.TYPE, - conn_name=cast_to_str(connection.name))) + fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): cursor: Optional[CursorWrapper | SessionCursorWrapper] = None @@ -362,7 +343,8 @@ def add_query( conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info(), - )) + ) + ) pre = time.time() @@ -370,14 +352,15 @@ def add_query( cursor = handle.execute(sql, bindings) response = self.get_response(cursor) # SQLQueryStatus in 1.10.x may not support query_id parameter - query_id = getattr(response, 'query_id', None) + query_id = getattr(response, "query_id", None) fire_event( SQLQueryStatus( status=str(response), elapsed=round((time.time() - pre), 2), node_info=get_node_info(), - query_id=response.query_id, - )) + query_id=query_id, + ) + ) return connection, cursor except Error: @@ -412,14 +395,13 @@ def execute( def _execute_with_cursor( self, log_sql: str, - f: Callable[[DatabricksHandle | DatabricksSessionHandle], - CursorWrapper | SessionCursorWrapper], + f: Callable[ + [DatabricksHandle | DatabricksSessionHandle], CursorWrapper | SessionCursorWrapper + ], ) -> "Table": connection = self.get_thread_connection() - fire_event( - ConnectionUsed(conn_type=self.TYPE, - conn_name=cast_to_str(connection.name))) + fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): cursor: Optional[CursorWrapper | SessionCursorWrapper] = None @@ -429,7 +411,8 @@ def _execute_with_cursor( conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info(), - )) + ) + ) pre = time.time() @@ -443,23 +426,21 @@ def _execute_with_cursor( status=str(response), elapsed=round((time.time() - pre), 2), node_info=get_node_info(), - )) + ) + ) return self.get_result_from_cursor(cursor, None) finally: if cursor: cursor.close() - def list_schemas(self, - database: str, - schema: Optional[str] = None) -> "Table": + def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table": database = database.strip("`") if schema: schema = schema.strip("`").lower() return self._execute_with_cursor( f"GetSchemas(database={database}, schema={schema})", - lambda cursor: cursor.list_schemas(database=database, - schema=schema), + lambda cursor: cursor.list_schemas(database=database, schema=schema), ) def list_tables(self, database: str, schema: str) -> "Table": @@ -467,15 +448,13 @@ def list_tables(self, database: str, schema: str) -> "Table": schema = schema.strip("`").lower() return self._execute_with_cursor( f"GetTables(database={database}, schema={schema})", - lambda cursor: cursor.list_tables(database=database, schema=schema - ), + lambda cursor: cursor.list_tables(database=database, schema=schema), ) # override def release(self) -> None: with self.lock: - conn = cast(Optional[DatabricksDBTConnection], - self.get_if_exists()) + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) if conn is None: return @@ -486,12 +465,9 @@ def release(self) -> None: def cleanup_all(self) -> None: with self.lock: # Close the current thread connection if it exists - conn = cast(Optional[DatabricksDBTConnection], - self.get_if_exists()) + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) if conn: - fire_event( - ConnectionClosedInCleanup( - conn_name=cast_to_str(conn.name))) + fire_event(ConnectionClosedInCleanup(conn_name=cast_to_str(conn.name))) self.close(conn) # garbage collect these connections @@ -513,8 +489,9 @@ def open(cls, connection: Connection) -> Connection: return cls._open_dbsql(databricks_connection, creds) @classmethod - def _open_session(cls, databricks_connection: DatabricksDBTConnection, - creds: DatabricksCredentials) -> Connection: + def _open_session( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: """Open a connection using SparkSession mode.""" logger.debug("Opening connection in session mode") @@ -534,8 +511,7 @@ def connect() -> DatabricksSessionHandle: return handle except Exception as exc: logger.error(ConnectionCreateError(exc)) - raise DbtDatabaseError( - f"Failed to create session connection: {exc}") from exc + raise DbtDatabaseError(f"Failed to create session connection: {exc}") from exc # Session mode doesn't need retry logic as SparkSession is already available databricks_connection.handle = connect() @@ -543,8 +519,7 @@ def connect() -> DatabricksSessionHandle: return databricks_connection @classmethod - def _cache_session_capabilities(cls, - handle: DatabricksSessionHandle) -> None: + def _cache_session_capabilities(cls, handle: DatabricksSessionHandle) -> None: """Cache DBR capabilities for session mode (simplified for 1.10.x).""" if cls._session_capabilities is None: dbr_version = handle.dbr_version @@ -552,12 +527,12 @@ def _cache_session_capabilities(cls, "dbr_version": dbr_version, "is_sql_warehouse": False, } - logger.debug( - f"Cached session capabilities: DBR version {dbr_version}") + logger.debug(f"Cached session capabilities: DBR version {dbr_version}") @classmethod - def _open_dbsql(cls, databricks_connection: DatabricksDBTConnection, - creds: DatabricksCredentials) -> Connection: + def _open_dbsql( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: """Open a connection using DBSQL connector.""" timeout = creds.connect_timeout @@ -567,31 +542,28 @@ def _open_dbsql(cls, databricks_connection: DatabricksDBTConnection, cls.credentials_manager = credentials_manager # Get merged query tags if we have query header context - query_header_context = getattr(databricks_connection, - "_query_header_context", None) + query_header_context = getattr(databricks_connection, "_query_header_context", None) merged_query_tags = {} if query_header_context: - merged_query_tags = QueryConfigUtils.get_merged_query_tags( - query_header_context, creds) + merged_query_tags = QueryConfigUtils.get_merged_query_tags(query_header_context, creds) conn_args = SqlUtils.prepare_connection_arguments( - creds, cls.credentials_manager, databricks_connection.http_path, - merged_query_tags) + creds, cls.credentials_manager, databricks_connection.http_path, merged_query_tags + ) def connect() -> DatabricksHandle: try: # TODO: what is the error when a user specifies a catalog they don't have access to conn = DatabricksHandle.from_connection_args( conn_args, - is_cluster_http_path(databricks_connection.http_path, - creds.cluster_id), + is_cluster_http_path(databricks_connection.http_path, creds.cluster_id), ) if conn: databricks_connection.session_id = conn.session_id - cls._cache_dbr_capabilities( - creds, databricks_connection.http_path) + cls._cache_dbr_capabilities(creds, databricks_connection.http_path) databricks_connection.capabilities = cls._dbr_capabilities_cache[ - databricks_connection.http_path] + databricks_connection.http_path + ] return conn else: raise DbtDatabaseError("Failed to create connection") @@ -613,8 +585,7 @@ def exponential_backoff(attempt: int) -> int: logger=logger, retryable_exceptions=retryable_exceptions, retry_limit=creds.connect_retries, - retry_timeout=(timeout - if timeout is not None else exponential_backoff), + retry_timeout=(timeout if timeout is not None else exponential_backoff), ) # override @@ -649,8 +620,7 @@ class QueryConfigUtils: """ @staticmethod - def get_http_path(context: QueryContextWrapper, - creds: DatabricksCredentials) -> str: + def get_http_path(context: QueryContextWrapper, creds: DatabricksCredentials) -> str: """ Get the http_path for the compute specified for the node. If none is specified default will be used. @@ -662,8 +632,7 @@ def get_http_path(context: QueryContextWrapper, # Get the http_path for the named compute. http_path = None if creds.compute: - http_path = creds.compute.get(context.compute_name, - {}).get("http_path", None) + http_path = creds.compute.get(context.compute_name, {}).get("http_path", None) # no http_path for the named compute resource is an error condition if not http_path: @@ -692,24 +661,27 @@ def get_merged_query_tags( # Default tags that will only exists for queries tied to a specific model if query_header_context: - if hasattr(query_header_context, - "model_name") and query_header_context.model_name: + if hasattr(query_header_context, "model_name") and query_header_context.model_name: default_tags[QueryTagsUtils.DBT_MODEL_NAME_QUERY_TAG_KEY] = ( - query_header_context.model_name) - if hasattr(query_header_context, - "materialized") and query_header_context.materialized: + query_header_context.model_name + ) + if hasattr(query_header_context, "materialized") and query_header_context.materialized: default_tags[QueryTagsUtils.DBT_MATERIALIZED_QUERY_TAG_KEY] = ( - query_header_context.materialized) + query_header_context.materialized + ) # Parse connection tags from JSON string - connection_tags = (QueryTagsUtils.parse_query_tags(creds.query_tags) - if creds.query_tags else {}) + connection_tags = ( + QueryTagsUtils.parse_query_tags(creds.query_tags) if creds.query_tags else {} + ) # Extract model-level query tags from context model_tags = {} - if (query_header_context - and hasattr(query_header_context, "model_query_tags_override") - and query_header_context.model_query_tags_override): + if ( + query_header_context + and hasattr(query_header_context, "model_query_tags_override") + and query_header_context.model_query_tags_override + ): model_tags = query_header_context.model_query_tags_override return QueryTagsUtils.merge_query_tags( diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index fa0671cd3..1471fab25 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -27,8 +27,7 @@ CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") -EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile( - r"/?sql/protocolv1/o/\d+/(.*)") +EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" REDIRECT_URL = "http://localhost:8020" @@ -84,8 +83,7 @@ def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) data.setdefault("database", None) data.setdefault("connection_parameters", {}) - data["connection_parameters"].setdefault( - "_retry_stop_after_attempts_count", 30) + data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30) data["connection_parameters"].setdefault("_retry_delay_max", 60) return data @@ -99,48 +97,48 @@ def __post_init__(self) -> None: session_properties = self.session_properties or {} if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties: if self.database is None: - self.database = session_properties[ - CATALOG_KEY_IN_SESSION_PROPERTIES] + self.database = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] del session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] else: raise DbtValidationError( f"Got duplicate keys: (`{CATALOG_KEY_IN_SESSION_PROPERTIES}` " - 'in session_properties) all map to "database"') + 'in session_properties) all map to "database"' + ) self.session_properties = session_properties if self.database is not None: database = self.database.strip() if not database: - raise DbtValidationError( - f"Invalid catalog name : `{self.database}`.") + raise DbtValidationError(f"Invalid catalog name : `{self.database}`.") self.database = database else: self.database = "hive_metastore" connection_parameters = self.connection_parameters or {} for key in ( - "server_hostname", - "http_path", - "access_token", - "client_id", - "client_secret", - "session_configuration", - "catalog", - "schema", - "_user_agent_entry", - "user_agent_entry", + "server_hostname", + "http_path", + "access_token", + "client_id", + "client_secret", + "session_configuration", + "catalog", + "schema", + "_user_agent_entry", + "user_agent_entry", ): if key in connection_parameters: - raise DbtValidationError( - f"The connection parameter `{key}` is reserved.") + raise DbtValidationError(f"The connection parameter `{key}` is reserved.") if "http_headers" in connection_parameters: http_headers = connection_parameters["http_headers"] if not isinstance(http_headers, dict) or any( - not isinstance(key, str) or not isinstance(value, str) - for key, value in http_headers.items()): + not isinstance(key, str) or not isinstance(value, str) + for key, value in http_headers.items() + ): raise DbtValidationError( "The connection parameter `http_headers` should be dict of strings: " - f"{http_headers}.") + f"{http_headers}." + ) if "_socket_timeout" not in connection_parameters: connection_parameters["_socket_timeout"] = 600 self.connection_parameters = connection_parameters @@ -150,15 +148,13 @@ def __post_init__(self) -> None: # Only create credentials manager for non-session mode if not self.is_session_mode: - self._credentials_manager = DatabricksCredentialManager.create_from( - self) + self._credentials_manager = DatabricksCredentialManager.create_from(self) def _init_connection_method(self) -> None: """Initialize and validate the connection method.""" if self.method is None: # Auto-detect session mode - if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, - "").lower() == "true": + if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, "").lower() == "true": self.method = CONNECTION_METHOD_SESSION elif os.getenv(DATABRICKS_RUNTIME_VERSION_ENV) and not self.host: # Running on Databricks cluster without host configured @@ -167,8 +163,7 @@ def _init_connection_method(self) -> None: self.method = CONNECTION_METHOD_DBSQL # Validate method value - if self.method not in (CONNECTION_METHOD_SESSION, - CONNECTION_METHOD_DBSQL): + if self.method not in (CONNECTION_METHOD_SESSION, CONNECTION_METHOD_DBSQL): raise DbtValidationError( f"Invalid connection method: '{self.method}'. " f"Must be '{CONNECTION_METHOD_SESSION}' or '{CONNECTION_METHOD_DBSQL}'." @@ -203,8 +198,7 @@ def _validate_dbsql_creds(self) -> None: """Validate credentials for DBSQL connector mode.""" for key in ["host", "http_path"]: if not getattr(self, key): - raise DbtConfigError( - f"The config '{key}' is required to connect to Databricks") + raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") if not self.token and self.auth_type != "oauth": raise DbtConfigError( "The config `auth_type: oauth` is required when not using access token" @@ -213,13 +207,16 @@ def _validate_dbsql_creds(self) -> None: if not self.client_id and self.client_secret: raise DbtConfigError( "The config 'client_id' is required to connect " - "to Databricks when 'client_secret' is present") + "to Databricks when 'client_secret' is present" + ) if (not self.azure_client_id and self.azure_client_secret) or ( - self.azure_client_id and not self.azure_client_secret): + self.azure_client_id and not self.azure_client_secret + ): raise DbtConfigError( "The config 'azure_client_id' and 'azure_client_secret' " - "must be both present or both absent") + "must be both present or both absent" + ) @classmethod def get_invocation_env(cls) -> Optional[str]: @@ -228,23 +225,25 @@ def get_invocation_env(cls) -> Optional[str]: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env): - raise DbtValidationError( - f"Invalid invocation environment: {invocation_env}") + raise DbtValidationError(f"Invalid invocation environment: {invocation_env}") return invocation_env @classmethod - def get_all_http_headers( - cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: + def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: http_session_headers_str = GlobalState.get_http_session_headers() - http_session_headers_dict: dict[str, str] = ({ - k: - v if isinstance(v, str) else json.dumps(v) - for k, v in json.loads(http_session_headers_str).items() - } if http_session_headers_str is not None else {}) + http_session_headers_dict: dict[str, str] = ( + { + k: v if isinstance(v, str) else json.dumps(v) + for k, v in json.loads(http_session_headers_str).items() + } + if http_session_headers_str is not None + else {} + ) - intersect_http_header_keys = (user_http_session_headers.keys() - & http_session_headers_dict.keys()) + intersect_http_header_keys = ( + user_http_session_headers.keys() & http_session_headers_dict.keys() + ) if len(intersect_http_header_keys) > 0: raise DbtValidationError( @@ -266,19 +265,13 @@ def unique_field(self) -> str: return f"session://{self.database}/{self.schema}" return cast(str, self.host) - def connection_info(self, - *, - with_aliases: bool = False - ) -> Iterable[tuple[str, Any]]: + def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]: as_dict = self.to_dict(omit_none=False) connection_keys = set(self._connection_keys(with_aliases=with_aliases)) aliases: list[str] = [] if with_aliases: - aliases = [ - k for k, v in self._ALIASES.items() if v in connection_keys - ] - for key in itertools.chain( - self._connection_keys(with_aliases=with_aliases), aliases): + aliases = [k for k, v in self._ALIASES.items() if v in connection_keys] + for key in itertools.chain(self._connection_keys(with_aliases=with_aliases), aliases): if key in as_dict: yield key, as_dict[key] @@ -291,9 +284,7 @@ def _connection_keys_session(self) -> tuple[str, ...]: connection_keys.append("session_properties") return tuple(connection_keys) - def _connection_keys(self, - *, - with_aliases: bool = False) -> tuple[str, ...]: + def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)` # is called from only: # @@ -326,8 +317,7 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: @property def cluster_id(self) -> Optional[str]: - return self.extract_cluster_id( - self.http_path) # type: ignore[arg-type] + return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] def authenticate(self) -> Optional["DatabricksCredentialManager"]: """Authenticate and return credentials manager. @@ -376,9 +366,7 @@ class DatabricksCredentialManager(DataClassDictMixin): auth_type: Optional[str] = None @classmethod - def create_from( - cls, credentials: DatabricksCredentials - ) -> "DatabricksCredentialManager": + def create_from(cls, credentials: DatabricksCredentials) -> "DatabricksCredentialManager": return DatabricksCredentialManager( host=credentials.host or "", token=credentials.token, @@ -443,10 +431,8 @@ def __post_init__(self) -> None: self._config = self.authenticate_with_external_browser() else: auth_methods = { - "oauth-m2m": - self.authenticate_with_oauth_m2m, - "legacy-azure-client-secret": - self.legacy_authenticate_with_azure_client_secret, + "oauth-m2m": self.authenticate_with_oauth_m2m, + "legacy-azure-client-secret": self.legacy_authenticate_with_azure_client_secret, } # If the secret starts with dose, high chance is it is a databricks secret @@ -468,20 +454,18 @@ def __post_init__(self) -> None: break # Exit loop if authentication is successful except Exception as e: exceptions.append((auth_type, e)) - next_auth_type = auth_sequence[i + 1] if i + 1 < len( - auth_sequence) else None + next_auth_type = auth_sequence[i + 1] if i + 1 < len(auth_sequence) else None if next_auth_type: logger.warning( f"Failed to authenticate with {auth_type}, " - f"trying {next_auth_type} next. Error: {e}") + f"trying {next_auth_type} next. Error: {e}" + ) else: logger.error( f"Failed to authenticate with {auth_type}. " f"No more authentication methods to try. Error: {e}" ) - raise Exception( - f"All authentication methods failed. Details: {exceptions}" - ) + raise Exception(f"All authentication methods failed. Details: {exceptions}") @property def api_client(self) -> WorkspaceClient: @@ -489,7 +473,6 @@ def api_client(self) -> WorkspaceClient: @property def credentials_provider(self) -> PySQLCredentialProvider: - def inner() -> Callable[[], dict[str, str]]: return self.header_factory diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py index e50e049a3..c11beb200 100644 --- a/dbt/adapters/databricks/handle.py +++ b/dbt/adapters/databricks/handle.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: pass - CursorOp = Callable[[Cursor], None] CursorExecOp = Callable[[Cursor], Cursor] CursorWrapperOp = Callable[["CursorWrapper"], None] @@ -292,20 +291,20 @@ def translate_bindings(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[A def format_bindings_for_sql(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[str]]: """ Format bindings as SQL literals for string substitution in session mode. - + This method properly quotes string values and handles special cases to ensure SQL injection safety and correct SQL syntax. Used when executing SQL via SparkSession.sql() which doesn't support parameterized queries. - + Args: bindings: Sequence of binding values (strings, numbers, None, etc.) - + Returns: Sequence of SQL literal strings, or None if bindings is None/empty """ if not bindings: return None - + formatted = [] for value in bindings: if value is None: @@ -326,7 +325,7 @@ def format_bindings_for_sql(bindings: Optional[Sequence[Any]]) -> Optional[Seque # For other types, convert to string and quote escaped = str(value).replace("'", "''") formatted.append(f"'{escaped}'") - + return formatted @staticmethod diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index f661f35f3..b71ca7e65 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -32,8 +32,7 @@ class BaseDatabricksHelper(PythonJobHelper): tracker = PythonRunTracker() - def __init__(self, parsed_model: dict, - credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: self.credentials = credentials self.credentials.validate_creds() self.parsed_model = ParsedPythonModel(**parsed_model) @@ -66,8 +65,9 @@ def submit(self, compiled_code: str) -> None: class PythonCommandSubmitter(PythonSubmitter): """Submitter for Python models using the Command API.""" - def __init__(self, api_client: DatabricksApiClient, - tracker: PythonRunTracker, cluster_id: str) -> None: + def __init__( + self, api_client: DatabricksApiClient, tracker: PythonRunTracker, cluster_id: str + ) -> None: self.api_client = api_client self.tracker = tracker self.cluster_id = cluster_id @@ -80,7 +80,8 @@ def submit(self, compiled_code: str) -> None: command_exec: Optional[CommandExecution] = None try: command_exec = self.api_client.commands.execute( - self.cluster_id, context_id, compiled_code) + self.cluster_id, context_id, compiled_code + ) self.tracker.insert_command(command_exec) self.api_client.commands.poll_for_completion(command_exec) @@ -88,46 +89,42 @@ def submit(self, compiled_code: str) -> None: finally: if command_exec: self.tracker.remove_command(command_exec) - self.api_client.command_contexts.destroy(self.cluster_id, - context_id) + self.api_client.command_contexts.destroy(self.cluster_id, context_id) class PythonNotebookUploader: """Uploads a compiled Python model as a notebook to the Databricks workspace.""" - def __init__(self, api_client: DatabricksApiClient, - parsed_model: ParsedPythonModel) -> None: + def __init__(self, api_client: DatabricksApiClient, parsed_model: ParsedPythonModel) -> None: self.api_client = api_client self.catalog = parsed_model.catalog self.schema = parsed_model.schema_ self.identifier = parsed_model.identifier - self.job_grants = (parsed_model.config.python_job_config.grants - if parsed_model.config.python_job_config else {}) + self.job_grants = ( + parsed_model.config.python_job_config.grants + if parsed_model.config.python_job_config + else {} + ) self.notebook_access_control_list = parsed_model.config.notebook_access_control_list def upload(self, compiled_code: str) -> str: """Upload the compiled code to the Databricks workspace.""" - logger.debug(f"[Notebook Upload Debug] Creating workspace dir for " - f"catalog={self.catalog}, schema={self.schema}") - workdir = self.api_client.workspace.create_python_model_dir( - self.catalog, self.schema) - file_path = f"{workdir}{self.identifier}" logger.debug( - f"[Notebook Upload Debug] Uploading notebook to path: {file_path}") + f"[Notebook Upload Debug] Creating workspace dir for " + f"catalog={self.catalog}, schema={self.schema}" + ) + workdir = self.api_client.workspace.create_python_model_dir(self.catalog, self.schema) + file_path = f"{workdir}{self.identifier}" + logger.debug(f"[Notebook Upload Debug] Uploading notebook to path: {file_path}") # Log notebook content length - logger.debug( - f"[Notebook Upload Debug] Notebook content length: {len(compiled_code)} chars" - ) + logger.debug(f"[Notebook Upload Debug] Notebook content length: {len(compiled_code)} chars") self.api_client.workspace.upload_notebook(file_path, compiled_code) - logger.debug( - f"[Notebook Upload Debug] Successfully uploaded notebook to {file_path}" - ) + logger.debug(f"[Notebook Upload Debug] Successfully uploaded notebook to {file_path}") if self.job_grants or self.notebook_access_control_list: - logger.debug( - "[Notebook Upload Debug] Setting permissions for notebook") + logger.debug("[Notebook Upload Debug] Setting permissions for notebook") self.set_notebook_permissions(file_path) return file_path @@ -137,19 +134,14 @@ def set_notebook_permissions(self, notebook_path: str) -> None: permission_builder = PythonPermissionBuilder(self.api_client) access_control_list = permission_builder.build_permissions( - self.job_grants, - self.notebook_access_control_list, - target_type="notebook") + self.job_grants, self.notebook_access_control_list, target_type="notebook" + ) if access_control_list: - logger.debug( - f"Setting permissions on notebook: {notebook_path}") - self.api_client.notebook_permissions.put( - notebook_path, access_control_list) + logger.debug(f"Setting permissions on notebook: {notebook_path}") + self.api_client.notebook_permissions.put(notebook_path, access_control_list) except Exception as e: - logger.error( - f"Failed to set permissions on notebook {notebook_path}: {str(e)}" - ) + logger.error(f"Failed to set permissions on notebook {notebook_path}: {str(e)}") raise DbtRuntimeError( f"Failed to set permissions on notebook: path={notebook_path}, error: {str(e)}" ) @@ -178,30 +170,26 @@ def __init__( def _get_job_owner_for_config(self) -> tuple[str, str]: """Get the owner of the job (and type) for the access control list.""" curr_user = self.api_client.curr_user.get_username() - is_service_principal = self.api_client.curr_user.is_service_principal( - curr_user) + is_service_principal = self.api_client.curr_user.is_service_principal(curr_user) source = "service_principal_name" if is_service_principal else "user_name" return curr_user, source @staticmethod - def _build_job_permission(job_grants: list[dict[str, Any]], - permission: str) -> list[dict[str, Any]]: + def _build_job_permission( + job_grants: list[dict[str, Any]], permission: str + ) -> list[dict[str, Any]]: """Build the access control list for the job.""" - return [{ - **grant, - **{ - "permission_level": permission - } - } for grant in job_grants] + return [{**grant, **{"permission_level": permission}} for grant in job_grants] def _filter_permissions( - self, acls: list[dict[str, Any]], - valid_permissions: set[str]) -> list[dict[str, Any]]: + self, acls: list[dict[str, Any]], valid_permissions: set[str] + ) -> list[dict[str, Any]]: return [ - acl for acl in acls if "permission_level" in acl - and acl["permission_level"] in valid_permissions + acl + for acl in acls + if "permission_level" in acl and acl["permission_level"] in valid_permissions ] def build_job_permissions( @@ -211,19 +199,22 @@ def build_job_permissions( ) -> list[dict[str, Any]]: access_control_list = [] owner, permissions_attribute = self._get_job_owner_for_config() - access_control_list.append({ - permissions_attribute: owner, - "permission_level": "IS_OWNER", - }) + access_control_list.append( + { + permissions_attribute: owner, + "permission_level": "IS_OWNER", + } + ) access_control_list.extend( - self._build_job_permission(job_grants.get("view", []), "CAN_VIEW")) + self._build_job_permission(job_grants.get("view", []), "CAN_VIEW") + ) access_control_list.extend( - self._build_job_permission(job_grants.get("run", []), - "CAN_MANAGE_RUN")) + self._build_job_permission(job_grants.get("run", []), "CAN_MANAGE_RUN") + ) access_control_list.extend( - self._build_job_permission(job_grants.get("manage", []), - "CAN_MANAGE")) + self._build_job_permission(job_grants.get("manage", []), "CAN_MANAGE") + ) combined_acls = access_control_list + acls return self._filter_permissions(combined_acls, self.JOB_PERMISSIONS) @@ -236,21 +227,17 @@ def build_notebook_permissions( access_control_list = [] access_control_list.extend( - self._build_job_permission(job_grants.get("view", []), "CAN_READ")) - access_control_list.extend( - self._build_job_permission(job_grants.get("run", []), "CAN_RUN")) + self._build_job_permission(job_grants.get("view", []), "CAN_READ") + ) + access_control_list.extend(self._build_job_permission(job_grants.get("run", []), "CAN_RUN")) access_control_list.extend( - self._build_job_permission(job_grants.get("manage", []), - "CAN_MANAGE")) + self._build_job_permission(job_grants.get("manage", []), "CAN_MANAGE") + ) combined_acls = access_control_list + acls - filtered_acls = self._filter_permissions(combined_acls, - self.NOTEBOOK_PERMISSIONS) + filtered_acls = self._filter_permissions(combined_acls, self.NOTEBOOK_PERMISSIONS) - return [ - acl for acl in filtered_acls - if acl.get("permission_level") != "IS_OWNER" - ] + return [acl for acl in filtered_acls if acl.get("permission_level") != "IS_OWNER"] def build_permissions( self, @@ -302,12 +289,10 @@ def __init__( packages = parsed_model.config.packages index_url = parsed_model.config.index_url additional_libraries = parsed_model.config.additional_libs - library_config = get_library_config(packages, index_url, - additional_libraries) + library_config = get_library_config(packages, index_url, additional_libraries) self.cluster_spec = {**cluster_spec, **library_config} self.job_grants = parsed_model.config.python_job_config.grants - self.additional_job_settings = parsed_model.config.python_job_config.dict( - ) + self.additional_job_settings = parsed_model.config.python_job_config.dict() self.environment_key = parsed_model.config.environment_key self.environment_deps = parsed_model.config.environment_dependencies @@ -323,27 +308,25 @@ def compile(self, path: str) -> PythonJobDetails: if self.environment_key: job_spec["environment_key"] = self.environment_key - if self.environment_deps and not self.additional_job_settings.get( - "environments"): - additional_job_config["environments"] = [{ - "environment_key": - self.environment_key, - "spec": { - "client": "2", - "dependencies": self.environment_deps - }, - }] + if self.environment_deps and not self.additional_job_settings.get("environments"): + additional_job_config["environments"] = [ + { + "environment_key": self.environment_key, + "spec": {"client": "2", "dependencies": self.environment_deps}, + } + ] job_spec.update(self.cluster_spec) access_control_list = self.permission_builder.build_job_permissions( - self.job_grants, self.access_control_list) + self.job_grants, self.access_control_list + ) if access_control_list: job_spec["access_control_list"] = access_control_list job_spec["queue"] = {"enabled": True} - return PythonJobDetails(run_name=self.run_name, - job_spec=job_spec, - additional_job_config=additional_job_config) + return PythonJobDetails( + run_name=self.run_name, job_spec=job_spec, additional_job_config=additional_job_config + ) class PythonNotebookSubmitter(PythonSubmitter): @@ -376,8 +359,7 @@ def create( parsed_model, cluster_spec, ) - return PythonNotebookSubmitter(api_client, tracker, notebook_uploader, - config_compiler) + return PythonNotebookSubmitter(api_client, tracker, notebook_uploader, config_compiler) @override def submit(self, compiled_code: str) -> None: @@ -387,21 +369,20 @@ def submit(self, compiled_code: str) -> None: job_config = self.config_compiler.compile(file_path) run_id = self.api_client.job_runs.submit( - job_config.run_name, job_config.job_spec, - **job_config.additional_job_config) + job_config.run_name, job_config.job_spec, **job_config.additional_job_config + ) self.tracker.insert_run_id(run_id) try: if "access_control_list" in job_config.job_spec: try: - job_id = self.api_client.job_runs.get_job_id_from_run_id( - run_id) + job_id = self.api_client.job_runs.get_job_id_from_run_id(run_id) logger.debug(f"Setting permissions on job: {job_id}") self.api_client.workflow_permissions.patch( - job_id, job_config.job_spec["access_control_list"]) + job_id, job_config.job_spec["access_control_list"] + ) except Exception as e: - logger.error( - f"Failed to set permissions on job {run_id}: {str(e)}") + logger.error(f"Failed to set permissions on job {run_id}: {str(e)}") raise DbtRuntimeError( f"Failed to set permissions on job: run_id={run_id}, error: {str(e)}" ) @@ -436,8 +417,7 @@ class AllPurposeClusterPythonJobHelper(BaseDatabricksHelper): Top level helper for Python models using job runs or Command API on an all-purpose cluster. """ - def __init__(self, parsed_model: dict, - credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: self.credentials = credentials self.credentials.validate_creds() self.parsed_model = ParsedPythonModel(**parsed_model) @@ -451,7 +431,8 @@ def __init__(self, parsed_model: dict, config = self.parsed_model.config self.create_notebook = config.create_notebook self.cluster_id = config.cluster_id or self.credentials.extract_cluster_id( - config.http_path or self.credentials.http_path or "") + config.http_path or self.credentials.http_path or "" + ) self.validate_config() self.command_submitter = self.build_submitter() @@ -466,23 +447,22 @@ def build_submitter(self) -> PythonSubmitter: {"existing_cluster_id": self.cluster_id}, ) else: - return PythonCommandSubmitter(self.api_client, self.tracker, - self.cluster_id or "") + return PythonCommandSubmitter(self.api_client, self.tracker, self.cluster_id or "") @override def validate_config(self) -> None: if not self.cluster_id: raise ValueError( "Databricks `http_path` or `cluster_id` of an all-purpose cluster is required " - "for the `all_purpose_cluster` submission method.") + "for the `all_purpose_cluster` submission method." + ) class ServerlessClusterPythonJobHelper(BaseDatabricksHelper): """Top level helper for Python models using job runs on a serverless cluster.""" def build_submitter(self) -> PythonSubmitter: - return PythonNotebookSubmitter.create(self.api_client, self.tracker, - self.parsed_model, {}) + return PythonNotebookSubmitter.create(self.api_client, self.tracker, self.parsed_model, {}) class PythonWorkflowConfigCompiler: @@ -501,22 +481,18 @@ def __init__( self.post_hook_tasks = post_hook_tasks @staticmethod - def create( - parsed_model: ParsedPythonModel) -> "PythonWorkflowConfigCompiler": - cluster_settings = PythonWorkflowConfigCompiler.cluster_settings( - parsed_model) + def create(parsed_model: ParsedPythonModel) -> "PythonWorkflowConfigCompiler": + cluster_settings = PythonWorkflowConfigCompiler.cluster_settings(parsed_model) config = parsed_model.config if config.python_job_config: - cluster_settings.update( - config.python_job_config.additional_task_settings) + cluster_settings.update(config.python_job_config.additional_task_settings) workflow_spec = config.python_job_config.dict() - workflow_spec["name"] = PythonWorkflowConfigCompiler.workflow_name( - parsed_model) + workflow_spec["name"] = PythonWorkflowConfigCompiler.workflow_name(parsed_model) existing_job_id = config.python_job_config.existing_job_id post_hook_tasks = config.python_job_config.post_hook_tasks - return PythonWorkflowConfigCompiler(cluster_settings, - workflow_spec, existing_job_id, - post_hook_tasks) + return PythonWorkflowConfigCompiler( + cluster_settings, workflow_spec, existing_job_id, post_hook_tasks + ) else: return PythonWorkflowConfigCompiler(cluster_settings, {}, "", []) @@ -526,8 +502,7 @@ def workflow_name(parsed_model: ParsedPythonModel) -> str: if parsed_model.config.python_job_config: name = parsed_model.config.python_job_config.name return ( - name or - f"dbt__{parsed_model.catalog}-{parsed_model.schema_}-{parsed_model.identifier}" + name or f"dbt__{parsed_model.catalog}-{parsed_model.schema_}-{parsed_model.identifier}" ) @staticmethod @@ -612,8 +587,7 @@ def __init__( @staticmethod def create( - api_client: DatabricksApiClient, tracker: PythonRunTracker, - parsed_model: ParsedPythonModel + api_client: DatabricksApiClient, tracker: PythonRunTracker, parsed_model: ParsedPythonModel ) -> "PythonNotebookWorkflowSubmitter": uploader = PythonNotebookUploader(api_client, parsed_model) config_compiler = PythonWorkflowConfigCompiler.create(parsed_model) @@ -644,43 +618,34 @@ def submit(self, compiled_code: str) -> None: file_path = self.uploader.upload(compiled_code) logger.debug(f"[Workflow Debug] Uploaded notebook to: {file_path}") - workflow_config, existing_job_id = self.config_compiler.compile( - file_path) + workflow_config, existing_job_id = self.config_compiler.compile(file_path) logger.debug(f"[Workflow Debug] Workflow config: {workflow_config}") logger.debug(f"[Workflow Debug] Existing job ID: {existing_job_id}") - job_id = self.workflow_creater.create_or_update( - workflow_config, existing_job_id) + job_id = self.workflow_creater.create_or_update(workflow_config, existing_job_id) logger.debug(f"[Workflow Debug] Created/updated job ID: {job_id}") access_control_list = self.permission_builder.build_job_permissions( - self.job_grants, self.acls) + self.job_grants, self.acls + ) logger.debug(f"[Workflow Debug] Setting ACL: {access_control_list}") self.api_client.workflow_permissions.put(job_id, access_control_list) - logger.debug( - f"[Workflow Debug] Running job {job_id} with queueing enabled") + logger.debug(f"[Workflow Debug] Running job {job_id} with queueing enabled") run_id = self.api_client.workflows.run(job_id, enable_queueing=True) - logger.debug( - f"[Workflow Debug] Started workflow run with ID: {run_id}") + logger.debug(f"[Workflow Debug] Started workflow run with ID: {run_id}") self.tracker.insert_run_id(run_id) try: - logger.debug( - f"[Workflow Debug] Polling for completion of run {run_id}") + logger.debug(f"[Workflow Debug] Polling for completion of run {run_id}") self.api_client.job_runs.poll_for_completion(run_id) - logger.debug( - f"[Workflow Debug] Workflow run {run_id} completed successfully" - ) + logger.debug(f"[Workflow Debug] Workflow run {run_id} completed successfully") except Exception as e: - logger.error( - f"[Workflow Debug] Workflow run {run_id} failed with error: {e}" - ) + logger.error(f"[Workflow Debug] Workflow run {run_id} failed with error: {e}") # Try to get more info about the failure try: run_info = self.api_client.job_runs.get_run_info(run_id) - logger.error( - f"[Workflow Debug] Run info for failed run: {run_info}") + logger.error(f"[Workflow Debug] Run info for failed run: {run_info}") except Exception: pass raise @@ -693,9 +658,9 @@ class WorkflowPythonJobHelper(BaseDatabricksHelper): @override def build_submitter(self) -> PythonSubmitter: - return PythonNotebookWorkflowSubmitter.create(self.api_client, - self.tracker, - self.parsed_model) + return PythonNotebookWorkflowSubmitter.create( + self.api_client, self.tracker, self.parsed_model + ) class SessionStateManager: @@ -707,7 +672,8 @@ def cleanup_temp_views(spark: "SparkSession") -> None: try: # Get list of temp views from the current database temp_views = [ - row.viewName for row in spark.sql("SHOW VIEWS").collect() + row.viewName + for row in spark.sql("SHOW VIEWS").collect() if hasattr(row, "isTemporary") and row.isTemporary ] for view in temp_views: @@ -747,8 +713,7 @@ def submit(self, compiled_code: str) -> None: try: # Get clean execution context - exec_globals = self._state_manager.get_clean_exec_globals( - self._spark) + exec_globals = self._state_manager.get_clean_exec_globals(self._spark) # Log a preview of the code being executed preview_len = min(500, len(compiled_code)) @@ -763,8 +728,7 @@ def submit(self, compiled_code: str) -> None: # 3. No collect() - data stays distributed exec(compiled_code, exec_globals) - logger.debug( - "[Session Python] Model execution completed successfully") + logger.debug("[Session Python] Model execution completed successfully") except Exception as e: logger.error(f"Python model execution failed: {e}") @@ -775,8 +739,7 @@ def submit(self, compiled_code: str) -> None: try: self._state_manager.cleanup_temp_views(self._spark) except Exception as cleanup_error: - logger.warning( - f"Failed to cleanup temp views: {cleanup_error}") + logger.warning(f"Failed to cleanup temp views: {cleanup_error}") class SessionPythonJobHelper(PythonJobHelper): @@ -789,15 +752,15 @@ class SessionPythonJobHelper(PythonJobHelper): tracker = PythonRunTracker() - def __init__(self, parsed_model: dict, - credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: self.credentials = credentials self.parsed_model = ParsedPythonModel(**parsed_model) # Get SparkSession directly - no API client needed import os - from pyspark.sql import SparkSession + from pyspark import SparkContext + from pyspark.sql import SparkSession # On Databricks, the SparkSession may already exist or we may need to create it # Try multiple methods to get an existing session first @@ -811,36 +774,27 @@ def __init__(self, parsed_model: dict, # Create SparkSession from existing SparkContext # This avoids the need for a master URL spark = SparkSession(sc) - logger.debug( - "[Session Python] Got SparkSession from active SparkContext" - ) + logger.debug("[Session Python] Got SparkSession from active SparkContext") except Exception as e: - logger.debug( - f"[Session Python] Could not get SparkSession from SparkContext: {e}" - ) + logger.debug(f"[Session Python] Could not get SparkSession from SparkContext: {e}") # Method 2: Try getActiveSession() (available in PySpark 3.0+) if spark is None: try: spark = SparkSession.getActiveSession() if spark is not None: - logger.debug( - "[Session Python] Got SparkSession from getActiveSession()" - ) + logger.debug("[Session Python] Got SparkSession from getActiveSession()") except (AttributeError, Exception) as e: - logger.debug( - f"[Session Python] getActiveSession() not available or failed: {e}" - ) + logger.debug(f"[Session Python] getActiveSession() not available or failed: {e}") # Method 3: Try to get from global 'spark' variable (Databricks convention) if spark is None: try: import __main__ - if hasattr(__main__, 'spark'): + + if hasattr(__main__, "spark"): spark = __main__.spark - logger.debug( - "[Session Python] Got SparkSession from __main__.spark" - ) + logger.debug("[Session Python] Got SparkSession from __main__.spark") except Exception as e: logger.debug( f"[Session Python] Could not get SparkSession from __main__.spark: {e}" @@ -850,26 +804,32 @@ def __init__(self, parsed_model: dict, if spark is None: databricks_runtime = os.getenv("DATABRICKS_RUNTIME_VERSION") if databricks_runtime: - raise DbtRuntimeError( - "[Session Python] Could not find an existing SparkSession. " - "This typically happens when using the native 'dbt task' in Databricks Jobs, " - "which does not provide a SparkSession context.\n\n" - "Session mode is only compatible with:\n" - " - Databricks Notebooks (where 'spark' is pre-initialized)\n" - " - Python tasks that initialize SparkSession before running dbt\n" - " - Environments where SparkSession is already available\n\n" - "For the native dbt task, use DBSQL mode instead (the default):\n" - " - Set 'method: dbsql' in your profile (or omit 'method' entirely)\n" - " - Configure 'host' and 'http_path' to connect to a SQL warehouse or cluster\n\n" - f"Databricks runtime version: {databricks_runtime}") + raise DbtRuntimeError(f""" + [Session Python] Could not find an existing SparkSession. + This typically happens when using the native 'dbt task' in Databricks Jobs, + which does not provide a SparkSession context. + Session mode is only compatible with: + - Databricks Notebooks (where 'spark' is pre-initialized) + - Python tasks that initialize SparkSession before running dbt + - Environments where SparkSession is already available + For the native dbt task, use DBSQL mode instead (the default): + - Set 'method: dbsql' in your profile (or omit 'method' entirely) + - Configure 'host' and 'http_path' to connect to a SQL warehouse + or cluster. + \n + Databricks runtime version: {databricks_runtime} + """) else: - raise DbtRuntimeError( - "[Session Python] Session mode requires a Databricks cluster environment " - "with an active SparkSession. " - "DATABRICKS_RUNTIME_VERSION environment variable not found. " - "Ensure you are running on a Databricks cluster in a context where " - "SparkSession is available (e.g., Notebook or Python task with Spark initialized)." - ) + raise DbtRuntimeError(f""" + [Session Python] Session mode requires a Databricks cluster environment + with an active SparkSession. + DATABRICKS_RUNTIME_VERSION environment variable not found. + Ensure you are running on a Databricks cluster in a context where + SparkSession is available (e.g., Notebook or Python task with + Spark initialized). + \n + Databricks runtime version: {databricks_runtime} + """) self._spark = spark logger.debug( diff --git a/dbt/adapters/databricks/session.py b/dbt/adapters/databricks/session.py index 66b82ae1a..505e73dd4 100644 --- a/dbt/adapters/databricks/session.py +++ b/dbt/adapters/databricks/session.py @@ -37,16 +37,12 @@ class SessionCursorWrapper: def __init__(self, spark: "SparkSession"): self._spark = spark - self._df: Optional["DataFrame"] = None - self._rows: Optional[list["Row"]] = None + self._df: Optional[DataFrame] = None + self._rows: Optional[list[Row]] = None self._query_id: str = "session-query" self.open = True - def execute( - self, - sql: str, - bindings: Optional[Sequence[Any]] = None - ) -> "SessionCursorWrapper": + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> "SessionCursorWrapper": """Execute a SQL statement and store the resulting DataFrame.""" cleaned_sql = SqlUtils.clean_sql(sql) @@ -176,8 +172,9 @@ def create( A new DatabricksSessionHandle instance """ import os - from pyspark.sql import SparkSession + from pyspark import SparkContext + from pyspark.sql import SparkSession # On Databricks, the SparkSession may already exist or we may need to create it # Try multiple methods to get an existing session first @@ -202,43 +199,51 @@ def create( if spark is not None: logger.debug("Got SparkSession from getActiveSession()") except (AttributeError, Exception) as e: - logger.debug( - f"getActiveSession() not available or failed: {e}") + logger.debug(f"getActiveSession() not available or failed: {e}") # Method 3: Try to get from global 'spark' variable (Databricks convention) if spark is None: try: import __main__ - if hasattr(__main__, 'spark'): + + if hasattr(__main__, "spark"): spark = __main__.spark logger.debug("Got SparkSession from __main__.spark") except Exception as e: - logger.debug( - f"Could not get SparkSession from __main__.spark: {e}") + logger.debug(f"Could not get SparkSession from __main__.spark: {e}") # If no existing SparkSession found, provide a clear error message if spark is None: databricks_runtime = os.getenv("DATABRICKS_RUNTIME_VERSION") if databricks_runtime: raise DbtRuntimeError( - "Session mode could not find an existing SparkSession. " - "This typically happens when using the native 'dbt task' in Databricks Jobs, " - "which does not provide a SparkSession context.\n\n" - "Session mode is only compatible with:\n" - " - Databricks Notebooks (where 'spark' is pre-initialized)\n" - " - Python tasks that initialize SparkSession before running dbt\n" - " - Environments where SparkSession is already available\n\n" - "For the native dbt task, use DBSQL mode instead (the default):\n" - " - Set 'method: dbsql' in your profile (or omit 'method' entirely)\n" - " - Configure 'host' and 'http_path' to connect to a SQL warehouse or cluster\n\n" - f"Databricks runtime version: {databricks_runtime}") - else: - raise DbtRuntimeError( - "Session mode requires a Databricks cluster environment with an active SparkSession. " - "DATABRICKS_RUNTIME_VERSION environment variable not found. " - "Ensure you are running on a Databricks cluster in a context where " - "SparkSession is available (e.g., Notebook or Python task with Spark initialized)." + f"""Session mode could not find an existing SparkSession. + This typically happens when using the native 'dbt task' in Databricks Jobs, + which does not provide a SparkSession context. + \n + Session mode is only compatible with: + - Databricks Notebooks (where 'spark' is pre-initialized) + - Python tasks that initialize SparkSession before running dbt + - Environments where SparkSession is already available + \n + For the native dbt task, use DBSQL mode instead (the default): + - Set 'method: dbsql' in your profile (or omit 'method' entirely) + - Configure 'host' and 'http_path' to connect to a SQL warehouse or cluster + \n + Databricks runtime version: {databricks_runtime} + """ ) + else: + raise DbtRuntimeError(f""" + Session mode requires a Databricks cluster environment with + an active SparkSession. + DATABRICKS_RUNTIME_VERSION environment variable not found. + Ensure you are running on a Databricks cluster in a context where + SparkSession is available (e.g., Notebook or Python task with + Spark initialized). + \n + Databricks runtime version: {databricks_runtime} + """) # Set catalog if provided if catalog: @@ -273,18 +278,16 @@ def dbr_version(self) -> tuple[int, int]: if self._dbr_version is None: try: version_str = self._spark.conf.get( - "spark.databricks.clusterUsageTags.sparkVersion", "") + "spark.databricks.clusterUsageTags.sparkVersion", "" + ) if version_str: - self._dbr_version = SqlUtils.extract_dbr_version( - version_str) + self._dbr_version = SqlUtils.extract_dbr_version(version_str) else: # If we can't get the version, assume latest - logger.warning( - "Could not determine DBR version, assuming latest") + logger.warning("Could not determine DBR version, assuming latest") self._dbr_version = (sys.maxsize, sys.maxsize) except Exception as e: - logger.warning( - f"Failed to get DBR version: {e}, assuming latest") + logger.warning(f"Failed to get DBR version: {e}, assuming latest") self._dbr_version = (sys.maxsize, sys.maxsize) return self._dbr_version @@ -297,14 +300,10 @@ def session_id(self) -> str: except Exception: return "session-unknown" - def execute( - self, - sql: str, - bindings: Optional[Sequence[Any]] = None) -> SessionCursorWrapper: + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> SessionCursorWrapper: """Execute a SQL statement and return a cursor wrapper.""" if not self.open: - raise DbtRuntimeError( - "Attempting to execute on a closed session handle") + raise DbtRuntimeError("Attempting to execute on a closed session handle") if self._cursor: self._cursor.close() @@ -312,9 +311,7 @@ def execute( self._cursor = SessionCursorWrapper(self._spark) return self._cursor.execute(sql, bindings) - def list_schemas(self, - database: str, - schema: Optional[str] = None) -> SessionCursorWrapper: + def list_schemas(self, database: str, schema: Optional[str] = None) -> SessionCursorWrapper: """List schemas in the given database/catalog.""" if schema: sql = f"SHOW SCHEMAS IN {database} LIKE '{schema}'" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b74719463..09ed7105d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -96,7 +96,7 @@ def test_fetchmany_returns_limited_rows(self, cursor, mock_spark): result = cursor.fetchmany(2) assert len(result) == 2 - assert result == [(0, ), (1, )] + assert result == [(0,), (1,)] def test_description_returns_column_info(self, cursor, mock_spark): """Test that description returns column metadata.""" @@ -157,12 +157,10 @@ def test_execute_with_string_special_characters(self, cursor, mock_spark): mock_df = MagicMock() mock_spark.sql.return_value = mock_df - cursor.execute("INSERT INTO test VALUES (%s, %s)", - ("k. A.", "O'Brien")) + cursor.execute("INSERT INTO test VALUES (%s, %s)", ("k. A.", "O'Brien")) # Special characters should be properly quoted and escaped - mock_spark.sql.assert_called_once_with( - "INSERT INTO test VALUES ('k. A.', 'O''Brien')") + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES ('k. A.', 'O''Brien')") def test_execute_with_null_value(self, cursor, mock_spark): """Test that execute handles NULL values correctly.""" @@ -172,8 +170,7 @@ def test_execute_with_null_value(self, cursor, mock_spark): cursor.execute("INSERT INTO test VALUES (%s, %s)", (1, None)) # NULL should be formatted as SQL NULL - mock_spark.sql.assert_called_once_with( - "INSERT INTO test VALUES (1, NULL)") + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES (1, NULL)") def test_execute_with_boolean_values(self, cursor, mock_spark): """Test that execute handles boolean values correctly.""" @@ -183,8 +180,7 @@ def test_execute_with_boolean_values(self, cursor, mock_spark): cursor.execute("INSERT INTO test VALUES (%s, %s)", (True, False)) # Booleans should be formatted as TRUE/FALSE - mock_spark.sql.assert_called_once_with( - "INSERT INTO test VALUES (TRUE, FALSE)") + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES (TRUE, FALSE)") def test_execute_with_numeric_values(self, cursor, mock_spark): """Test that execute handles numeric values correctly.""" @@ -194,32 +190,31 @@ def test_execute_with_numeric_values(self, cursor, mock_spark): cursor.execute("INSERT INTO test VALUES (%s, %s, %s)", (42, 3.14, 100)) # Numbers should remain unquoted - mock_spark.sql.assert_called_once_with( - "INSERT INTO test VALUES (42, 3.14, 100)") + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES (42, 3.14, 100)") def test_execute_with_comma_in_string(self, cursor, mock_spark): """Test that execute properly handles strings containing commas.""" mock_df = MagicMock() mock_spark.sql.return_value = mock_df - cursor.execute("INSERT INTO test VALUES (%s)", - ("Freie und Hansestadt Hamburg", )) + cursor.execute("INSERT INTO test VALUES (%s)", ("Freie und Hansestadt Hamburg",)) # String with spaces and special characters should be properly quoted mock_spark.sql.assert_called_once_with( - "INSERT INTO test VALUES ('Freie und Hansestadt Hamburg')") + "INSERT INTO test VALUES ('Freie und Hansestadt Hamburg')" + ) def test_execute_with_parentheses_in_string(self, cursor, mock_spark): """Test that execute properly handles strings containing parentheses.""" mock_df = MagicMock() mock_spark.sql.return_value = mock_df - cursor.execute("INSERT INTO test VALUES (%s)", - ("Haus in Planung (projektiert)", )) + cursor.execute("INSERT INTO test VALUES (%s)", ("Haus in Planung (projektiert)",)) # String with parentheses should be properly quoted mock_spark.sql.assert_called_once_with( - "INSERT INTO test VALUES ('Haus in Planung (projektiert)')") + "INSERT INTO test VALUES ('Haus in Planung (projektiert)')" + ) class TestDatabricksSessionHandle: @@ -303,8 +298,7 @@ def test_list_schemas_with_pattern(self, handle, mock_spark): handle.list_schemas("my_catalog", "my_schema") - mock_spark.sql.assert_called_with( - "SHOW SCHEMAS IN my_catalog LIKE 'my_schema'") + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog LIKE 'my_schema'") def test_list_tables_executes_show_tables(self, handle, mock_spark): """Test that list_tables executes SHOW TABLES.""" @@ -313,8 +307,7 @@ def test_list_tables_executes_show_tables(self, handle, mock_spark): handle.list_tables("my_catalog", "my_schema") - mock_spark.sql.assert_called_with( - "SHOW TABLES IN my_catalog.my_schema") + mock_spark.sql.assert_called_with("SHOW TABLES IN my_catalog.my_schema") def test_create_gets_existing_spark_session(self): """Test that create finds an existing SparkSession via SparkContext.""" @@ -323,11 +316,8 @@ def test_create_gets_existing_spark_session(self): mock_sc = MagicMock() with patch.dict( - "sys.modules", - { - "pyspark": MagicMock(), - "pyspark.sql": MagicMock() - }, + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, ): import sys @@ -339,8 +329,7 @@ def test_create_gets_existing_spark_session(self): handle = DatabricksSessionHandle.create() # Should have created SparkSession from SparkContext - sys.modules["pyspark.sql"].SparkSession.assert_called_once_with( - mock_sc) + sys.modules["pyspark.sql"].SparkSession.assert_called_once_with(mock_sc) assert handle.session_id == "app-456" def test_create_sets_catalog(self): @@ -350,11 +339,8 @@ def test_create_sets_catalog(self): mock_sc = MagicMock() with patch.dict( - "sys.modules", - { - "pyspark": MagicMock(), - "pyspark.sql": MagicMock() - }, + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, ): import sys @@ -363,8 +349,7 @@ def test_create_sets_catalog(self): DatabricksSessionHandle.create(catalog="my_catalog") - mock_spark.catalog.setCurrentCatalog.assert_called_once_with( - "my_catalog") + mock_spark.catalog.setCurrentCatalog.assert_called_once_with("my_catalog") def test_create_does_not_set_schema(self): """Test that create does NOT set the schema/database. @@ -380,11 +365,8 @@ def test_create_does_not_set_schema(self): mock_sc = MagicMock() with patch.dict( - "sys.modules", - { - "pyspark": MagicMock(), - "pyspark.sql": MagicMock() - }, + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, ): import sys @@ -403,21 +385,15 @@ def test_create_sets_session_properties(self): mock_sc = MagicMock() with patch.dict( - "sys.modules", - { - "pyspark": MagicMock(), - "pyspark.sql": MagicMock() - }, + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, ): import sys sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark - DatabricksSessionHandle.create(session_properties={ - "key1": "value1", - "key2": 123 - }) + DatabricksSessionHandle.create(session_properties={"key1": "value1", "key2": 123}) mock_spark.conf.set.assert_any_call("key1", "value1") mock_spark.conf.set.assert_any_call("key2", "123")