diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index b7e5cc689..bf0b4d2c5 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.11.4" +version = "1.10.15-5" diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 291ac7677..5211ecaa3 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -49,12 +49,12 @@ from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +from dbt.adapters.databricks.session import DatabricksSessionHandle, SessionCursorWrapper from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials if TYPE_CHECKING: from agate import Table - DATABRICKS_QUERY_COMMENT = f""" {{%- set comment_dict = {{}} -%}} {{%- do comment_dict.update( @@ -149,6 +149,8 @@ def has_capability(self, capability: DBRCapability) -> bool: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_manager: Optional[DatabricksCredentialManager] = None + # Cache for session mode (1.10.x doesn't have DBRCapabilities, so we use a simple dict) + _session_capabilities: Optional[dict] = None _dbr_capabilities_cache: dict[str, DBRCapabilities] = {} def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): @@ -158,10 +160,16 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): @property def api_client(self) -> DatabricksApiClient: if self._api_client is None: - credentials = cast(DatabricksCredentials, self.profile.credentials) - self._api_client = DatabricksApiClient(credentials, 15 * 60) + self._api_client = DatabricksApiClient( + cast(DatabricksCredentials, self.profile.credentials), 15 * 60 + ) return self._api_client + def is_session_mode(self) -> bool: + """Check if the connection is using session mode.""" + credentials = cast(DatabricksCredentials, self.profile.credentials) + return credentials.is_session_mode + def is_cluster(self) -> bool: conn = self.get_thread_connection() databricks_conn = cast(DatabricksDBTConnection, conn) @@ -210,14 +218,16 @@ def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) - def cancel_open(self) -> list[str]: cancelled = super().cancel_open() - logger.info("Cancelling open python jobs") - PythonRunTracker.cancel_runs(self.api_client) + # Only cancel Python jobs via API if not in session mode + if not self.is_session_mode(): + logger.info("Cancelling open python jobs") + PythonRunTracker.cancel_runs(self.api_client) return cancelled def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - handle: DatabricksHandle = self.get_thread_connection().handle + handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection().handle dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) @@ -289,6 +299,7 @@ def _create_fresh_connection( conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) conn._query_header_context = query_header_context conn.capabilities = self._get_capabilities_for_http_path(conn.http_path) + conn.handle = LazyHandle(self.open) logger.debug(ConnectionCreate(str(conn))) @@ -321,7 +332,7 @@ def add_query( fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: log_sql = redact_credentials(sql) if abridge_sql_log: @@ -337,15 +348,17 @@ def add_query( pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = handle.execute(sql, bindings) response = self.get_response(cursor) + # SQLQueryStatus in 1.10.x may not support query_id parameter + query_id = getattr(response, "query_id", None) fire_event( SQLQueryStatus( - status=str(cursor.get_response()), + status=str(response), elapsed=round((time.time() - pre), 2), node_info=get_node_info(), - query_id=response.query_id, + query_id=query_id, ) ) @@ -380,14 +393,18 @@ def execute( cursor.close() def _execute_with_cursor( - self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper] + self, + log_sql: str, + f: Callable[ + [DatabricksHandle | DatabricksSessionHandle], CursorWrapper | SessionCursorWrapper + ], ) -> "Table": connection = self.get_thread_connection() fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: fire_event( SQLQuery( @@ -399,14 +416,14 @@ def _execute_with_cursor( pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = f(handle) response = self.get_response(cursor) + # SQLQueryStatus in 1.10.x may not support query_id parameter fire_event( SQLQueryStatus( status=str(response), - query_id=response.query_id, elapsed=round((time.time() - pre), 2), node_info=get_node_info(), ) @@ -464,9 +481,65 @@ def open(cls, connection: Connection) -> Connection: return connection creds: DatabricksCredentials = connection.credentials + + # Dispatch based on connection method + if creds.is_session_mode: + return cls._open_session(databricks_connection, creds) + else: + return cls._open_dbsql(databricks_connection, creds) + + @classmethod + def _open_session( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: + """Open a connection using SparkSession mode.""" + logger.debug("Opening connection in session mode") + + def connect() -> DatabricksSessionHandle: + try: + handle = DatabricksSessionHandle.create( + catalog=creds.database, + schema=creds.schema, + session_properties=creds.session_properties, + ) + databricks_connection.session_id = handle.session_id + + # Cache capabilities for session mode (simplified for 1.10.x) + cls._cache_session_capabilities(handle) + + logger.debug(f"Session mode connection opened: {handle}") + return handle + except Exception as exc: + logger.error(ConnectionCreateError(exc)) + raise DbtDatabaseError(f"Failed to create session connection: {exc}") from exc + + # Session mode doesn't need retry logic as SparkSession is already available + databricks_connection.handle = connect() + databricks_connection.state = ConnectionState.OPEN + return databricks_connection + + @classmethod + def _cache_session_capabilities(cls, handle: DatabricksSessionHandle) -> None: + """Cache DBR capabilities for session mode (simplified for 1.10.x).""" + if cls._session_capabilities is None: + dbr_version = handle.dbr_version + cls._session_capabilities = { + "dbr_version": dbr_version, + "is_sql_warehouse": False, + } + logger.debug(f"Cached session capabilities: DBR version {dbr_version}") + + @classmethod + def _open_dbsql( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: + """Open a connection using DBSQL connector.""" timeout = creds.connect_timeout - cls.credentials_manager = creds.authenticate() + credentials_manager = creds.authenticate() + # In DBSQL mode, authenticate() always returns a credentials manager + assert credentials_manager is not None, "Credentials manager is required for DBSQL mode" + cls.credentials_manager = credentials_manager # Get merged query tags if we have query header context query_header_context = getattr(databricks_connection, "_query_header_context", None) @@ -507,7 +580,7 @@ def exponential_backoff(attempt: int) -> int: retryable_exceptions = [Error] return cls.retry_connection( - connection, + databricks_connection, connect=connect, logger=logger, retryable_exceptions=retryable_exceptions, @@ -527,7 +600,7 @@ def close(cls, connection: Connection) -> Connection: @classmethod def get_response(cls, cursor: Any) -> AdapterResponse: - if isinstance(cursor, CursorWrapper): + if isinstance(cursor, (CursorWrapper, SessionCursorWrapper)): return cursor.get_response() else: return AdapterResponse("OK") diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7e8af786b..1471fab25 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,12 +1,13 @@ import itertools import json +import os import re from collections.abc import Iterable from dataclasses import dataclass, field from typing import Any, Callable, Optional, cast from dbt.adapters.contracts.connection import Credentials -from dbt_common.exceptions import DbtConfigError, DbtValidationError +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError from mashumaro import DataClassDictMixin from requests import PreparedRequest from requests.auth import AuthBase @@ -16,6 +17,14 @@ from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger +# Connection method constants +CONNECTION_METHOD_SESSION = "session" +CONNECTION_METHOD_DBSQL = "dbsql" + +# Environment variable for session mode +DBT_DATABRICKS_SESSION_MODE_ENV = "DBT_DATABRICKS_SESSION_MODE" +DATABRICKS_RUNTIME_VERSION_ENV = "DATABRICKS_RUNTIME_VERSION" + CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") @@ -48,6 +57,9 @@ class DatabricksCredentials(Credentials): connection_parameters: Optional[dict[str, Any]] = None auth_type: Optional[str] = None + # Connection method: "session" for SparkSession mode, "dbsql" for DBSQL connector (default) + method: Optional[str] = None + # Named compute resources specified in the profile. Used for # creating a connection when a model specifies a compute resource. compute: Optional[dict[str, Any]] = None @@ -130,9 +142,60 @@ def __post_init__(self) -> None: if "_socket_timeout" not in connection_parameters: connection_parameters["_socket_timeout"] = 600 self.connection_parameters = connection_parameters - self._credentials_manager = DatabricksCredentialManager.create_from(self) + + # Auto-detect and validate connection method + self._init_connection_method() + + # Only create credentials manager for non-session mode + if not self.is_session_mode: + self._credentials_manager = DatabricksCredentialManager.create_from(self) + + def _init_connection_method(self) -> None: + """Initialize and validate the connection method.""" + if self.method is None: + # Auto-detect session mode + if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, "").lower() == "true": + self.method = CONNECTION_METHOD_SESSION + elif os.getenv(DATABRICKS_RUNTIME_VERSION_ENV) and not self.host: + # Running on Databricks cluster without host configured + self.method = CONNECTION_METHOD_SESSION + else: + self.method = CONNECTION_METHOD_DBSQL + + # Validate method value + if self.method not in (CONNECTION_METHOD_SESSION, CONNECTION_METHOD_DBSQL): + raise DbtValidationError( + f"Invalid connection method: '{self.method}'. " + f"Must be '{CONNECTION_METHOD_SESSION}' or '{CONNECTION_METHOD_DBSQL}'." + ) + + @property + def is_session_mode(self) -> bool: + """Check if using session mode (SparkSession) for connections.""" + return self.method == CONNECTION_METHOD_SESSION + + def _validate_session_mode(self) -> None: + """Validate configuration for session mode.""" + try: + from pyspark.sql import SparkSession # noqa: F401 + except ImportError: + raise DbtRuntimeError( + "Session mode requires PySpark. " + "Please ensure you are running on a Databricks cluster with PySpark available." + ) + + if self.schema is None: + raise DbtValidationError("Schema is required for session mode.") def validate_creds(self) -> None: + """Validate credentials based on connection method.""" + if self.is_session_mode: + self._validate_session_mode() + else: + self._validate_dbsql_creds() + + def _validate_dbsql_creds(self) -> None: + """Validate credentials for DBSQL connector mode.""" for key in ["host", "http_path"]: if not getattr(self, key): raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") @@ -197,6 +260,9 @@ def type(self) -> str: @property def unique_field(self) -> str: + if self.is_session_mode: + # For session mode, use a unique identifier based on catalog and schema + return f"session://{self.database}/{self.schema}" return cast(str, self.host) def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]: @@ -209,6 +275,15 @@ def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, if key in as_dict: yield key, as_dict[key] + def _connection_keys_session(self) -> tuple[str, ...]: + """Connection keys for session mode.""" + connection_keys = ["method", "schema"] + if self.database: + connection_keys.insert(1, "catalog") + if self.session_properties: + connection_keys.append("session_properties") + return tuple(connection_keys) + def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)` # is called from only: @@ -218,6 +293,11 @@ def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # # Thus, if `with_aliases` is `True`, `DatabricksCredentials._connection_keys` should return # the internal key names; otherwise it can use aliases to show in `dbt debug`. + + # Session mode has different connection keys + if self.is_session_mode: + return self._connection_keys_session() + connection_keys = ["host", "http_path", "schema"] if with_aliases: connection_keys.insert(2, "database") @@ -239,8 +319,15 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: def cluster_id(self) -> Optional[str]: return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] - def authenticate(self) -> "DatabricksCredentialManager": + def authenticate(self) -> Optional["DatabricksCredentialManager"]: + """Authenticate and return credentials manager. + + For session mode, returns None as no external authentication is needed. + For DBSQL mode, validates credentials and returns the credentials manager. + """ self.validate_creds() + if self.is_session_mode: + return None assert self._credentials_manager is not None, "Credentials manager is not set." return self._credentials_manager diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py index aadf871c1..c11beb200 100644 --- a/dbt/adapters/databricks/handle.py +++ b/dbt/adapters/databricks/handle.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: pass - CursorOp = Callable[[Cursor], None] CursorExecOp = Callable[[Cursor], Cursor] CursorWrapperOp = Callable[["CursorWrapper"], None] @@ -288,6 +287,47 @@ def translate_bindings(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[A return list(map(lambda x: float(x) if isinstance(x, decimal.Decimal) else x, bindings)) return None + @staticmethod + def format_bindings_for_sql(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[str]]: + """ + Format bindings as SQL literals for string substitution in session mode. + + This method properly quotes string values and handles special cases to ensure + SQL injection safety and correct SQL syntax. Used when executing SQL via + SparkSession.sql() which doesn't support parameterized queries. + + Args: + bindings: Sequence of binding values (strings, numbers, None, etc.) + + Returns: + Sequence of SQL literal strings, or None if bindings is None/empty + """ + if not bindings: + return None + + formatted = [] + for value in bindings: + if value is None: + formatted.append("NULL") + elif isinstance(value, bool): + formatted.append("TRUE" if value else "FALSE") + elif isinstance(value, str): + # Escape single quotes by doubling them, then wrap in quotes + escaped = value.replace("'", "''") + formatted.append(f"'{escaped}'") + elif isinstance(value, (int, float, decimal.Decimal)): + # Numbers don't need quotes + if isinstance(value, decimal.Decimal): + formatted.append(str(float(value))) + else: + formatted.append(str(value)) + else: + # For other types, convert to string and quote + escaped = str(value).replace("'", "''") + formatted.append(f"'{escaped}'") + + return formatted + @staticmethod def clean_sql(sql: str) -> str: cleaned = sql.strip() diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 2190e1f9c..f17a978dd 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -57,6 +57,7 @@ AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, ServerlessClusterPythonJobHelper, + SessionPythonJobHelper, WorkflowPythonJobHelper, ) from dbt.adapters.databricks.relation import ( @@ -801,6 +802,7 @@ def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: "all_purpose_cluster": AllPurposeClusterPythonJobHelper, "serverless_cluster": ServerlessClusterPythonJobHelper, "workflow_job": WorkflowPythonJobHelper, + "session": SessionPythonJobHelper, } @log_code_execution @@ -809,6 +811,17 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterRe "user_folder_for_python", self.behavior.use_user_folder_for_python.setting, # type: ignore[attr-defined] ) + + # Auto-select session submission when in session mode + from dbt.adapters.databricks.credentials import DatabricksCredentials + from dbt.adapters.databricks.logging import logger + + creds = cast(DatabricksCredentials, self.config.credentials) + if creds.is_session_mode: + if parsed_model["config"].get("submission_method") is None: + parsed_model["config"]["submission_method"] = "session" + logger.debug("Auto-selected 'session' submission method for Python model") + return super().submit_python_job(parsed_model, compiled_code) @available diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index 23c2a27c1..b71ca7e65 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from dbt.adapters.base import PythonJobHelper from dbt_common.exceptions import DbtRuntimeError @@ -12,6 +12,9 @@ from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +if TYPE_CHECKING: + from pyspark.sql import SparkSession + DEFAULT_TIMEOUT = 60 * 60 * 24 @@ -309,7 +312,7 @@ def compile(self, path: str) -> PythonJobDetails: additional_job_config["environments"] = [ { "environment_key": self.environment_key, - "spec": {"environment_version": "4", "dependencies": self.environment_deps}, + "spec": {"client": "2", "dependencies": self.environment_deps}, } ] job_spec.update(self.cluster_spec) @@ -658,3 +661,183 @@ def build_submitter(self) -> PythonSubmitter: return PythonNotebookWorkflowSubmitter.create( self.api_client, self.tracker, self.parsed_model ) + + +class SessionStateManager: + """Manages session state to prevent leakage between Python models.""" + + @staticmethod + def cleanup_temp_views(spark: "SparkSession") -> None: + """Drop temporary views created during model execution.""" + try: + # Get list of temp views from the current database + temp_views = [ + row.viewName + for row in spark.sql("SHOW VIEWS").collect() + if hasattr(row, "isTemporary") and row.isTemporary + ] + for view in temp_views: + try: + spark.catalog.dropTempView(view) + logger.debug(f"Dropped temp view: {view}") + except Exception as e: + logger.warning(f"Failed to drop temp view {view}: {e}") + except Exception as e: + logger.debug(f"Could not list temp views for cleanup: {e}") + + @staticmethod + def get_clean_exec_globals(spark: "SparkSession") -> dict[str, Any]: + """Return a clean execution context with minimal state.""" + return { + "spark": spark, + "dbt": __import__("dbt"), + # Standard Python builtins are available by default + } + + +class SessionPythonSubmitter(PythonSubmitter): + """Submitter for Python models using direct execution in current SparkSession. + + NOTE: This does NOT collect data to the driver. The compiled code contains + df.write.saveAsTable() which writes directly to storage, just like API-based + submission methods. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._state_manager = SessionStateManager() + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + try: + # Get clean execution context + exec_globals = self._state_manager.get_clean_exec_globals(self._spark) + + # Log a preview of the code being executed + preview_len = min(500, len(compiled_code)) + logger.debug( + f"[Session Python] Executing code preview: {compiled_code[:preview_len]}..." + ) + + # Execute the compiled Python model code + # The compiled code will: + # 1. Execute model() function to get a DataFrame + # 2. Call df.write.saveAsTable() to persist to Delta + # 3. No collect() - data stays distributed + exec(compiled_code, exec_globals) + + logger.debug("[Session Python] Model execution completed successfully") + + except Exception as e: + logger.error(f"Python model execution failed: {e}") + raise DbtRuntimeError(f"Python model execution failed: {e}") from e + + finally: + # Clean up temp views to prevent state leakage + try: + self._state_manager.cleanup_temp_views(self._spark) + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup temp views: {cleanup_error}") + + +class SessionPythonJobHelper(PythonJobHelper): + """Helper for Python models executing directly in session mode. + + This helper executes Python models directly in the current SparkSession + without using the Databricks API. It's designed for running dbt on + job clusters where the SparkSession is already available. + """ + + tracker = PythonRunTracker() + + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.parsed_model = ParsedPythonModel(**parsed_model) + + # Get SparkSession directly - no API client needed + import os + + from pyspark import SparkContext + from pyspark.sql import SparkSession + + # On Databricks, the SparkSession may already exist or we may need to create it + # Try multiple methods to get an existing session first + spark = None + + # Method 1: Try to get from active SparkContext and create SparkSession from it + # This is the most reliable method on Databricks + try: + sc = SparkContext._active_spark_context + if sc is not None: + # Create SparkSession from existing SparkContext + # This avoids the need for a master URL + spark = SparkSession(sc) + logger.debug("[Session Python] Got SparkSession from active SparkContext") + except Exception as e: + logger.debug(f"[Session Python] Could not get SparkSession from SparkContext: {e}") + + # Method 2: Try getActiveSession() (available in PySpark 3.0+) + if spark is None: + try: + spark = SparkSession.getActiveSession() + if spark is not None: + logger.debug("[Session Python] Got SparkSession from getActiveSession()") + except (AttributeError, Exception) as e: + logger.debug(f"[Session Python] getActiveSession() not available or failed: {e}") + + # Method 3: Try to get from global 'spark' variable (Databricks convention) + if spark is None: + try: + import __main__ + + if hasattr(__main__, "spark"): + spark = __main__.spark + logger.debug("[Session Python] Got SparkSession from __main__.spark") + except Exception as e: + logger.debug( + f"[Session Python] Could not get SparkSession from __main__.spark: {e}" + ) + + # If no existing SparkSession found, provide a clear error message + if spark is None: + databricks_runtime = os.getenv("DATABRICKS_RUNTIME_VERSION") + if databricks_runtime: + raise DbtRuntimeError(f""" + [Session Python] Could not find an existing SparkSession. + This typically happens when using the native 'dbt task' in Databricks Jobs, + which does not provide a SparkSession context. + Session mode is only compatible with: + - Databricks Notebooks (where 'spark' is pre-initialized) + - Python tasks that initialize SparkSession before running dbt + - Environments where SparkSession is already available + For the native dbt task, use DBSQL mode instead (the default): + - Set 'method: dbsql' in your profile (or omit 'method' entirely) + - Configure 'host' and 'http_path' to connect to a SQL warehouse + or cluster. + \n + Databricks runtime version: {databricks_runtime} + """) + else: + raise DbtRuntimeError(f""" + [Session Python] Session mode requires a Databricks cluster environment + with an active SparkSession. + DATABRICKS_RUNTIME_VERSION environment variable not found. + Ensure you are running on a Databricks cluster in a context where + SparkSession is available (e.g., Notebook or Python task with + Spark initialized). + \n + Databricks runtime version: {databricks_runtime} + """) + + self._spark = spark + logger.debug( + f"[Session Python] Using SparkSession: {self._spark.sparkContext.applicationId}" + ) + + self._submitter = SessionPythonSubmitter(self._spark) + + def submit(self, compiled_code: str) -> None: + """Submit the compiled Python model for execution.""" + self._submitter.submit(compiled_code) diff --git a/dbt/adapters/databricks/session.py b/dbt/adapters/databricks/session.py new file mode 100644 index 000000000..505e73dd4 --- /dev/null +++ b/dbt/adapters/databricks/session.py @@ -0,0 +1,352 @@ +""" +Session mode support for dbt-databricks. + +This module provides SparkSession-based execution for running dbt on Databricks job clusters +without requiring the DBSQL connector. It enables complete dbt pipelines (SQL + Python models) +to execute entirely within a single SparkSession. + +Key components: +- SessionCursorWrapper: Adapts DataFrame results to the cursor interface expected by dbt +- DatabricksSessionHandle: Wraps SparkSession to provide the handle interface +""" + +import sys +from collections.abc import Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional + +from dbt.adapters.contracts.connection import AdapterResponse +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.handle import SqlUtils +from dbt.adapters.databricks.logging import logger + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, Row, SparkSession + from pyspark.sql.types import StructField + + +class SessionCursorWrapper: + """ + Wraps SparkSession DataFrame results to provide a cursor-like interface. + + This adapter allows dbt to use SparkSession.sql() results in the same way + it uses DBSQL cursor results, maintaining compatibility with the existing + connection management code. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._df: Optional[DataFrame] = None + self._rows: Optional[list[Row]] = None + self._query_id: str = "session-query" + self.open = True + + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> "SessionCursorWrapper": + """Execute a SQL statement and store the resulting DataFrame.""" + cleaned_sql = SqlUtils.clean_sql(sql) + + # Handle bindings by formatting as SQL literals and substituting + # This ensures string values are properly quoted (e.g., 'k. A.' instead of k. A.) + if bindings: + translated = SqlUtils.translate_bindings(bindings) + if translated: + formatted_bindings = SqlUtils.format_bindings_for_sql(translated) + if formatted_bindings: + cleaned_sql = cleaned_sql % tuple(formatted_bindings) + + logger.debug(f"Session mode executing SQL: {cleaned_sql[:200]}...") + self._df = self._spark.sql(cleaned_sql) + self._rows = None # Reset cached rows + return self + + def fetchall(self) -> Sequence[tuple]: + """Fetch all rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + return [tuple(row) for row in (self._rows or [])] + + def fetchone(self) -> Optional[tuple]: + """Fetch the next row from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if self._rows: + return tuple(self._rows.pop(0)) + return None + + def fetchmany(self, size: int) -> Sequence[tuple]: + """Fetch the next `size` rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if not self._rows: + return [] + result = [tuple(row) for row in self._rows[:size]] + self._rows = self._rows[size:] + return result + + @property + def description(self) -> Optional[list[tuple]]: + """Return column descriptions in DB-API format.""" + if self._df is None: + return None + return [self._field_to_description(f) for f in self._df.schema.fields] + + @staticmethod + def _field_to_description(field: "StructField") -> tuple: + """Convert a StructField to a DB-API description tuple.""" + # DB-API description: (name, type_code, display_size, internal_size, + # precision, scale, null_ok) + return ( + field.name, + field.dataType.simpleString(), + None, + None, + None, + None, + field.nullable, + ) + + def get_response(self) -> AdapterResponse: + """Return an adapter response for the executed query.""" + return AdapterResponse(_message="OK", query_id=self._query_id) + + def cancel(self) -> None: + """Cancel the current operation (no-op for session mode).""" + logger.debug("SessionCursorWrapper.cancel() called (no-op)") + self.open = False + + def close(self) -> None: + """Close the cursor.""" + logger.debug("SessionCursorWrapper.close() called") + self.open = False + self._df = None + self._rows = None + + def __str__(self) -> str: + return f"SessionCursor(query-id={self._query_id})" + + def __enter__(self) -> "SessionCursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + + +class DatabricksSessionHandle: + """ + Handle for a Databricks SparkSession. + + Provides the same interface as DatabricksHandle but uses the active SparkSession + instead of the DBSQL connector. This enables dbt to run on job clusters without + requiring external API connections. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self.open = True + self._cursor: Optional[SessionCursorWrapper] = None + self._dbr_version: Optional[tuple[int, int]] = None + + @staticmethod + def create( + catalog: Optional[str] = None, + schema: Optional[str] = None, + session_properties: Optional[dict[str, Any]] = None, + ) -> "DatabricksSessionHandle": + """ + Create a DatabricksSessionHandle using the active SparkSession. + + Args: + catalog: Optional catalog to set as current + schema: Optional schema (not used - dbt uses fully qualified names and + creates schemas during execution) + session_properties: Optional session configuration properties + + Returns: + A new DatabricksSessionHandle instance + """ + import os + + from pyspark import SparkContext + from pyspark.sql import SparkSession + + # On Databricks, the SparkSession may already exist or we may need to create it + # Try multiple methods to get an existing session first + spark = None + + # Method 1: Try to get from active SparkContext and create SparkSession from it + # This is the most reliable method on Databricks + try: + sc = SparkContext._active_spark_context + if sc is not None: + # Create SparkSession from existing SparkContext + # This avoids the need for a master URL + spark = SparkSession(sc) + logger.debug("Got SparkSession from active SparkContext") + except Exception as e: + logger.debug(f"Could not get SparkSession from SparkContext: {e}") + + # Method 2: Try getActiveSession() (available in PySpark 3.0+) + if spark is None: + try: + spark = SparkSession.getActiveSession() + if spark is not None: + logger.debug("Got SparkSession from getActiveSession()") + except (AttributeError, Exception) as e: + logger.debug(f"getActiveSession() not available or failed: {e}") + + # Method 3: Try to get from global 'spark' variable (Databricks convention) + if spark is None: + try: + import __main__ + + if hasattr(__main__, "spark"): + spark = __main__.spark + logger.debug("Got SparkSession from __main__.spark") + except Exception as e: + logger.debug(f"Could not get SparkSession from __main__.spark: {e}") + + # If no existing SparkSession found, provide a clear error message + if spark is None: + databricks_runtime = os.getenv("DATABRICKS_RUNTIME_VERSION") + if databricks_runtime: + raise DbtRuntimeError( + f"""Session mode could not find an existing SparkSession. + This typically happens when using the native 'dbt task' in Databricks Jobs, + which does not provide a SparkSession context. + \n + Session mode is only compatible with: + - Databricks Notebooks (where 'spark' is pre-initialized) + - Python tasks that initialize SparkSession before running dbt + - Environments where SparkSession is already available + \n + For the native dbt task, use DBSQL mode instead (the default): + - Set 'method: dbsql' in your profile (or omit 'method' entirely) + - Configure 'host' and 'http_path' to connect to a SQL warehouse or cluster + \n + Databricks runtime version: {databricks_runtime} + """ + ) + else: + raise DbtRuntimeError(f""" + Session mode requires a Databricks cluster environment with + an active SparkSession. + DATABRICKS_RUNTIME_VERSION environment variable not found. + Ensure you are running on a Databricks cluster in a context where + SparkSession is available (e.g., Notebook or Python task with + Spark initialized). + \n + Databricks runtime version: {databricks_runtime} + """) + + # Set catalog if provided + if catalog: + try: + spark.catalog.setCurrentCatalog(catalog) + logger.debug(f"Set current catalog to: {catalog}") + except Exception as e: + logger.warning(f"Failed to set catalog '{catalog}': {e}") + # Fall back to USE CATALOG for older Spark versions + spark.sql(f"USE CATALOG {catalog}") + + # Note: We intentionally do NOT call setCurrentDatabase(schema) here. + # The schema from the profile is used as a base/prefix for generated schema names + # (e.g., "dbt" -> "dbt_seeds", "dbt_bronze" via generate_schema_name macro). + # dbt creates these schemas during execution via CREATE SCHEMA IF NOT EXISTS, + # and uses fully qualified names (catalog.schema.table) for all operations. + # Setting the current database would fail if the schema doesn't exist yet. + + # Apply session properties + if session_properties: + for key, value in session_properties.items(): + spark.conf.set(key, str(value)) + logger.debug(f"Set session property {key}={value}") + + handle = DatabricksSessionHandle(spark) + logger.debug(f"Created session handle: {handle}") + return handle + + @property + def dbr_version(self) -> tuple[int, int]: + """Get the DBR version of the current cluster.""" + if self._dbr_version is None: + try: + version_str = self._spark.conf.get( + "spark.databricks.clusterUsageTags.sparkVersion", "" + ) + if version_str: + self._dbr_version = SqlUtils.extract_dbr_version(version_str) + else: + # If we can't get the version, assume latest + logger.warning("Could not determine DBR version, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + except Exception as e: + logger.warning(f"Failed to get DBR version: {e}, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + return self._dbr_version + + @property + def session_id(self) -> str: + """Get a unique identifier for this session.""" + try: + # Try to get the Spark application ID as a session identifier + return self._spark.sparkContext.applicationId or "session-unknown" + except Exception: + return "session-unknown" + + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> SessionCursorWrapper: + """Execute a SQL statement and return a cursor wrapper.""" + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed session handle") + + if self._cursor: + self._cursor.close() + + self._cursor = SessionCursorWrapper(self._spark) + return self._cursor.execute(sql, bindings) + + def list_schemas(self, database: str, schema: Optional[str] = None) -> SessionCursorWrapper: + """List schemas in the given database/catalog.""" + if schema: + sql = f"SHOW SCHEMAS IN {database} LIKE '{schema}'" + else: + sql = f"SHOW SCHEMAS IN {database}" + return self.execute(sql) + + def list_tables(self, database: str, schema: str) -> SessionCursorWrapper: + """List tables in the given database and schema.""" + sql = f"SHOW TABLES IN {database}.{schema}" + return self.execute(sql) + + def cancel(self) -> None: + """Cancel any in-progress operations.""" + logger.debug("DatabricksSessionHandle.cancel() called") + if self._cursor: + self._cursor.cancel() + self.open = False + + def close(self) -> None: + """Close the session handle.""" + logger.debug("DatabricksSessionHandle.close() called") + if self._cursor: + self._cursor.close() + self.open = False + # Note: We don't stop the SparkSession as it may be shared + + def rollback(self) -> None: + """Required for interface compatibility, but not implemented.""" + logger.debug("NotImplemented: rollback (session mode)") + + def __del__(self) -> None: + if self._cursor: + self._cursor.close() + self.close() + + def __str__(self) -> str: + return f"SessionHandle(session-id={self.session_id})" diff --git a/pyproject.toml b/pyproject.toml index 4c120c86a..c388c8f1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "dbt-databricks" +name = "dbt-databricks-session" dynamic = ["version"] description = "The Databricks adapter plugin for dbt" readme = "README.md" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..09ed7105d --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,399 @@ +"""Unit tests for session mode components.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.databricks.session import ( + DatabricksSessionHandle, + SessionCursorWrapper, +) + + +class TestSessionCursorWrapper: + """Tests for SessionCursorWrapper.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + return spark + + @pytest.fixture + def cursor(self, mock_spark): + """Create a SessionCursorWrapper with mock SparkSession.""" + return SessionCursorWrapper(mock_spark) + + def test_execute_cleans_sql(self, cursor, mock_spark): + """Test that execute cleans SQL (strips whitespace and trailing semicolon).""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + result = cursor.execute(" SELECT 1; ") + + mock_spark.sql.assert_called_once_with("SELECT 1") + assert result is cursor + + def test_execute_with_bindings(self, cursor, mock_spark): + """Test that execute handles bindings with proper SQL literal formatting.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT %s, %s", (1, "test")) + + # String values should be properly quoted + mock_spark.sql.assert_called_once_with("SELECT 1, 'test'") + + def test_fetchall_returns_tuples(self, cursor, mock_spark): + """Test that fetchall returns list of tuples.""" + mock_df = MagicMock() + mock_row1 = MagicMock() + mock_row1.__iter__ = lambda self: iter([1, "a"]) + mock_row2 = MagicMock() + mock_row2.__iter__ = lambda self: iter([2, "b"]) + mock_df.collect.return_value = [mock_row1, mock_row2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchall() + + assert result == [(1, "a"), (2, "b")] + + def test_fetchone_returns_single_tuple(self, cursor, mock_spark): + """Test that fetchone returns a single tuple.""" + mock_df = MagicMock() + mock_row = MagicMock() + mock_row.__iter__ = lambda self: iter([1, "a"]) + mock_df.collect.return_value = [mock_row] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result == (1, "a") + + def test_fetchone_returns_none_when_empty(self, cursor, mock_spark): + """Test that fetchone returns None when no rows.""" + mock_df = MagicMock() + mock_df.collect.return_value = [] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result is None + + def test_fetchmany_returns_limited_rows(self, cursor, mock_spark): + """Test that fetchmany returns limited number of rows.""" + mock_df = MagicMock() + mock_rows = [MagicMock() for _ in range(5)] + for i, row in enumerate(mock_rows): + row.__iter__ = (lambda i: lambda self: iter([i]))(i) + mock_df.collect.return_value = mock_rows + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchmany(2) + + assert len(result) == 2 + assert result == [(0,), (1,)] + + def test_description_returns_column_info(self, cursor, mock_spark): + """Test that description returns column metadata.""" + mock_df = MagicMock() + mock_field1 = MagicMock() + mock_field1.name = "id" + mock_field1.dataType.simpleString.return_value = "int" + mock_field1.nullable = False + + mock_field2 = MagicMock() + mock_field2.name = "name" + mock_field2.dataType.simpleString.return_value = "string" + mock_field2.nullable = True + + mock_df.schema.fields = [mock_field1, mock_field2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + desc = cursor.description + + assert len(desc) == 2 + assert desc[0][0] == "id" + assert desc[0][1] == "int" + assert desc[0][6] is False + assert desc[1][0] == "name" + assert desc[1][1] == "string" + assert desc[1][6] is True + + def test_description_returns_none_before_execute(self, cursor): + """Test that description returns None before execute.""" + assert cursor.description is None + + def test_get_response_returns_adapter_response(self, cursor, mock_spark): + """Test that get_response returns an AdapterResponse.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT 1") + response = cursor.get_response() + + assert response._message == "OK" + assert response.query_id == "session-query" + + def test_close_sets_open_to_false(self, cursor): + """Test that close sets open to False.""" + assert cursor.open is True + cursor.close() + assert cursor.open is False + + def test_context_manager(self, mock_spark): + """Test that cursor works as context manager.""" + with SessionCursorWrapper(mock_spark) as cursor: + assert cursor.open is True + assert cursor.open is False + + def test_execute_with_string_special_characters(self, cursor, mock_spark): + """Test that execute properly quotes strings with special characters.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s)", ("k. A.", "O'Brien")) + + # Special characters should be properly quoted and escaped + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES ('k. A.', 'O''Brien')") + + def test_execute_with_null_value(self, cursor, mock_spark): + """Test that execute handles NULL values correctly.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s)", (1, None)) + + # NULL should be formatted as SQL NULL + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES (1, NULL)") + + def test_execute_with_boolean_values(self, cursor, mock_spark): + """Test that execute handles boolean values correctly.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s)", (True, False)) + + # Booleans should be formatted as TRUE/FALSE + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES (TRUE, FALSE)") + + def test_execute_with_numeric_values(self, cursor, mock_spark): + """Test that execute handles numeric values correctly.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s, %s, %s)", (42, 3.14, 100)) + + # Numbers should remain unquoted + mock_spark.sql.assert_called_once_with("INSERT INTO test VALUES (42, 3.14, 100)") + + def test_execute_with_comma_in_string(self, cursor, mock_spark): + """Test that execute properly handles strings containing commas.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s)", ("Freie und Hansestadt Hamburg",)) + + # String with spaces and special characters should be properly quoted + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES ('Freie und Hansestadt Hamburg')" + ) + + def test_execute_with_parentheses_in_string(self, cursor, mock_spark): + """Test that execute properly handles strings containing parentheses.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("INSERT INTO test VALUES (%s)", ("Haus in Planung (projektiert)",)) + + # String with parentheses should be properly quoted + mock_spark.sql.assert_called_once_with( + "INSERT INTO test VALUES ('Haus in Planung (projektiert)')" + ) + + +class TestDatabricksSessionHandle: + """Tests for DatabricksSessionHandle.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + spark.sparkContext.applicationId = "app-123" + spark.conf.get.return_value = "14.3.x-scala2.12" + return spark + + @pytest.fixture + def handle(self, mock_spark): + """Create a DatabricksSessionHandle with mock SparkSession.""" + return DatabricksSessionHandle(mock_spark) + + def test_session_id_returns_application_id(self, handle): + """Test that session_id returns the Spark application ID.""" + assert handle.session_id == "app-123" + + def test_dbr_version_extracts_version(self, handle, mock_spark): + """Test that dbr_version extracts version from Spark config.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + version = handle.dbr_version + + assert version == (14, 3) + + def test_dbr_version_caches_result(self, handle, mock_spark): + """Test that dbr_version caches the result.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + _ = handle.dbr_version + _ = handle.dbr_version + + # Should only call conf.get once due to caching + assert mock_spark.conf.get.call_count == 1 + + def test_execute_returns_cursor(self, handle, mock_spark): + """Test that execute returns a SessionCursorWrapper.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor = handle.execute("SELECT 1") + + assert isinstance(cursor, SessionCursorWrapper) + + def test_execute_closes_previous_cursor(self, handle, mock_spark): + """Test that execute closes any previous cursor.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor1 = handle.execute("SELECT 1") + assert cursor1.open is True + + cursor2 = handle.execute("SELECT 2") + assert cursor1.open is False + assert cursor2.open is True + + def test_close_sets_open_to_false(self, handle): + """Test that close sets open to False.""" + assert handle.open is True + handle.close() + assert handle.open is False + + def test_list_schemas_executes_show_schemas(self, handle, mock_spark): + """Test that list_schemas executes SHOW SCHEMAS.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog") + + def test_list_schemas_with_pattern(self, handle, mock_spark): + """Test that list_schemas includes LIKE pattern.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog LIKE 'my_schema'") + + def test_list_tables_executes_show_tables(self, handle, mock_spark): + """Test that list_tables executes SHOW TABLES.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_tables("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with("SHOW TABLES IN my_catalog.my_schema") + + def test_create_gets_existing_spark_session(self): + """Test that create finds an existing SparkSession via SparkContext.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + # Mock SparkContext._active_spark_context to return a context + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + # Mock SparkSession constructor to return our mock session + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + handle = DatabricksSessionHandle.create() + + # Should have created SparkSession from SparkContext + sys.modules["pyspark.sql"].SparkSession.assert_called_once_with(mock_sc) + assert handle.session_id == "app-456" + + def test_create_sets_catalog(self): + """Test that create sets the catalog.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + DatabricksSessionHandle.create(catalog="my_catalog") + + mock_spark.catalog.setCurrentCatalog.assert_called_once_with("my_catalog") + + def test_create_does_not_set_schema(self): + """Test that create does NOT set the schema/database. + + dbt uses the schema from the profile as a base/prefix for generated schema names + (e.g., "dbt" -> "dbt_seeds", "dbt_bronze" via generate_schema_name macro). + dbt creates these schemas during execution via CREATE SCHEMA IF NOT EXISTS, + and uses fully qualified names (catalog.schema.table) for all operations. + Setting the current database would fail if the schema doesn't exist yet. + """ + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + DatabricksSessionHandle.create(schema="my_schema") + + # Schema should NOT be set - this is the fix for SCHEMA_NOT_FOUND error + mock_spark.catalog.setCurrentDatabase.assert_not_called() + + def test_create_sets_session_properties(self): + """Test that create sets session properties.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_sc = MagicMock() + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark"].SparkContext._active_spark_context = mock_sc + sys.modules["pyspark.sql"].SparkSession.return_value = mock_spark + + DatabricksSessionHandle.create(session_properties={"key1": "value1", "key2": 123}) + + mock_spark.conf.set.assert_any_call("key1", "value1") + mock_spark.conf.set.assert_any_call("key2", "123")