@@ -1424,26 +1424,27 @@ def test_databricks(make_config):
14241424 )
14251425
14261426
1427- def test_databricks_shared_connection (make_config ):
1428- """Databricks should use a shared connection pool to prevent OAuth CSRF races.
1427+ def test_databricks__u2m_oauth__shared_connection_pool (make_config ):
1428+ """Databricks should use a shared connection pool when using OAuth to prevent CSRF races.
14291429
14301430 When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per
14311431 thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow;
14321432 these race on the CSRF state parameter and cause MismatchingStateError.
14331433
1434- Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be
1435- used instead: a single connection is created (behind a lock) and each thread
1436- gets its own cursor, so only one OAuth flow is ever initiated.
1434+ For non-U2M OAuth authentication types (e.g. access_token and M2M OAuth) then
1435+ ThreadLocalConnectionPool should still be used.
14371436
1438- See: https://github.com/tobymao/sqlmesh/issues/5646
1437+ See:
1438+ https://github.com/tobymao/sqlmesh/issues/5646
1439+ https://github.com/SQLMesh/sqlmesh/issues/5858
14391440 """
14401441 from sqlmesh .utils .connection_pool import ThreadLocalSharedConnectionPool
14411442
14421443 config = make_config (
14431444 type = "databricks" ,
14441445 server_hostname = "dbc-test.cloud.databricks.com" ,
14451446 http_path = "sql/test/foo" ,
1446- access_token = "test-token " ,
1447+ auth_type = "databricks-oauth " ,
14471448 concurrent_tasks = 4 ,
14481449 )
14491450 assert isinstance (config , DatabricksConnectionConfig )
@@ -1453,6 +1454,41 @@ def test_databricks_shared_connection(make_config):
14531454 assert isinstance (adapter ._connection_pool , ThreadLocalSharedConnectionPool )
14541455
14551456
1457+ def test_databricks__m2m_oauth__connection_pool (make_config ):
1458+ from sqlmesh .utils .connection_pool import ThreadLocalConnectionPool
1459+
1460+ config = make_config (
1461+ type = "databricks" ,
1462+ server_hostname = "dbc-test.cloud.databricks.com" ,
1463+ http_path = "sql/test/foo" ,
1464+ auth_type = "databricks-oauth" ,
1465+ oauth_client_id = "oauth_client_id" ,
1466+ concurrent_tasks = 4 ,
1467+ )
1468+ assert isinstance (config , DatabricksConnectionConfig )
1469+ assert config .shared_connection is False
1470+
1471+ adapter = config .create_engine_adapter ()
1472+ assert isinstance (adapter ._connection_pool , ThreadLocalConnectionPool )
1473+
1474+
1475+ def test_databricks__access_token__connection_pool (make_config ):
1476+ from sqlmesh .utils .connection_pool import ThreadLocalConnectionPool
1477+
1478+ config = make_config (
1479+ type = "databricks" ,
1480+ server_hostname = "dbc-test.cloud.databricks.com" ,
1481+ http_path = "sql/test/foo" ,
1482+ access_token = "any-token" ,
1483+ concurrent_tasks = 4 ,
1484+ )
1485+ assert isinstance (config , DatabricksConnectionConfig )
1486+ assert config .shared_connection is False
1487+
1488+ adapter = config .create_engine_adapter ()
1489+ assert isinstance (adapter ._connection_pool , ThreadLocalConnectionPool )
1490+
1491+
14561492def test_engine_import_validator ():
14571493 with pytest .raises (
14581494 ConfigError ,
0 commit comments