1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517import asyncio
16- import mysql .connector
1718from contextlib import asynccontextmanager
18- from typing import Any , Callable , Generic
19+ from typing import Any , Callable , Dict , Generic , cast
1920
20- from tortoise .backends .base .client import BaseDBAsyncClient , T_conn , TransactionalDBClient , TransactionContext
21+ import mysql .connector
22+ from tortoise .backends .base .client import (BaseDBAsyncClient , T_conn ,
23+ TransactionalDBClient ,
24+ TransactionContext )
2125from tortoise .connection import connections
2226from tortoise .exceptions import TransactionManagementError
2327
2630
2731class AwsWrapperAsyncConnector :
2832 """Class for creating and closing AWS wrapper connections."""
29-
33+
3034 @staticmethod
31- async def ConnectWithAwsWrapper (connect_func : Callable , ** kwargs ) -> AwsWrapperConnection :
35+ async def connect_with_aws_wrapper (connect_func : Callable , ** kwargs ) -> AwsConnectionAsyncWrapper :
3236 """Create an AWS wrapper connection with async cursor support."""
3337 connection = await asyncio .to_thread (
3438 AwsWrapperConnection .connect , connect_func , ** kwargs
3539 )
3640 return AwsConnectionAsyncWrapper (connection )
37-
41+
3842 @staticmethod
39- async def CloseAwsWrapper (connection : AwsWrapperConnection ) -> None :
43+ async def close_aws_wrapper (connection : AwsWrapperConnection ) -> None :
4044 """Close an AWS wrapper connection asynchronously."""
4145 await asyncio .to_thread (connection .close )
4246
4347
4448class AwsCursorAsyncWrapper :
4549 """Wraps sync AwsCursor cursor with async support."""
46-
50+
4751 def __init__ (self , sync_cursor ):
4852 self ._cursor = sync_cursor
49-
53+
5054 async def execute (self , query , params = None ):
5155 """Execute a query asynchronously."""
5256 return await asyncio .to_thread (self ._cursor .execute , query , params )
53-
57+
5458 async def executemany (self , query , params_list ):
5559 """Execute multiple queries asynchronously."""
5660 return await asyncio .to_thread (self ._cursor .executemany , query , params_list )
57-
61+
5862 async def fetchall (self ):
5963 """Fetch all results asynchronously."""
6064 return await asyncio .to_thread (self ._cursor .fetchall )
61-
65+
6266 async def fetchone (self ):
6367 """Fetch one result asynchronously."""
6468 return await asyncio .to_thread (self ._cursor .fetchone )
65-
69+
6670 async def close (self ):
6771 """Close cursor asynchronously."""
6872 return await asyncio .to_thread (self ._cursor .close )
69-
73+
7074 def __getattr__ (self , name ):
7175 """Delegate non-async attributes to the wrapped cursor."""
7276 return getattr (self ._cursor , name )
7377
7478
7579class AwsConnectionAsyncWrapper (AwsWrapperConnection ):
7680 """Wraps sync AwsConnection with async cursor support."""
77-
81+
7882 def __init__ (self , connection : AwsWrapperConnection ):
7983 self ._wrapped_connection = connection
8084
@@ -90,40 +94,50 @@ async def cursor(self):
9094 async def rollback (self ):
9195 """Rollback the current transaction."""
9296 return await asyncio .to_thread (self ._wrapped_connection .rollback )
93-
97+
9498 async def commit (self ):
9599 """Commit the current transaction."""
96100 return await asyncio .to_thread (self ._wrapped_connection .commit )
97-
101+
98102 async def set_autocommit (self , value : bool ):
99103 """Set autocommit mode."""
100104 return await asyncio .to_thread (setattr , self ._wrapped_connection , 'autocommit' , value )
101105
102106 def __getattr__ (self , name ):
103107 """Delegate all other attributes/methods to the wrapped connection."""
104108 return getattr (self ._wrapped_connection , name )
105-
109+
106110 def __del__ (self ):
107111 """Delegate cleanup to wrapped connection."""
108112 if hasattr (self , '_wrapped_connection' ):
109113 # Let the wrapped connection handle its own cleanup
110114 pass
111115
112116
117+ class AwsBaseDBAsyncClient (BaseDBAsyncClient ):
118+ _template : Dict [str , Any ]
119+
120+
121+ class AwsTransactionalDBClient (TransactionalDBClient ):
122+ _template : Dict [str , Any ]
123+ _parent : AwsBaseDBAsyncClient
124+ pass
125+
126+
113127class TortoiseAwsClientConnectionWrapper (Generic [T_conn ]):
114128 """Manages acquiring from and releasing connections to a pool."""
115129
116130 __slots__ = ("client" , "connection" , "connect_func" , "with_db" )
117131
118132 def __init__ (
119- self ,
120- client : BaseDBAsyncClient ,
121- connect_func : Callable ,
133+ self ,
134+ client : AwsBaseDBAsyncClient ,
135+ connect_func : Callable ,
122136 with_db : bool = True
123137 ) -> None :
124138 self .connect_func = connect_func
125139 self .client = client
126- self .connection : T_conn | None = None
140+ self .connection : AwsConnectionAsyncWrapper | None = None
127141 self .with_db = with_db
128142
129143 async def ensure_connection (self ) -> None :
@@ -133,22 +147,22 @@ async def ensure_connection(self) -> None:
133147 async def __aenter__ (self ) -> T_conn :
134148 """Acquire connection from pool."""
135149 await self .ensure_connection ()
136- self .connection = await AwsWrapperAsyncConnector .ConnectWithAwsWrapper (self .connect_func , ** self .client ._template )
137- return self .connection
150+ self .connection = await AwsWrapperAsyncConnector .connect_with_aws_wrapper (self .connect_func , ** self .client ._template )
151+ return cast ( "T_conn" , self .connection )
138152
139153 async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
140154 """Close connection and release back to pool."""
141155 if self .connection :
142- await AwsWrapperAsyncConnector .CloseAwsWrapper (self .connection )
156+ await AwsWrapperAsyncConnector .close_aws_wrapper (self .connection )
143157
144158
145159class TortoiseAwsClientTransactionContext (TransactionContext ):
146160 """Transaction context that uses a pool to acquire connections."""
147161
148162 __slots__ = ("client" , "connection_name" , "token" )
149163
150- def __init__ (self , client : TransactionalDBClient ) -> None :
151- self .client = client
164+ def __init__ (self , client : AwsTransactionalDBClient ) -> None :
165+ self .client : AwsTransactionalDBClient = client
152166 self .connection_name = client .connection_name
153167
154168 async def ensure_connection (self ) -> None :
@@ -158,13 +172,13 @@ async def ensure_connection(self) -> None:
158172 async def __aenter__ (self ) -> TransactionalDBClient :
159173 """Enter transaction context."""
160174 await self .ensure_connection ()
161-
175+
162176 # Set the context variable so the current task sees a TransactionWrapper connection
163177 self .token = connections .set (self .connection_name , self .client )
164-
178+
165179 # Create connection and begin transaction
166- self .client ._connection = await AwsWrapperAsyncConnector .ConnectWithAwsWrapper (
167- mysql .connector .Connect ,
180+ self .client ._connection = await AwsWrapperAsyncConnector .connect_with_aws_wrapper (
181+ mysql .connector .Connect ,
168182 ** self .client ._parent ._template
169183 )
170184 await self .client .begin ()
@@ -181,5 +195,5 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
181195 else :
182196 await self .client .commit ()
183197 finally :
184- await AwsWrapperAsyncConnector .CloseAwsWrapper (self .client ._connection )
198+ await AwsWrapperAsyncConnector .close_aws_wrapper (self .client ._connection )
185199 connections .reset (self .token )
0 commit comments