Skip to content

Commit 7f5f1d2

Browse files
authored
fix: set catalog properly across connection and engine (#2428)
* fix: remove extra spark set catalog * fix: consolidate spark catalog operations to connection level * properly handle if pandas df results are returned
1 parent 79085b8 commit 7f5f1d2

File tree

4 files changed

+35
-22
lines changed

4 files changed

+35
-22
lines changed

sqlmesh/core/engine_adapter/spark.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SourceQuery,
2323
set_catalog,
2424
)
25+
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionConnection
2526
from sqlmesh.utils import classproperty
2627
from sqlmesh.utils.errors import SQLMeshError
2728

@@ -58,9 +59,13 @@ class SparkEngineAdapter(GetCurrentCatalogFromFunctionMixin, HiveMetastoreTableP
5859
WAP_PREFIX = "wap_"
5960
BRANCH_PREFIX = "branch_"
6061

62+
@property
63+
def connection(self) -> SparkSessionConnection:
64+
return self._connection_pool.get()
65+
6166
@property
6267
def spark(self) -> PySparkSession:
63-
return self._connection_pool.get().spark
68+
return self.connection.spark
6469

6570
@property
6671
def _use_spark_session(self) -> bool:
@@ -319,7 +324,8 @@ def _get_data_objects(
319324
DataObject(
320325
catalog=self.get_current_catalog(),
321326
# This varies between Spark and Databricks
322-
schema=row.asDict().get("namespace") or row["database"],
327+
schema=(row.asDict() if not isinstance(row, dict) else row).get("namespace")
328+
or row["database"],
323329
name=row["tableName"],
324330
type=(
325331
DataObjectType.VIEW
@@ -330,26 +336,13 @@ def _get_data_objects(
330336
for row in results # type: ignore
331337
]
332338

333-
@property
334-
def _spark_major_minor(self) -> t.Tuple[int, int]:
335-
return tuple(int(x) for x in self.spark.version.split(".")[:2]) # type: ignore
336-
337339
def get_current_catalog(self) -> t.Optional[str]:
338340
if self._use_spark_session:
339-
if self._spark_major_minor >= (3, 4):
340-
return self.spark.catalog.currentCatalog()
341-
else:
342-
return self._default_catalog or "spark_catalog"
341+
return self.connection.get_current_catalog()
343342
return super().get_current_catalog()
344343

345344
def set_current_catalog(self, catalog_name: str) -> None:
346-
if self._spark_major_minor >= (3, 4):
347-
return self.spark.catalog.setCurrentCatalog(catalog_name)
348-
current_catalog = self.get_current_catalog()
349-
if current_catalog != catalog_name:
350-
logger.warning(
351-
"Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4"
352-
)
345+
self.connection.set_current_catalog(catalog_name)
353346

354347
def get_current_database(self) -> str:
355348
if self._use_spark_session:

sqlmesh/engines/spark/db_api/spark_session.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import typing as t
23
from threading import get_ident
34

@@ -6,6 +7,8 @@
67

78
from sqlmesh.engines.spark.db_api.errors import NotSupportedError, ProgrammingError
89

10+
logger = logging.getLogger(__name__)
11+
912

1013
class SparkSessionCursor:
1114
def __init__(self, spark: SparkSession):
@@ -65,18 +68,35 @@ def __init__(self, spark: SparkSession, catalog: t.Optional[str] = None):
6568
self.spark = spark
6669
self.catalog = catalog
6770

71+
@property
72+
def _spark_major_minor(self) -> t.Tuple[int, int]:
73+
return tuple(int(x) for x in self.spark.version.split(".")[:2]) # type: ignore
74+
75+
def get_current_catalog(self) -> t.Optional[str]:
76+
if self._spark_major_minor >= (3, 4):
77+
return self.spark.catalog.currentCatalog()
78+
return self.catalog or "spark_catalog"
79+
80+
def set_current_catalog(self, catalog_name: str) -> None:
81+
if self._spark_major_minor >= (3, 4):
82+
return self.spark.catalog.setCurrentCatalog(catalog_name)
83+
current_catalog = self.get_current_catalog()
84+
if current_catalog != catalog_name:
85+
logger.warning(
86+
"Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4"
87+
)
88+
6889
def cursor(self) -> SparkSessionCursor:
6990
try:
7091
self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
7192
except NotImplementedError:
7293
# Databricks Connect does not support accessing the SparkContext
7394
pass
7495
if self.catalog:
75-
# Note: Spark 3.4+ Only API
7696
from py4j.protocol import Py4JError
7797

7898
try:
79-
self.spark.catalog.setCurrentCatalog(self.catalog)
99+
self.set_current_catalog(self.catalog)
80100
# Databricks does not support `setCurrentCatalog` with Unity catalog
81101
# and shared clusters so we use the Databricks Unity only SQL command instead
82102
except Py4JError:

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _make_function(
344344
)
345345
if isinstance(adapter, SparkEngineAdapter):
346346
mocker.patch(
347-
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._spark_major_minor",
347+
"sqlmesh.engines.spark.db_api.spark_session.SparkSessionConnection._spark_major_minor",
348348
new_callable=PropertyMock(return_value=(3, 5)),
349349
)
350350
return adapter

tests/core/engine_adapter/test_spark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def check_table_exists(table_name: exp.Table) -> bool:
787787

788788
def test_wap_prepare(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
789789
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
790-
adapter.spark.catalog.currentCatalog.return_value = "spark_catalog"
790+
adapter.connection.get_current_catalog.return_value = "spark_catalog"
791791
adapter.spark.catalog.currentDatabase.return_value = "default"
792792

793793
table_name = "test_db.test_table"
@@ -805,7 +805,7 @@ def test_wap_publish(make_mocked_engine_adapter: t.Callable, mocker: MockerFixtu
805805
iceberg_snapshot_id = 123
806806

807807
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
808-
adapter.spark.catalog.currentCatalog.return_value = "spark_catalog"
808+
adapter.connection.get_current_catalog.return_value = "spark_catalog"
809809
adapter.spark.catalog.currentDatabase.return_value = "default"
810810
adapter.cursor.fetchall.return_value = [(iceberg_snapshot_id,)]
811811

0 commit comments

Comments
 (0)