Skip to content

Commit 27a52ad

Browse files
committed
fix: Set DatabricksConnectionConfig to only use a shared_connection when U2M OAuth authentication is used.
Signed-off-by: davem-bis <68955845+davem-bis@users.noreply.github.com>
1 parent 444c50d commit 27a52ad

2 files changed

Lines changed: 58 additions & 12 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import partial
1313

1414
import pydantic
15-
from pydantic import Field
15+
from pydantic import Field, computed_field
1616
from pydantic_core import from_json
1717
from packaging import version
1818
from sqlglot import exp
@@ -108,7 +108,10 @@ class ConnectionConfig(abc.ABC, BaseConfig):
108108
catalog_type_overrides: t.Optional[t.Dict[str, str]] = None
109109

110110
# Whether to share a single connection across threads or create a new connection per thread.
111-
shared_connection: t.ClassVar[bool] = False
111+
@computed_field
112+
@property
113+
def shared_connection(self) -> bool:
114+
return False
112115

113116
@property
114117
@abc.abstractmethod
@@ -309,7 +312,10 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
309312

310313
token: t.Optional[str] = None
311314

312-
shared_connection: t.ClassVar[bool] = True
315+
@computed_field
316+
@property
317+
def shared_connection(self) -> bool:
318+
return True
313319

314320
_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
315321

@@ -818,11 +824,15 @@ class DatabricksConnectionConfig(ConnectionConfig):
818824
DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks"
819825
DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3
820826

821-
shared_connection: t.ClassVar[bool] = True
822-
823827
_concurrent_tasks_validator = concurrent_tasks_validator
824828
_http_headers_validator = http_headers_validator
825829

830+
@computed_field
831+
@property
832+
def shared_connection(self) -> bool:
833+
"""The connection should only be shared if U2M OAuth is being used"""
834+
return self.auth_type is not None and self.oauth_client_id is None
835+
826836
@model_validator(mode="before")
827837
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
828838
# SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block.

tests/core/test_connection_config.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,26 +1424,27 @@ def test_databricks(make_config):
14241424
)
14251425

14261426

1427-
def test_databricks_shared_connection(make_config):
1428-
"""Databricks should use a shared connection pool to prevent OAuth CSRF races.
1427+
def test_databricks__u2m_oauth__shared_connection_pool(make_config):
1428+
"""Databricks should use a shared connection pool when using OAuth to prevent CSRF races.
14291429
14301430
When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per
14311431
thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow;
14321432
these race on the CSRF state parameter and cause MismatchingStateError.
14331433
1434-
Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be
1435-
used instead: a single connection is created (behind a lock) and each thread
1436-
gets its own cursor, so only one OAuth flow is ever initiated.
1434+
For non-U2M OAuth authentication types (e.g. access_token and M2M OAuth) then
1435+
ThreadLocalConnectionPool should still be used.
14371436
1438-
See: https://github.com/tobymao/sqlmesh/issues/5646
1437+
See:
1438+
https://github.com/tobymao/sqlmesh/issues/5646
1439+
https://github.com/SQLMesh/sqlmesh/issues/5858
14391440
"""
14401441
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
14411442

14421443
config = make_config(
14431444
type="databricks",
14441445
server_hostname="dbc-test.cloud.databricks.com",
14451446
http_path="sql/test/foo",
1446-
access_token="test-token",
1447+
auth_type="databricks-oauth",
14471448
concurrent_tasks=4,
14481449
)
14491450
assert isinstance(config, DatabricksConnectionConfig)
@@ -1453,6 +1454,41 @@ def test_databricks_shared_connection(make_config):
14531454
assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool)
14541455

14551456

1457+
def test_databricks__m2m_oauth__connection_pool(make_config):
1458+
from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool
1459+
1460+
config = make_config(
1461+
type="databricks",
1462+
server_hostname="dbc-test.cloud.databricks.com",
1463+
http_path="sql/test/foo",
1464+
auth_type="databricks-oauth",
1465+
oauth_client_id="oauth_client_id",
1466+
concurrent_tasks=4,
1467+
)
1468+
assert isinstance(config, DatabricksConnectionConfig)
1469+
assert config.shared_connection is False
1470+
1471+
adapter = config.create_engine_adapter()
1472+
assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool)
1473+
1474+
1475+
def test_databricks__access_token__connection_pool(make_config):
1476+
from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool
1477+
1478+
config = make_config(
1479+
type="databricks",
1480+
server_hostname="dbc-test.cloud.databricks.com",
1481+
http_path="sql/test/foo",
1482+
access_token="any-token",
1483+
concurrent_tasks=4,
1484+
)
1485+
assert isinstance(config, DatabricksConnectionConfig)
1486+
assert config.shared_connection is False
1487+
1488+
adapter = config.create_engine_adapter()
1489+
assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool)
1490+
1491+
14561492
def test_engine_import_validator():
14571493
with pytest.raises(
14581494
ConfigError,

0 commit comments

Comments
 (0)