Skip to content

Commit cad4e9b

Browse files
committed
feat: tortoise
1 parent 2df8d33 commit cad4e9b

File tree

15 files changed

+1652
-46
lines changed

15 files changed

+1652
-46
lines changed

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
8888
if self._accept_url_func:
8989
return self._accept_url_func(host_info, props)
9090
url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host)
91-
return RdsUrlType.RDS_INSTANCE == url_type
91+
return RdsUrlType.RDS_INSTANCE == url_type or RdsUrlType.RDS_WRITER_CLUSTER
9292

9393
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
9494
return strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS or strategy in self._accepted_strategies
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tortoise.backends.base.config_generator import DB_LOOKUP
16+
17+
# Register AWS MySQL backend
18+
DB_LOOKUP["aws-mysql"] = {
19+
"engine": "aws_advanced_python_wrapper.tortoise.backend.mysql",
20+
"vmap": {
21+
"path": "database",
22+
"hostname": "host",
23+
"port": "port",
24+
"username": "user",
25+
"password": "password",
26+
},
27+
"defaults": {"port": 3306, "charset": "utf8mb4", "sql_mode": "STRICT_TRANS_TABLES"},
28+
"cast": {
29+
"minsize": int,
30+
"maxsize": int,
31+
"connect_timeout": float,
32+
"echo": bool,
33+
"use_unicode": bool,
34+
"pool_recycle": int,
35+
"ssl": bool,
36+
},
37+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import mysql.connector
17+
from contextlib import asynccontextmanager
18+
from typing import Any, Callable, Generic
19+
20+
from tortoise.backends.base.client import BaseDBAsyncClient, T_conn, TransactionalDBClient, TransactionContext
21+
from tortoise.connection import connections
22+
from tortoise.exceptions import TransactionManagementError
23+
24+
from aws_advanced_python_wrapper import AwsWrapperConnection
25+
26+
27+
async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection:
28+
"""Create an AWS wrapper connection with async cursor support."""
29+
connection = await asyncio.to_thread(
30+
AwsWrapperConnection.connect,
31+
connect_func,
32+
**kwargs,
33+
)
34+
return AwsConnectionAsyncWrapper(connection)
35+
36+
37+
class AwsCursorAsyncWrapper:
38+
"""Wraps a sync cursor to provide async interface."""
39+
40+
def __init__(self, sync_cursor):
41+
self._cursor = sync_cursor
42+
43+
async def execute(self, query, params=None):
44+
"""Execute a query asynchronously."""
45+
return await asyncio.to_thread(self._cursor.execute, query, params)
46+
47+
async def executemany(self, query, params_list):
48+
"""Execute multiple queries asynchronously."""
49+
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
50+
51+
async def fetchall(self):
52+
"""Fetch all results asynchronously."""
53+
return await asyncio.to_thread(self._cursor.fetchall)
54+
55+
async def fetchone(self):
56+
"""Fetch one result asynchronously."""
57+
return await asyncio.to_thread(self._cursor.fetchone)
58+
59+
async def close(self):
60+
"""Close cursor asynchronously."""
61+
return await asyncio.to_thread(self._cursor.close)
62+
63+
def __getattr__(self, name):
64+
"""Delegate non-async attributes to the wrapped cursor."""
65+
return getattr(self._cursor, name)
66+
67+
68+
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
69+
"""AWS wrapper connection with async cursor support."""
70+
71+
def __init__(self, connection: AwsWrapperConnection):
72+
self._wrapped_connection = connection
73+
74+
@asynccontextmanager
75+
async def cursor(self):
76+
"""Create an async cursor context manager."""
77+
cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor)
78+
try:
79+
yield AwsCursorAsyncWrapper(cursor_obj)
80+
finally:
81+
await asyncio.to_thread(cursor_obj.close)
82+
83+
async def rollback(self):
84+
"""Rollback the current transaction."""
85+
return await asyncio.to_thread(self._wrapped_connection.rollback)
86+
87+
async def commit(self):
88+
"""Commit the current transaction."""
89+
return await asyncio.to_thread(self._wrapped_connection.commit)
90+
91+
async def set_autocommit(self, value: bool):
92+
"""Set autocommit mode."""
93+
return await asyncio.to_thread(lambda: setattr(self._wrapped_connection, 'autocommit', value))
94+
95+
def __getattr__(self, name):
96+
"""Delegate all other attributes/methods to the wrapped connection."""
97+
return getattr(self._wrapped_connection, name)
98+
99+
def __del__(self):
100+
"""Delegate cleanup to wrapped connection."""
101+
if hasattr(self, '_wrapped_connection'):
102+
# Let the wrapped connection handle its own cleanup
103+
pass
104+
105+
106+
class TortoiseAwsClientConnectionWrapper(Generic[T_conn]):
107+
"""Manages acquiring from and releasing connections to a pool."""
108+
109+
__slots__ = ("client", "connection", "_pool_init_lock", "connect_func", "with_db")
110+
111+
def __init__(
112+
self,
113+
client: BaseDBAsyncClient,
114+
pool_init_lock: asyncio.Lock,
115+
connect_func: Callable,
116+
with_db: bool = True
117+
) -> None:
118+
self.connect_func = connect_func
119+
self.client = client
120+
self.connection: T_conn | None = None
121+
self._pool_init_lock = pool_init_lock
122+
self.with_db = with_db
123+
124+
async def ensure_connection(self) -> None:
125+
"""Ensure the connection pool is initialized."""
126+
await self.client.create_connection(with_db=self.with_db)
127+
128+
async def __aenter__(self) -> T_conn:
129+
"""Acquire connection from pool."""
130+
await self.ensure_connection()
131+
self.connection = await ConnectWithAwsWrapper(self.connect_func, **self.client._template)
132+
return self.connection
133+
134+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
135+
"""Close connection and release back to pool."""
136+
if self.connection:
137+
await asyncio.to_thread(self.connection.close)
138+
139+
140+
class TortoiseAwsClientTransactionContext(TransactionContext):
141+
"""Transaction context that uses a pool to acquire connections."""
142+
143+
__slots__ = ("client", "connection_name", "token", "_pool_init_lock")
144+
145+
def __init__(self, client: TransactionalDBClient, pool_init_lock: asyncio.Lock) -> None:
146+
self.client = client
147+
self.connection_name = client.connection_name
148+
self._pool_init_lock = pool_init_lock
149+
150+
async def ensure_connection(self) -> None:
151+
"""Ensure the connection pool is initialized."""
152+
await self.client._parent.create_connection(with_db=True)
153+
154+
async def __aenter__(self) -> TransactionalDBClient:
155+
"""Enter transaction context."""
156+
await self.ensure_connection()
157+
158+
# Set the context variable so the current task sees a TransactionWrapper connection
159+
self.token = connections.set(self.connection_name, self.client)
160+
161+
# Create connection and begin transaction
162+
self.client._connection = await ConnectWithAwsWrapper(
163+
mysql.connector.Connect,
164+
**self.client._parent._template
165+
)
166+
await self.client.begin()
167+
return self.client
168+
169+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
170+
"""Exit transaction context with proper cleanup."""
171+
try:
172+
if not self.client._finalized:
173+
if exc_type:
174+
# Can't rollback a transaction that already failed
175+
if exc_type is not TransactionManagementError:
176+
await self.client.rollback()
177+
else:
178+
await self.client.commit()
179+
finally:
180+
connections.reset(self.token)
181+
await asyncio.to_thread(self.client._connection.close)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from .client import AwsMySQLClient
15+
16+
client_class = AwsMySQLClient

0 commit comments

Comments
 (0)