@@ -116,6 +116,7 @@ def __init__(
116116 self ,
117117 connection_factory : t .Callable [[], t .Any ],
118118 cursor_kwargs : t .Optional [t .Dict [str , t .Any ]] = None ,
119+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
119120 ):
120121 self ._connection_factory = connection_factory
121122 self ._thread_connections : t .Dict [t .Hashable , t .Any ] = {}
@@ -126,12 +127,15 @@ def __init__(
126127 self ._thread_cursors_lock = Lock ()
127128 self ._thread_transactions_lock = Lock ()
128129 self ._cursor_kwargs = cursor_kwargs or {}
130+ self ._cursor_init = cursor_init
129131
130132 def get_cursor (self ) -> t .Any :
131133 thread_id = get_ident ()
132134 with self ._thread_cursors_lock :
133135 if thread_id not in self ._thread_cursors :
134136 self ._thread_cursors [thread_id ] = self .get ().cursor (** self ._cursor_kwargs )
137+ if self ._cursor_init :
138+ self ._cursor_init (self ._thread_cursors [thread_id ])
135139 return self ._thread_cursors [thread_id ]
136140
137141 def get (self ) -> t .Any :
@@ -206,17 +210,21 @@ def __init__(
206210 self ,
207211 connection_factory : t .Callable [[], t .Any ],
208212 cursor_kwargs : t .Optional [t .Dict [str , t .Any ]] = None ,
213+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
209214 ):
210215 self ._connection_factory = connection_factory
211216 self ._connection : t .Optional [t .Any ] = None
212217 self ._cursor : t .Optional [t .Any ] = None
213218 self ._cursor_kwargs = cursor_kwargs or {}
214219 self ._attributes : t .Dict [str , t .Any ] = {}
215220 self ._is_transaction_active : bool = False
221+ self ._cursor_init = cursor_init
216222
217223 def get_cursor (self ) -> t .Any :
218224 if not self ._cursor :
219225 self ._cursor = self .get ().cursor (** self ._cursor_kwargs )
226+ if self ._cursor_init :
227+ self ._cursor_init (self ._cursor )
220228 return self ._cursor
221229
222230 def get (self ) -> t .Any :
@@ -266,11 +274,16 @@ def create_connection_pool(
266274 connection_factory : t .Callable [[], t .Any ],
267275 multithreaded : bool ,
268276 cursor_kwargs : t .Optional [t .Dict [str , t .Any ]] = None ,
277+ cursor_init : t .Optional [t .Callable [[t .Any ], None ]] = None ,
269278) -> ConnectionPool :
270279 return (
271- ThreadLocalConnectionPool (connection_factory , cursor_kwargs = cursor_kwargs )
280+ ThreadLocalConnectionPool (
281+ connection_factory , cursor_kwargs = cursor_kwargs , cursor_init = cursor_init
282+ )
272283 if multithreaded
273- else SingletonConnectionPool (connection_factory , cursor_kwargs = cursor_kwargs )
284+ else SingletonConnectionPool (
285+ connection_factory , cursor_kwargs = cursor_kwargs , cursor_init = cursor_init
286+ )
274287 )
275288
276289
0 commit comments