1414
1515import asyncio
1616import mysql .connector
17+ from concurrent .futures import ThreadPoolExecutor
1718from contextlib import asynccontextmanager
1819from typing import Any , Callable , Generic
1920
2425from aws_advanced_python_wrapper import AwsWrapperConnection
2526
2627
27- async def ConnectWithAwsWrapper (connect_func : Callable , ** kwargs ) -> AwsWrapperConnection :
28- """Create an AWS wrapper connection with async cursor support."""
29- connection = await asyncio .to_thread (
30- AwsWrapperConnection .connect ,
31- connect_func ,
32- ** kwargs ,
28+ class AwsWrapperAsyncConnector :
29+ """Factory class for creating AWS wrapper connections."""
30+
31+ _executor : ThreadPoolExecutor = ThreadPoolExecutor (
32+ thread_name_prefix = "AwsWrapperConnectorExecutor"
3333 )
34- return AwsConnectionAsyncWrapper (connection )
34+
35+ @staticmethod
36+ async def ConnectWithAwsWrapper (connect_func : Callable , ** kwargs ) -> AwsWrapperConnection :
37+ """Create an AWS wrapper connection with async cursor support."""
38+ loop = asyncio .get_event_loop ()
39+ connection = await loop .run_in_executor (
40+ AwsWrapperAsyncConnector ._executor ,
41+ lambda : AwsWrapperConnection .connect (connect_func , ** kwargs )
42+ )
43+ return AwsConnectionAsyncWrapper (connection )
44+
45+ @staticmethod
46+ async def CloseAwsWrapper (connection : AwsWrapperConnection ) -> None :
47+ """Close an AWS wrapper connection asynchronously."""
48+ loop = asyncio .get_event_loop ()
49+ await loop .run_in_executor (
50+ AwsWrapperAsyncConnector ._executor ,
51+ connection .close
52+ )
3553
3654
3755class AwsCursorAsyncWrapper :
3856 """Wraps a sync cursor to provide async interface."""
3957
58+ _executor : ThreadPoolExecutor = ThreadPoolExecutor (
59+ thread_name_prefix = "AwsCursorAsyncWrapperExecutor"
60+ )
61+
4062 def __init__ (self , sync_cursor ):
4163 self ._cursor = sync_cursor
4264
4365 async def execute (self , query , params = None ):
4466 """Execute a query asynchronously."""
45- return await asyncio .to_thread (self ._cursor .execute , query , params )
67+ loop = asyncio .get_event_loop ()
68+ return await loop .run_in_executor (self ._executor , self ._cursor .execute , query , params )
4669
4770 async def executemany (self , query , params_list ):
4871 """Execute multiple queries asynchronously."""
49- return await asyncio .to_thread (self ._cursor .executemany , query , params_list )
72+ loop = asyncio .get_event_loop ()
73+ return await loop .run_in_executor (self ._executor , self ._cursor .executemany , query , params_list )
5074
5175 async def fetchall (self ):
5276 """Fetch all results asynchronously."""
53- return await asyncio .to_thread (self ._cursor .fetchall )
77+ loop = asyncio .get_event_loop ()
78+ return await loop .run_in_executor (self ._executor , self ._cursor .fetchall )
5479
5580 async def fetchone (self ):
5681 """Fetch one result asynchronously."""
57- return await asyncio .to_thread (self ._cursor .fetchone )
82+ loop = asyncio .get_event_loop ()
83+ return await loop .run_in_executor (self ._executor , self ._cursor .fetchone )
5884
5985 async def close (self ):
6086 """Close cursor asynchronously."""
61- return await asyncio .to_thread (self ._cursor .close )
87+ loop = asyncio .get_event_loop ()
88+ return await loop .run_in_executor (self ._executor , self ._cursor .close )
6289
6390 def __getattr__ (self , name ):
6491 """Delegate non-async attributes to the wrapped cursor."""
@@ -68,29 +95,37 @@ def __getattr__(self, name):
6895class AwsConnectionAsyncWrapper (AwsWrapperConnection ):
6996 """AWS wrapper connection with async cursor support."""
7097
98+ _executor : ThreadPoolExecutor = ThreadPoolExecutor (
99+ thread_name_prefix = "AwsConnectionAsyncWrapperExecutor"
100+ )
101+
71102 def __init__ (self , connection : AwsWrapperConnection ):
72103 self ._wrapped_connection = connection
73104
74105 @asynccontextmanager
75106 async def cursor (self ):
76107 """Create an async cursor context manager."""
77- cursor_obj = await asyncio .to_thread (self ._wrapped_connection .cursor )
108+ loop = asyncio .get_event_loop ()
109+ cursor_obj = await loop .run_in_executor (self ._executor , self ._wrapped_connection .cursor )
78110 try :
79111 yield AwsCursorAsyncWrapper (cursor_obj )
80112 finally :
81- await asyncio . to_thread ( cursor_obj .close )
113+ await loop . run_in_executor ( self . _executor , cursor_obj .close )
82114
83115 async def rollback (self ):
84116 """Rollback the current transaction."""
85- return await asyncio .to_thread (self ._wrapped_connection .rollback )
117+ loop = asyncio .get_event_loop ()
118+ return await loop .run_in_executor (self ._executor , self ._wrapped_connection .rollback )
86119
87120 async def commit (self ):
88121 """Commit the current transaction."""
89- return await asyncio .to_thread (self ._wrapped_connection .commit )
122+ loop = asyncio .get_event_loop ()
123+ return await loop .run_in_executor (self ._executor , self ._wrapped_connection .commit )
90124
91125 async def set_autocommit (self , value : bool ):
92126 """Set autocommit mode."""
93- return await asyncio .to_thread (lambda : setattr (self ._wrapped_connection , 'autocommit' , value ))
127+ loop = asyncio .get_event_loop ()
128+ return await loop .run_in_executor (self ._executor , lambda : setattr (self ._wrapped_connection , 'autocommit' , value ))
94129
95130 def __getattr__ (self , name ):
96131 """Delegate all other attributes/methods to the wrapped connection."""
@@ -128,13 +163,13 @@ async def ensure_connection(self) -> None:
128163 async def __aenter__ (self ) -> T_conn :
129164 """Acquire connection from pool."""
130165 await self .ensure_connection ()
131- self .connection = await ConnectWithAwsWrapper (self .connect_func , ** self .client ._template )
166+ self .connection = await AwsWrapperAsyncConnector . ConnectWithAwsWrapper (self .connect_func , ** self .client ._template )
132167 return self .connection
133168
134169 async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
135170 """Close connection and release back to pool."""
136171 if self .connection :
137- await asyncio . to_thread (self .connection . close )
172+ await AwsWrapperAsyncConnector . CloseAwsWrapper (self .connection )
138173
139174
140175class TortoiseAwsClientTransactionContext (TransactionContext ):
@@ -159,7 +194,7 @@ async def __aenter__(self) -> TransactionalDBClient:
159194 self .token = connections .set (self .connection_name , self .client )
160195
161196 # Create connection and begin transaction
162- self .client ._connection = await ConnectWithAwsWrapper (
197+ self .client ._connection = await AwsWrapperAsyncConnector . ConnectWithAwsWrapper (
163198 mysql .connector .Connect ,
164199 ** self .client ._parent ._template
165200 )
@@ -178,4 +213,4 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
178213 await self .client .commit ()
179214 finally :
180215 connections .reset (self .token )
181- await asyncio . to_thread (self .client ._connection . close )
216+ await AwsWrapperAsyncConnector . CloseAwsWrapper (self .client ._connection )
0 commit comments