Skip to content

Commit e9c3c55

Browse files
authored
feat: databricks allow disabling spark session (#2703)
1 parent be4d496 commit e9c3c55

File tree

4 files changed

+52
-36
lines changed

4 files changed

+52
-36
lines changed

docs/integrations/engines/databricks.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Note: If using Databricks Connect please note the [requirements](https://docs.da
3838
| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N |
3939
| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N |
4040
| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N |
41+
| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N |
4142

4243
## Airflow Scheduler
4344
**Engine Name:** `databricks` / `databricks-submit` / `databricks-sql`.

sqlmesh/core/config/connection.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ class DatabricksConnectionConfig(ConnectionConfig):
511511
Defaults to deriving the cluster id from the `http_path` value.
512512
force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
513513
disable_databricks_connect: Even if databricks connect is installed, do not use it.
514+
disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook).
514515
pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
515516
"""
516517

@@ -525,6 +526,7 @@ class DatabricksConnectionConfig(ConnectionConfig):
525526
databricks_connect_cluster_id: t.Optional[str] = None
526527
force_databricks_connect: bool = False
527528
disable_databricks_connect: bool = False
529+
disable_spark_session: bool = False
528530

529531
concurrent_tasks: int = 1
530532
register_comments: bool = True
@@ -538,12 +540,11 @@ class DatabricksConnectionConfig(ConnectionConfig):
538540
@model_validator(mode="before")
539541
@model_validator_v1_args
540542
def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
541-
from sqlmesh import RuntimeEnv
542543
from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
543544

544-
runtime_env = RuntimeEnv.get()
545-
546-
if runtime_env.is_databricks:
545+
if DatabricksEngineAdapter.can_access_spark_session(
546+
bool(values.get("disable_spark_session"))
547+
):
547548
return values
548549
server_hostname, http_path, access_token = (
549550
values.get("server_hostname"),
@@ -554,7 +555,9 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
554555
raise ValueError(
555556
"`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook"
556557
)
557-
if DatabricksEngineAdapter.can_access_spark_session:
558+
if DatabricksEngineAdapter.can_access_databricks_connect(
559+
bool(values.get("disable_databricks_connect"))
560+
):
558561
if not values.get("databricks_connect_server_hostname"):
559562
values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
560563
if not values.get("databricks_connect_access_token"):
@@ -585,14 +588,18 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
585588
return {
586589
k: v
587590
for k, v in self.dict().items()
588-
if k.startswith("databricks_connect_") or k in ("catalog", "disable_databricks_connect")
591+
if k.startswith("databricks_connect_")
592+
or k in ("catalog", "disable_databricks_connect", "disable_spark_session")
589593
}
590594

591595
@property
592596
def use_spark_session_only(self) -> bool:
593-
from sqlmesh import RuntimeEnv
597+
from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
594598

595-
return RuntimeEnv.get().is_databricks or self.force_databricks_connect
599+
return (
600+
DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session)
601+
or self.force_databricks_connect
602+
)
596603

597604
@property
598605
def _connection_factory(self) -> t.Callable:
@@ -607,14 +614,14 @@ def _connection_factory(self) -> t.Callable:
607614

608615
@property
609616
def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
610-
from sqlmesh import RuntimeEnv
617+
from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
611618

612619
if not self.use_spark_session_only:
613620
return {
614621
"_user_agent_entry": "sqlmesh",
615622
}
616623

617-
if RuntimeEnv.get().is_databricks:
624+
if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session):
618625
from pyspark.sql import SparkSession
619626

620627
return dict(

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
1616
from sqlmesh.core.schema_diff import SchemaDiffer
17-
from sqlmesh.utils import classproperty
1817
from sqlmesh.utils.errors import SQLMeshError
1918

2019
if t.TYPE_CHECKING:
@@ -47,12 +46,20 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
4746
super().__init__(*args, **kwargs)
4847
self._spark: t.Optional[PySparkSession] = None
4948

50-
@classproperty
51-
def can_access_spark_session(cls) -> bool:
49+
@classmethod
50+
def can_access_spark_session(cls, disable_spark_session: bool) -> bool:
5251
from sqlmesh import RuntimeEnv
5352

54-
if RuntimeEnv.get().is_databricks:
55-
return True
53+
if disable_spark_session:
54+
return False
55+
56+
return RuntimeEnv.get().is_databricks
57+
58+
@classmethod
59+
def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool:
60+
if disable_databricks_connect:
61+
return False
62+
5663
try:
5764
from databricks.connect import DatabricksSession # noqa
5865

@@ -62,19 +69,15 @@ def can_access_spark_session(cls) -> bool:
6269

6370
@property
6471
def _use_spark_session(self) -> bool:
65-
from sqlmesh import RuntimeEnv
66-
67-
if RuntimeEnv.get().is_databricks:
72+
if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))):
6873
return True
69-
return (
70-
self.can_access_spark_session
71-
and {
72-
"databricks_connect_server_hostname",
73-
"databricks_connect_access_token",
74-
"databricks_connect_cluster_id",
75-
}.issubset(self._extra_config)
76-
and not self._extra_config.get("disable_databricks_connect")
77-
)
74+
return self.can_access_databricks_connect(
75+
bool(self._extra_config.get("disable_databricks_connect"))
76+
) and {
77+
"databricks_connect_server_hostname",
78+
"databricks_connect_access_token",
79+
"databricks_connect_cluster_id",
80+
}.issubset(self._extra_config)
7881

7982
@property
8083
def is_spark_session_cursor(self) -> bool:
@@ -97,11 +100,15 @@ def spark(self) -> PySparkSession:
97100
from databricks.connect import DatabricksSession
98101

99102
if self._spark is None:
100-
self._spark = DatabricksSession.builder.remote(
101-
host=self._extra_config["databricks_connect_server_hostname"],
102-
token=self._extra_config["databricks_connect_access_token"],
103-
cluster_id=self._extra_config["databricks_connect_cluster_id"],
104-
).getOrCreate()
103+
self._spark = (
104+
DatabricksSession.builder.remote(
105+
host=self._extra_config["databricks_connect_server_hostname"],
106+
token=self._extra_config["databricks_connect_access_token"],
107+
cluster_id=self._extra_config["databricks_connect_cluster_id"],
108+
)
109+
.userAgent("sqlmesh")
110+
.getOrCreate()
111+
)
105112
catalog = self._extra_config.get("catalog")
106113
if catalog:
107114
self.set_current_catalog(catalog)

sqlmesh/engines/spark/db_api/spark_session.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ def set_current_catalog(self, catalog_name: str) -> None:
8989
def cursor(self) -> SparkSessionCursor:
9090
try:
9191
self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
92+
self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
93+
self.spark.conf.set("hive.exec.dynamic.partition", "true")
94+
self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
9295
except NotImplementedError:
93-
# Databricks Connect does not support accessing the SparkContext
96+
# Databricks Connect does not support accessing the SparkContext nor does it support
97+
# setting dynamic partition overwrite since it uses replace where
9498
pass
9599
if self.catalog:
96100
from py4j.protocol import Py4JError
@@ -101,9 +105,6 @@ def cursor(self) -> SparkSessionCursor:
101105
# and shared clusters so we use the Databricks Unity only SQL command instead
102106
except Py4JError:
103107
self.spark.sql(f"USE CATALOG {self.catalog}")
104-
self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
105-
self.spark.conf.set("hive.exec.dynamic.partition", "true")
106-
self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
107108
return SparkSessionCursor(self.spark)
108109

109110
def commit(self) -> None:

0 commit comments

Comments
 (0)