Skip to content

Commit 046b43a

Browse files
authored
Fix: Properly support engines that can share a single connection instance across threads (#4124)
1 parent 34ad12d commit 046b43a

File tree

7 files changed

+180
-62
lines changed

7 files changed

+180
-62
lines changed

sqlmesh/core/config/connection.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import typing as t
1010
from enum import Enum
11-
from functools import partial, lru_cache
11+
from functools import partial
1212

1313
import pydantic
1414
from pydantic import Field
@@ -51,6 +51,9 @@ class ConnectionConfig(abc.ABC, BaseConfig):
5151
pre_ping: bool
5252
pretty_sql: bool = False
5353

54+
# Whether to share a single connection across threads or create a new connection per thread.
55+
shared_connection: t.ClassVar[bool] = False
56+
5457
@property
5558
@abc.abstractmethod
5659
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -116,6 +119,7 @@ def create_engine_adapter(self, register_comments_override: bool = False) -> Eng
116119
register_comments=register_comments_override or self.register_comments,
117120
pre_ping=self.pre_ping,
118121
pretty_sql=self.pretty_sql,
122+
shared_connection=self.shared_connection,
119123
**self._extra_engine_config,
120124
)
121125

@@ -182,6 +186,8 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
182186

183187
token: t.Optional[str] = None
184188

189+
shared_connection: t.ClassVar[bool] = True
190+
185191
_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
186192

187193
@model_validator(mode="before")
@@ -212,43 +218,6 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
212218
def _connection_factory(self) -> t.Callable:
213219
import duckdb
214220

215-
if self.concurrent_tasks > 1:
216-
# ensures a single connection instance is used across threads rather than a new connection being established per thread
217-
# this is in line with https://duckdb.org/docs/guides/python/multiple_threads.html
218-
# the important thing is that the *cursor*'s are per thread, but the connection should be shared
219-
@lru_cache
220-
def _factory(*args: t.Any, **kwargs: t.Any) -> t.Any:
221-
class ConnWrapper:
222-
def __init__(self, conn: duckdb.DuckDBPyConnection):
223-
self.conn = conn
224-
225-
def __getattr__(self, attr: str) -> t.Any:
226-
return getattr(self.conn, attr)
227-
228-
def close(self) -> None:
229-
# This overrides conn.close() to be a no-op to work with ThreadLocalConnectionPool which assumes that a new connection should
230-
# be created per thread. However, DuckDB expects the same connection instance to be shared across threads. There is a pattern
231-
# in the SQLMesh codebase that `EngineAdapter.recycle()` is called after doing things like merging intervals. This in turn causes
232-
# `ThreadLocalConnectionPool.close_all(exclude_calling_thread=True)` to be called.
233-
#
234-
# The problem with sharing a connection across threads and then allowing it to be closed for every thread except the current one
235-
# is that it gets closed for the current one too because its shared. This causes any ":memory:" databases to be discarded.
236-
# ":memory:" databases are convienient and are used heavily in our test suite amongst other things.
237-
#
238-
# Ok, so why not have a connection per thread as is the default for ThreadLocalConnectionPool? Two reasons:
239-
# - It makes any ":memory:" databases unique to that thread. So if one thread creates tables, another thread cant see them
240-
# - If you use local files instead (eg point each connection to the same db file) then all the connection instances
241-
# fight over locks to the same file and performance tanks heavily
242-
#
243-
# From what I can tell, DuckDB expects the single process reading / writing the database from multiple
244-
# threads to /share the same connection/ and just use thread-local cursors. In order to support ":memory:" databases
245-
# and remove lock contention, the connection needs to live for the life of the application and not be closed
246-
pass
247-
248-
return ConnWrapper(duckdb.connect(*args, **kwargs))
249-
250-
return _factory
251-
252221
return duckdb.connect
253222

254223
@property

sqlmesh/core/engine_adapter/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,18 @@ def __init__(
119119
register_comments: bool = True,
120120
pre_ping: bool = False,
121121
pretty_sql: bool = False,
122+
shared_connection: bool = False,
122123
**kwargs: t.Any,
123124
):
124125
self.dialect = dialect.lower() or self.DIALECT
125126
self._connection_pool = (
126127
connection_factory_or_pool
127128
if isinstance(connection_factory_or_pool, ConnectionPool)
128129
else create_connection_pool(
129-
connection_factory_or_pool, multithreaded, cursor_init=cursor_init
130+
connection_factory_or_pool,
131+
multithreaded,
132+
shared_connection=shared_connection,
133+
cursor_init=cursor_init,
130134
)
131135
)
132136
self._sql_gen_kwargs = sql_gen_kwargs or {}

sqlmesh/utils/connection_pool.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,16 @@ def _do_rollback(self) -> None:
111111
self.get().rollback()
112112

113113

114-
class ThreadLocalConnectionPool(_TransactionManagementMixin):
114+
class _ThreadLocalBase(_TransactionManagementMixin):
115115
def __init__(
116116
self,
117117
connection_factory: t.Callable[[], t.Any],
118118
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
119119
):
120120
self._connection_factory = connection_factory
121-
self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
122121
self._thread_cursors: t.Dict[t.Hashable, t.Any] = {}
123122
self._thread_transactions: t.Set[t.Hashable] = set()
124123
self._thread_attributes: t.Dict[t.Hashable, t.Dict[str, t.Any]] = defaultdict(dict)
125-
self._thread_connections_lock = Lock()
126124
self._thread_cursors_lock = Lock()
127125
self._thread_transactions_lock = Lock()
128126
self._cursor_init = cursor_init
@@ -136,13 +134,6 @@ def get_cursor(self) -> t.Any:
136134
self._cursor_init(self._thread_cursors[thread_id])
137135
return self._thread_cursors[thread_id]
138136

139-
def get(self) -> t.Any:
140-
thread_id = get_ident()
141-
with self._thread_connections_lock:
142-
if thread_id not in self._thread_connections:
143-
self._thread_connections[thread_id] = self._connection_factory()
144-
return self._thread_connections[thread_id]
145-
146137
def get_attribute(self, key: str) -> t.Optional[t.Any]:
147138
thread_id = get_ident()
148139
return self._thread_attributes[thread_id].get(key)
@@ -176,6 +167,28 @@ def close_cursor(self) -> None:
176167
_try_close(self._thread_cursors[thread_id], "cursor")
177168
self._thread_cursors.pop(thread_id)
178169

170+
def _discard_transaction(self, thread_id: t.Hashable) -> None:
171+
with self._thread_transactions_lock:
172+
self._thread_transactions.discard(thread_id)
173+
174+
175+
class ThreadLocalConnectionPool(_ThreadLocalBase):
176+
def __init__(
177+
self,
178+
connection_factory: t.Callable[[], t.Any],
179+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
180+
):
181+
super().__init__(connection_factory, cursor_init)
182+
self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
183+
self._thread_connections_lock = Lock()
184+
185+
def get(self) -> t.Any:
186+
thread_id = get_ident()
187+
with self._thread_connections_lock:
188+
if thread_id not in self._thread_connections:
189+
self._thread_connections[thread_id] = self._connection_factory()
190+
return self._thread_connections[thread_id]
191+
179192
def close(self) -> None:
180193
thread_id = get_ident()
181194
with self._thread_cursors_lock, self._thread_connections_lock:
@@ -191,16 +204,51 @@ def close_all(self, exclude_calling_thread: bool = False) -> None:
191204
with self._thread_cursors_lock, self._thread_connections_lock:
192205
for thread_id, connection in self._thread_connections.copy().items():
193206
if not exclude_calling_thread or thread_id != calling_thread_id:
194-
# NOTE: the access to the connection instance itself is not thread-safe here.
195207
_try_close(connection, "connection")
196208
self._thread_connections.pop(thread_id)
197209
self._thread_cursors.pop(thread_id, None)
198210
self._discard_transaction(thread_id)
199211
self._thread_attributes.pop(thread_id, None)
200212

201-
def _discard_transaction(self, thread_id: t.Hashable) -> None:
202-
with self._thread_transactions_lock:
203-
self._thread_transactions.discard(thread_id)
213+
214+
class ThreadLocalSharedConnectionPool(_ThreadLocalBase):
215+
def __init__(
216+
self,
217+
connection_factory: t.Callable[[], t.Any],
218+
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
219+
):
220+
super().__init__(connection_factory, cursor_init)
221+
self._connection: t.Optional[t.Any] = None
222+
self._connection_lock = Lock()
223+
224+
def get(self) -> t.Any:
225+
with self._connection_lock:
226+
if self._connection is None:
227+
self._connection = self._connection_factory()
228+
return self._connection
229+
230+
def close(self) -> None:
231+
thread_id = get_ident()
232+
with self._thread_cursors_lock, self._connection_lock:
233+
if thread_id in self._thread_cursors:
234+
_try_close(self._thread_cursors[thread_id], "cursor")
235+
self._thread_cursors.pop(thread_id)
236+
self._discard_transaction(thread_id)
237+
self._thread_attributes.pop(thread_id, None)
238+
239+
def close_all(self, exclude_calling_thread: bool = False) -> None:
240+
calling_thread_id = get_ident()
241+
with self._thread_cursors_lock, self._connection_lock:
242+
for thread_id, cursor in self._thread_cursors.copy().items():
243+
if not exclude_calling_thread or thread_id != calling_thread_id:
244+
_try_close(cursor, "cursor")
245+
self._thread_cursors.pop(thread_id)
246+
self._discard_transaction(thread_id)
247+
self._thread_attributes.pop(thread_id, None)
248+
249+
if not exclude_calling_thread:
250+
_try_close(self._connection, "connection")
251+
self._connection = None
204252

205253

206254
class SingletonConnectionPool(_TransactionManagementMixin):
@@ -269,13 +317,17 @@ def close_all(self, exclude_calling_thread: bool = False) -> None:
269317
def create_connection_pool(
270318
connection_factory: t.Callable[[], t.Any],
271319
multithreaded: bool,
320+
shared_connection: bool = False,
272321
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
273322
) -> ConnectionPool:
274-
return (
275-
ThreadLocalConnectionPool(connection_factory, cursor_init=cursor_init)
323+
pool_class = (
324+
ThreadLocalSharedConnectionPool
325+
if multithreaded and shared_connection
326+
else ThreadLocalConnectionPool
276327
if multithreaded
277-
else SingletonConnectionPool(connection_factory, cursor_init=cursor_init)
328+
else SingletonConnectionPool
278329
)
330+
return pool_class(connection_factory, cursor_init=cursor_init)
279331

280332

281333
def _try_close(closeable: t.Any, kind: str) -> None:

tests/core/engine_adapter/integration/test_integration_duckdb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlglot import exp
66

77
from sqlmesh.core.config.connection import DuckDBConnectionConfig
8-
from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool
8+
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
99

1010
pytestmark = [pytest.mark.duckdb, pytest.mark.engine, pytest.mark.slow]
1111

@@ -21,7 +21,7 @@ def test_multithread_concurrency(tmp_path, database: t.Optional[str]):
2121

2222
adapter = config.create_engine_adapter()
2323

24-
assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool)
24+
assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool)
2525

2626
# this test loosely follows this example: https://duckdb.org/docs/guides/python/multiple_threads.html
2727
adapter.execute(

tests/core/test_connection_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def test_duckdb_attach_options():
607607

608608
def test_duckdb_multithreaded_connection_factory(make_config):
609609
from sqlmesh.core.engine_adapter import DuckDBEngineAdapter
610-
from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool
610+
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
611611
from threading import Thread
612612

613613
config = make_config(type="duckdb")
@@ -620,7 +620,7 @@ def test_duckdb_multithreaded_connection_factory(make_config):
620620
config = make_config(type="duckdb", concurrent_tasks=8)
621621
adapter = config.create_engine_adapter()
622622
assert isinstance(adapter, DuckDBEngineAdapter)
623-
assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool)
623+
assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool)
624624

625625
threads = []
626626
connection_objects = []

tests/core/test_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from sqlmesh.core.plan import BuiltInPlanEvaluator, PlanBuilder
4040
from sqlmesh.core.state_sync.cache import CachingStateSync
4141
from sqlmesh.core.state_sync.db import EngineAdapterStateSync
42-
from sqlmesh.utils.connection_pool import SingletonConnectionPool, ThreadLocalConnectionPool
42+
from sqlmesh.utils.connection_pool import SingletonConnectionPool, ThreadLocalSharedConnectionPool
4343
from sqlmesh.utils.date import (
4444
make_inclusive_end,
4545
now,
@@ -1209,7 +1209,7 @@ def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path):
12091209
state_sync = context.state_sync.state_sync
12101210
assert isinstance(state_sync, EngineAdapterStateSync)
12111211
assert isinstance(state_sync.engine_adapter, DuckDBEngineAdapter)
1212-
assert isinstance(state_sync.engine_adapter._connection_pool, ThreadLocalConnectionPool)
1212+
assert isinstance(state_sync.engine_adapter._connection_pool, ThreadLocalSharedConnectionPool)
12131213

12141214

12151215
def test_requirements(copy_to_temp_path: t.Callable):

tests/utils/test_connection_pool.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sqlmesh.utils.connection_pool import (
77
SingletonConnectionPool,
88
ThreadLocalConnectionPool,
9+
ThreadLocalSharedConnectionPool,
910
)
1011

1112

@@ -207,3 +208,95 @@ def thread():
207208
assert cursor_mock_thread_one.rollback.call_count == 1
208209

209210
assert cursor_mock_thread_two.begin.call_count == 1
211+
212+
213+
def test_thread_local_shared_connection_pool(mocker: MockerFixture):
214+
cursor_mock_thread_one = mocker.Mock()
215+
cursor_mock_thread_two = mocker.Mock()
216+
connection_mock = mocker.Mock()
217+
connection_mock.cursor.side_effect = [
218+
cursor_mock_thread_one,
219+
cursor_mock_thread_two,
220+
cursor_mock_thread_one,
221+
]
222+
223+
test_thread_id = get_ident()
224+
225+
connection_factory_mock = mocker.Mock(return_value=connection_mock)
226+
pool = ThreadLocalSharedConnectionPool(connection_factory_mock)
227+
228+
assert pool.get_cursor() == cursor_mock_thread_one
229+
assert pool.get_cursor() == cursor_mock_thread_one
230+
assert pool.get() == connection_mock
231+
assert pool.get() == connection_mock
232+
233+
def thread():
234+
assert pool.get_cursor() == cursor_mock_thread_two
235+
assert pool.get_cursor() == cursor_mock_thread_two
236+
assert pool.get() == connection_mock
237+
assert pool.get() == connection_mock
238+
239+
with ThreadPoolExecutor(max_workers=1) as executor:
240+
executor.submit(thread).result()
241+
242+
assert pool._connection is not None
243+
assert len(pool._thread_cursors) == 2
244+
245+
pool.close_all(exclude_calling_thread=True)
246+
247+
assert pool._connection is not None
248+
assert len(pool._thread_cursors) == 1
249+
assert test_thread_id in pool._thread_cursors
250+
251+
pool.close_cursor()
252+
pool.close()
253+
254+
assert pool.get_cursor() == cursor_mock_thread_one
255+
256+
pool.close_all()
257+
258+
assert connection_factory_mock.call_count == 1
259+
260+
assert cursor_mock_thread_one.close.call_count == 2
261+
assert connection_mock.cursor.call_count == 3
262+
assert connection_mock.close.call_count == 1
263+
264+
265+
def test_thread_local_shared_connection_pool_close(mocker: MockerFixture):
266+
connection_mock = mocker.Mock()
267+
cursor_mock = mocker.Mock()
268+
connection_mock.cursor.return_value = cursor_mock
269+
270+
connection_factory_mock = mocker.Mock(return_value=connection_mock)
271+
pool = ThreadLocalSharedConnectionPool(connection_factory_mock)
272+
273+
# First time we get a connection
274+
pool.get()
275+
pool.get()
276+
pool.get_cursor()
277+
pool.get_cursor()
278+
279+
# This shouldn't close the connection, only the cursor
280+
pool.close()
281+
pool.get()
282+
pool.get()
283+
pool.get_cursor()
284+
285+
pool.get_cursor()
286+
# This shouldn't close the connection either
287+
pool.close_all(exclude_calling_thread=True)
288+
289+
pool.get()
290+
pool.get()
291+
# Now this should close the connection
292+
pool.close_all()
293+
294+
# Re-open the connection
295+
pool.get()
296+
pool.get()
297+
# Close it again
298+
pool.close_all()
299+
300+
assert cursor_mock.close.call_count == 2
301+
assert connection_factory_mock.call_count == 2
302+
assert connection_mock.close.call_count == 2

0 commit comments

Comments
 (0)