Skip to content

Commit bd84b8b

Browse files
authored
fix: init catalogs at cursor level (#1630)
1 parent 2461127 commit bd84b8b

File tree

7 files changed

+67
-21
lines changed

7 files changed

+67
-21
lines changed

examples/sushi/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,15 @@
113113
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
114114
environment_suffix_target=EnvironmentSuffixTarget.TABLE,
115115
)
116+
117+
118+
CATALOGS = {
119+
"in_memory": ":memory:",
120+
"other_catalog": f":memory:",
121+
}
122+
123+
local_catalogs = Config(
124+
default_connection=DuckDBConnectionConfig(catalogs=CATALOGS),
125+
default_test_connection=DuckDBConnectionConfig(catalogs=CATALOGS),
126+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
127+
)

sqlmesh/core/config/connection.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]:
6262
"""Key-value arguments that will be passed during cursor construction."""
6363
return None
6464

65+
@property
66+
def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
67+
"""A function that is called to initialize the cursor"""
68+
return None
69+
6570
def create_engine_adapter(self) -> EngineAdapter:
6671
"""Returns a new instance of the Engine Adapter."""
6772
return self._engine_adapter(
@@ -73,6 +78,7 @@ def create_engine_adapter(self) -> EngineAdapter:
7378
),
7479
multithreaded=self.concurrent_tasks > 1,
7580
cursor_kwargs=self._cursor_kwargs,
81+
cursor_init=self._cursor_init,
7682
**self._extra_engine_config,
7783
)
7884

@@ -118,25 +124,29 @@ def _connection_factory(self) -> t.Callable:
118124

119125
return duckdb.connect
120126

121-
def create_engine_adapter(self) -> EngineAdapter:
122-
"""Returns a new instance of the Engine Adapter."""
127+
@property
128+
def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
129+
"""A function that is called to initialize the cursor"""
130+
import duckdb
123131
from duckdb import BinderException
124132

125-
engine_adapter = super().create_engine_adapter()
126-
for i, (alias, path) in enumerate((self.catalogs or {}).items()):
127-
try:
128-
engine_adapter.execute(f"ATTACH '{path}' AS {alias}")
129-
except BinderException as e:
130-
# If a user tries to create a catalog pointing at `:memory:` and with the name `memory`
131-
# then we don't want to raise since this happens by default. They are just doing this to
132-
# set it as the default catalog.
133-
if not (
134-
'database with name "memory" already exists' in str(e) and path == ":memory:"
135-
):
136-
raise e
137-
if i == 0 and not self.database:
138-
engine_adapter.set_current_catalog(alias)
139-
return engine_adapter
133+
def init_catalogs(cursor: duckdb.DuckDBPyConnection) -> None:
134+
for i, (alias, path) in enumerate((self.catalogs or {}).items()):
135+
try:
136+
cursor.execute(f"ATTACH '{path}' AS {alias}")
137+
except BinderException as e:
138+
# If a user tries to create a catalog pointing at `:memory:` and with the name `memory`
139+
# then we don't want to raise since this happens by default. They are just doing this to
140+
# set it as the default catalog.
141+
if not (
142+
'database with name "memory" already exists' in str(e)
143+
and path == ":memory:"
144+
):
145+
raise e
146+
if i == 0 and not self.database:
147+
cursor.execute(f"USE {alias}")
148+
149+
return init_catalogs
140150

141151

142152
class SnowflakeConnectionConfig(ConnectionConfig):

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,12 @@ def __init__(
164164
sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None,
165165
multithreaded: bool = False,
166166
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
167+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
167168
**kwargs: t.Any,
168169
):
169170
self.dialect = dialect.lower() or self.DIALECT
170171
self._connection_pool = create_connection_pool(
171-
connection_factory, multithreaded, cursor_kwargs=cursor_kwargs
172+
connection_factory, multithreaded, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init
172173
)
173174
self.sql_gen_kwargs = sql_gen_kwargs or {}
174175
self._extra_config = kwargs

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None,
4242
multithreaded: bool = False,
4343
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
44+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
4445
**kwargs: t.Any,
4546
):
4647
super().__init__(
@@ -49,6 +50,7 @@ def __init__(
4950
sql_gen_kwargs,
5051
multithreaded,
5152
cursor_kwargs,
53+
cursor_init,
5254
**kwargs,
5355
)
5456
self._spark: t.Optional[PySparkSession] = None

sqlmesh/utils/connection_pool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
self,
117117
connection_factory: t.Callable[[], t.Any],
118118
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
119+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
119120
):
120121
self._connection_factory = connection_factory
121122
self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
@@ -126,12 +127,15 @@ def __init__(
126127
self._thread_cursors_lock = Lock()
127128
self._thread_transactions_lock = Lock()
128129
self._cursor_kwargs = cursor_kwargs or {}
130+
self._cursor_init = cursor_init
129131

130132
def get_cursor(self) -> t.Any:
131133
thread_id = get_ident()
132134
with self._thread_cursors_lock:
133135
if thread_id not in self._thread_cursors:
134136
self._thread_cursors[thread_id] = self.get().cursor(**self._cursor_kwargs)
137+
if self._cursor_init:
138+
self._cursor_init(self._thread_cursors[thread_id])
135139
return self._thread_cursors[thread_id]
136140

137141
def get(self) -> t.Any:
@@ -206,17 +210,21 @@ def __init__(
206210
self,
207211
connection_factory: t.Callable[[], t.Any],
208212
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
213+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
209214
):
210215
self._connection_factory = connection_factory
211216
self._connection: t.Optional[t.Any] = None
212217
self._cursor: t.Optional[t.Any] = None
213218
self._cursor_kwargs = cursor_kwargs or {}
214219
self._attributes: t.Dict[str, t.Any] = {}
215220
self._is_transaction_active: bool = False
221+
self._cursor_init = cursor_init
216222

217223
def get_cursor(self) -> t.Any:
218224
if not self._cursor:
219225
self._cursor = self.get().cursor(**self._cursor_kwargs)
226+
if self._cursor_init:
227+
self._cursor_init(self._cursor)
220228
return self._cursor
221229

222230
def get(self) -> t.Any:
@@ -266,11 +274,16 @@ def create_connection_pool(
266274
connection_factory: t.Callable[[], t.Any],
267275
multithreaded: bool,
268276
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
277+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
269278
) -> ConnectionPool:
270279
return (
271-
ThreadLocalConnectionPool(connection_factory, cursor_kwargs=cursor_kwargs)
280+
ThreadLocalConnectionPool(
281+
connection_factory, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init
282+
)
272283
if multithreaded
273-
else SingletonConnectionPool(connection_factory, cursor_kwargs=cursor_kwargs)
284+
else SingletonConnectionPool(
285+
connection_factory, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init
286+
)
274287
)
275288

276289

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def sushi_test_dbt_context(mocker: MockerFixture) -> Context:
178178
return context
179179

180180

181+
@pytest.fixture()
182+
def sushi_default_catalog(mocker: MockerFixture) -> Context:
183+
context, plan = init_and_plan_context("examples/sushi", mocker, "local_catalogs")
184+
context.apply(plan)
185+
return context
186+
187+
181188
def init_and_plan_context(
182189
paths: str | t.List[str],
183190
mocker: MockerFixture,

tests/core/test_integration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ def test_plan_set_choice_is_reflected_in_missing_intervals(mocker: MockerFixture
418418
@pytest.mark.integration
419419
@pytest.mark.core_integration
420420
@pytest.mark.parametrize(
421-
"context_fixture", ["sushi_context", "sushi_dbt_context", "sushi_test_dbt_context"]
421+
"context_fixture",
422+
["sushi_context", "sushi_dbt_context", "sushi_test_dbt_context", "sushi_default_catalog"],
422423
)
423424
def test_model_add(context_fixture: Context, request):
424425
initial_add(request.getfixturevalue(context_fixture), "dev")

0 commit comments

Comments
 (0)