1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License").
4+ # You may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import asyncio
16+ import mysql .connector
17+ from contextlib import asynccontextmanager
18+ from typing import Any , Callable , Generic
19+
20+ from tortoise .backends .base .client import BaseDBAsyncClient , T_conn , TransactionalDBClient , TransactionContext
21+ from tortoise .connection import connections
22+ from tortoise .exceptions import TransactionManagementError
23+
24+ from aws_advanced_python_wrapper import AwsWrapperConnection
25+
26+
27+ async def ConnectWithAwsWrapper (connect_func : Callable , ** kwargs ) -> AwsWrapperConnection :
28+ """Create an AWS wrapper connection with async cursor support."""
29+ connection = await asyncio .to_thread (
30+ AwsWrapperConnection .connect ,
31+ connect_func ,
32+ ** kwargs ,
33+ )
34+ return AwsConnectionAsyncWrapper (connection )
35+
36+
37+ class AwsCursorAsyncWrapper :
38+ """Wraps a sync cursor to provide async interface."""
39+
40+ def __init__ (self , sync_cursor ):
41+ self ._cursor = sync_cursor
42+
43+ async def execute (self , query , params = None ):
44+ """Execute a query asynchronously."""
45+ return await asyncio .to_thread (self ._cursor .execute , query , params )
46+
47+ async def executemany (self , query , params_list ):
48+ """Execute multiple queries asynchronously."""
49+ return await asyncio .to_thread (self ._cursor .executemany , query , params_list )
50+
51+ async def fetchall (self ):
52+ """Fetch all results asynchronously."""
53+ return await asyncio .to_thread (self ._cursor .fetchall )
54+
55+ async def fetchone (self ):
56+ """Fetch one result asynchronously."""
57+ return await asyncio .to_thread (self ._cursor .fetchone )
58+
59+ async def close (self ):
60+ """Close cursor asynchronously."""
61+ return await asyncio .to_thread (self ._cursor .close )
62+
63+ def __getattr__ (self , name ):
64+ """Delegate non-async attributes to the wrapped cursor."""
65+ return getattr (self ._cursor , name )
66+
67+
68+ class AwsConnectionAsyncWrapper (AwsWrapperConnection ):
69+ """AWS wrapper connection with async cursor support."""
70+
71+ def __init__ (self , connection : AwsWrapperConnection ):
72+ self ._wrapped_connection = connection
73+
74+ @asynccontextmanager
75+ async def cursor (self ):
76+ """Create an async cursor context manager."""
77+ cursor_obj = await asyncio .to_thread (self ._wrapped_connection .cursor )
78+ try :
79+ yield AwsCursorAsyncWrapper (cursor_obj )
80+ finally :
81+ await asyncio .to_thread (cursor_obj .close )
82+
83+ async def rollback (self ):
84+ """Rollback the current transaction."""
85+ return await asyncio .to_thread (self ._wrapped_connection .rollback )
86+
87+ async def commit (self ):
88+ """Commit the current transaction."""
89+ return await asyncio .to_thread (self ._wrapped_connection .commit )
90+
91+ async def set_autocommit (self , value : bool ):
92+ """Set autocommit mode."""
93+ return await asyncio .to_thread (lambda : setattr (self ._wrapped_connection , 'autocommit' , value ))
94+
95+ def __getattr__ (self , name ):
96+ """Delegate all other attributes/methods to the wrapped connection."""
97+ return getattr (self ._wrapped_connection , name )
98+
99+ def __del__ (self ):
100+ """Delegate cleanup to wrapped connection."""
101+ if hasattr (self , '_wrapped_connection' ):
102+ # Let the wrapped connection handle its own cleanup
103+ pass
104+
105+
106+ class TortoiseAwsClientConnectionWrapper (Generic [T_conn ]):
107+ """Manages acquiring from and releasing connections to a pool."""
108+
109+ __slots__ = ("client" , "connection" , "_pool_init_lock" , "connect_func" , "with_db" )
110+
111+ def __init__ (
112+ self ,
113+ client : BaseDBAsyncClient ,
114+ pool_init_lock : asyncio .Lock ,
115+ connect_func : Callable ,
116+ with_db : bool = True
117+ ) -> None :
118+ self .connect_func = connect_func
119+ self .client = client
120+ self .connection : T_conn | None = None
121+ self ._pool_init_lock = pool_init_lock
122+ self .with_db = with_db
123+
124+ async def ensure_connection (self ) -> None :
125+ """Ensure the connection pool is initialized."""
126+ await self .client .create_connection (with_db = self .with_db )
127+
128+ async def __aenter__ (self ) -> T_conn :
129+ """Acquire connection from pool."""
130+ await self .ensure_connection ()
131+ self .connection = await ConnectWithAwsWrapper (self .connect_func , ** self .client ._template )
132+ return self .connection
133+
134+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
135+ """Close connection and release back to pool."""
136+ if self .connection :
137+ await asyncio .to_thread (self .connection .close )
138+
139+
140+ class TortoiseAwsClientTransactionContext (TransactionContext ):
141+ """Transaction context that uses a pool to acquire connections."""
142+
143+ __slots__ = ("client" , "connection_name" , "token" , "_pool_init_lock" )
144+
145+ def __init__ (self , client : TransactionalDBClient , pool_init_lock : asyncio .Lock ) -> None :
146+ self .client = client
147+ self .connection_name = client .connection_name
148+ self ._pool_init_lock = pool_init_lock
149+
150+ async def ensure_connection (self ) -> None :
151+ """Ensure the connection pool is initialized."""
152+ await self .client ._parent .create_connection (with_db = True )
153+
154+ async def __aenter__ (self ) -> TransactionalDBClient :
155+ """Enter transaction context."""
156+ await self .ensure_connection ()
157+
158+ # Set the context variable so the current task sees a TransactionWrapper connection
159+ self .token = connections .set (self .connection_name , self .client )
160+
161+ # Create connection and begin transaction
162+ self .client ._connection = await ConnectWithAwsWrapper (
163+ mysql .connector .Connect ,
164+ ** self .client ._parent ._template
165+ )
166+ await self .client .begin ()
167+ return self .client
168+
169+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
170+ """Exit transaction context with proper cleanup."""
171+ try :
172+ if not self .client ._finalized :
173+ if exc_type :
174+ # Can't rollback a transaction that already failed
175+ if exc_type is not TransactionManagementError :
176+ await self .client .rollback ()
177+ else :
178+ await self .client .commit ()
179+ finally :
180+ connections .reset (self .token )
181+ await asyncio .to_thread (self .client ._connection .close )
0 commit comments