@@ -111,18 +111,16 @@ def _do_rollback(self) -> None:
111111 self .get ().rollback ()
112112
113113
114- class ThreadLocalConnectionPool (_TransactionManagementMixin ):
114+ class _ThreadLocalBase (_TransactionManagementMixin ):
115115 def __init__ (
116116 self ,
117117 connection_factory : t .Callable [[], t .Any ],
118118 cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
119119 ):
120120 self ._connection_factory = connection_factory
121- self ._thread_connections : t .Dict [t .Hashable , t .Any ] = {}
122121 self ._thread_cursors : t .Dict [t .Hashable , t .Any ] = {}
123122 self ._thread_transactions : t .Set [t .Hashable ] = set ()
124123 self ._thread_attributes : t .Dict [t .Hashable , t .Dict [str , t .Any ]] = defaultdict (dict )
125- self ._thread_connections_lock = Lock ()
126124 self ._thread_cursors_lock = Lock ()
127125 self ._thread_transactions_lock = Lock ()
128126 self ._cursor_init = cursor_init
@@ -136,13 +134,6 @@ def get_cursor(self) -> t.Any:
136134 self ._cursor_init (self ._thread_cursors [thread_id ])
137135 return self ._thread_cursors [thread_id ]
138136
139- def get (self ) -> t .Any :
140- thread_id = get_ident ()
141- with self ._thread_connections_lock :
142- if thread_id not in self ._thread_connections :
143- self ._thread_connections [thread_id ] = self ._connection_factory ()
144- return self ._thread_connections [thread_id ]
145-
146137 def get_attribute (self , key : str ) -> t .Optional [t .Any ]:
147138 thread_id = get_ident ()
148139 return self ._thread_attributes [thread_id ].get (key )
@@ -176,6 +167,28 @@ def close_cursor(self) -> None:
176167 _try_close (self ._thread_cursors [thread_id ], "cursor" )
177168 self ._thread_cursors .pop (thread_id )
178169
170+ def _discard_transaction (self , thread_id : t .Hashable ) -> None :
171+ with self ._thread_transactions_lock :
172+ self ._thread_transactions .discard (thread_id )
173+
174+
175+ class ThreadLocalConnectionPool (_ThreadLocalBase ):
176+ def __init__ (
177+ self ,
178+ connection_factory : t .Callable [[], t .Any ],
179+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
180+ ):
181+ super ().__init__ (connection_factory , cursor_init )
182+ self ._thread_connections : t .Dict [t .Hashable , t .Any ] = {}
183+ self ._thread_connections_lock = Lock ()
184+
185+ def get (self ) -> t .Any :
186+ thread_id = get_ident ()
187+ with self ._thread_connections_lock :
188+ if thread_id not in self ._thread_connections :
189+ self ._thread_connections [thread_id ] = self ._connection_factory ()
190+ return self ._thread_connections [thread_id ]
191+
179192 def close (self ) -> None :
180193 thread_id = get_ident ()
181194 with self ._thread_cursors_lock , self ._thread_connections_lock :
@@ -191,16 +204,51 @@ def close_all(self, exclude_calling_thread: bool = False) -> None:
191204 with self ._thread_cursors_lock , self ._thread_connections_lock :
192205 for thread_id , connection in self ._thread_connections .copy ().items ():
193206 if not exclude_calling_thread or thread_id != calling_thread_id :
194- # NOTE: the access to the connection instance itself is not thread-safe here.
195207 _try_close (connection , "connection" )
196208 self ._thread_connections .pop (thread_id )
197209 self ._thread_cursors .pop (thread_id , None )
198210 self ._discard_transaction (thread_id )
199211 self ._thread_attributes .pop (thread_id , None )
200212
201- def _discard_transaction (self , thread_id : t .Hashable ) -> None :
202- with self ._thread_transactions_lock :
203- self ._thread_transactions .discard (thread_id )
213+
214+ class ThreadLocalSharedConnectionPool (_ThreadLocalBase ):
215+ def __init__ (
216+ self ,
217+ connection_factory : t .Callable [[], t .Any ],
218+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
219+ ):
220+ super ().__init__ (connection_factory , cursor_init )
221+ self ._connection : t .Optional [t .Any ] = None
222+ self ._connection_lock = Lock ()
223+
224+ def get (self ) -> t .Any :
225+ with self ._connection_lock :
226+ if self ._connection is None :
227+ self ._connection = self ._connection_factory ()
228+ return self ._connection
229+
230+ def close (self ) -> None :
231+ thread_id = get_ident ()
232+ with self ._thread_cursors_lock , self ._connection_lock :
233+ if thread_id in self ._thread_cursors :
234+ _try_close (self ._thread_cursors [thread_id ], "cursor" )
235+ self ._thread_cursors .pop (thread_id )
236+ self ._discard_transaction (thread_id )
237+ self ._thread_attributes .pop (thread_id , None )
238+
239+ def close_all (self , exclude_calling_thread : bool = False ) -> None :
240+ calling_thread_id = get_ident ()
241+ with self ._thread_cursors_lock , self ._connection_lock :
242+ for thread_id , cursor in self ._thread_cursors .copy ().items ():
243+ if not exclude_calling_thread or thread_id != calling_thread_id :
244+ _try_close (cursor , "cursor" )
245+ self ._thread_cursors .pop (thread_id )
246+ self ._discard_transaction (thread_id )
247+ self ._thread_attributes .pop (thread_id , None )
248+
249+ if not exclude_calling_thread :
250+ _try_close (self ._connection , "connection" )
251+ self ._connection = None
204252
205253
206254class SingletonConnectionPool (_TransactionManagementMixin ):
@@ -269,13 +317,17 @@ def close_all(self, exclude_calling_thread: bool = False) -> None:
269317def create_connection_pool (
270318 connection_factory : t .Callable [[], t .Any ],
271319 multithreaded : bool ,
320+ shared_connection : bool = False ,
272321 cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
273322) -> ConnectionPool :
274- return (
275- ThreadLocalConnectionPool (connection_factory , cursor_init = cursor_init )
323+ pool_class = (
324+ ThreadLocalSharedConnectionPool
325+ if multithreaded and shared_connection
326+ else ThreadLocalConnectionPool
276327 if multithreaded
277- else SingletonConnectionPool ( connection_factory , cursor_init = cursor_init )
328+ else SingletonConnectionPool
278329 )
330+ return pool_class (connection_factory , cursor_init = cursor_init )
279331
280332
281333def _try_close (closeable : t .Any , kind : str ) -> None :
0 commit comments