Skip to content

Commit 1a5d163

Browse files
committed
add tortoise integration tests
1 parent 2a5b151 commit 1a5d163

16 files changed

+1288
-62
lines changed

aws_advanced_python_wrapper/tortoise/backend/base/client.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AwsWrapperAsyncConnector:
2929
"""Factory class for creating AWS wrapper connections."""
3030

3131
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
32-
thread_name_prefix="AwsWrapperConnectorExecutor"
32+
thread_name_prefix="AwsWrapperAsyncExecutor"
3333
)
3434

3535
@staticmethod
@@ -141,19 +141,17 @@ def __del__(self):
141141
class TortoiseAwsClientConnectionWrapper(Generic[T_conn]):
142142
"""Manages acquiring from and releasing connections to a pool."""
143143

144-
__slots__ = ("client", "connection", "_pool_init_lock", "connect_func", "with_db")
144+
__slots__ = ("client", "connection", "connect_func", "with_db")
145145

146146
def __init__(
147147
self,
148148
client: BaseDBAsyncClient,
149-
pool_init_lock: asyncio.Lock,
150149
connect_func: Callable,
151150
with_db: bool = True
152151
) -> None:
153152
self.connect_func = connect_func
154153
self.client = client
155154
self.connection: T_conn | None = None
156-
self._pool_init_lock = pool_init_lock
157155
self.with_db = with_db
158156

159157
async def ensure_connection(self) -> None:
@@ -175,12 +173,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
175173
class TortoiseAwsClientTransactionContext(TransactionContext):
176174
"""Transaction context that uses a pool to acquire connections."""
177175

178-
__slots__ = ("client", "connection_name", "token", "_pool_init_lock")
176+
__slots__ = ("client", "connection_name", "token")
179177

180-
def __init__(self, client: TransactionalDBClient, pool_init_lock: asyncio.Lock) -> None:
178+
def __init__(self, client: TransactionalDBClient) -> None:
181179
self.client = client
182180
self.connection_name = client.connection_name
183-
self._pool_init_lock = pool_init_lock
184181

185182
async def ensure_connection(self) -> None:
186183
"""Ensure the connection pool is initialized."""

aws_advanced_python_wrapper/tortoise/backend/mysql/client.py

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import asyncio
1616
from functools import wraps
1717
from itertools import count
18-
from typing import Any, Callable, Coroutine, Dict, SupportsInt, TypeVar
18+
from typing import Any, Callable, Coroutine, Dict, List, Optional, SupportsInt, Tuple, TypeVar
1919

2020
import mysql.connector
2121
import sqlparse
@@ -37,7 +37,7 @@
3737
TransactionManagementError,
3838
)
3939

40-
from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager
40+
from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager, ConnectionProvider
4141
from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError
4242
from aws_advanced_python_wrapper.hostinfo import HostInfo
4343
from aws_advanced_python_wrapper.tortoise.backend.base.client import (
@@ -94,7 +94,7 @@ class AwsMySQLClient(BaseDBAsyncClient):
9494
support_for_posix_regex_queries=True,
9595
support_json_attributes=True,
9696
)
97-
_pool_initialized = False
97+
_provider: Optional[ConnectionProvider] = None
9898
_pool_init_class_lock = asyncio.Lock()
9999

100100
def __init__(
@@ -121,24 +121,19 @@ def __init__(
121121
# Extract MySQL-specific settings
122122
self.storage_engine = self.extra.pop("storage_engine", "innodb")
123123
self.charset = self.extra.pop("charset", "utf8mb4")
124-
self.pool_minsize = int(self.extra.pop("minsize", 1))
125-
self.pool_maxsize = int(self.extra.pop("maxsize", 5))
126124

127125
# Remove Tortoise-specific parameters
128126
self.extra.pop("connection_name", None)
129127
self.extra.pop("fetch_inserted", None)
130-
self.extra.pop("db", None)
131-
self.extra.pop("autocommit", None) # We need this to be true
128+
self.extra.pop("autocommit", None)
132129
self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES")
133130

134131
# Initialize connection templates
135132
self._init_connection_templates()
136133

137134
# Initialize state
138-
self._template = {}
139-
self._connection = None
140-
self._pool: SqlAlchemyTortoisePooledConnectionProvider = None
141-
self._pool_init_lock = asyncio.Lock()
135+
self._template: Dict[str, Any] = {}
136+
self._connection: Optional[Any] = None
142137

143138
def _init_connection_templates(self) -> None:
144139
"""Initialize connection templates for with/without database."""
@@ -154,39 +149,12 @@ def _init_connection_templates(self) -> None:
154149
self._template_with_db = {**base_template, "database": self.database}
155150
self._template_no_db = {**base_template, "database": None}
156151

157-
# Pool Management
158-
def _configure_pool(self, host_info: HostInfo, props: Dict[str, Any]) -> Dict[str, Any]:
159-
"""Configure connection pool settings."""
160-
return {"pool_size": self.pool_maxsize, "max_overflow": -1}
161-
162-
def _get_pool_key(self, host_info: HostInfo, props: Dict[str, Any]) -> str:
163-
"""Generate unique pool key for connection pooling."""
164-
url = host_info.url
165-
user = props["user"]
166-
db = props["database"]
167-
return f"{url}{user}{db}{self.connection_name}"
168-
169-
async def _init_pool_if_needed(self) -> None:
170-
"""Initialize connection pool only once across all instances."""
171-
if not AwsMySQLClient._pool_initialized:
172-
async with AwsMySQLClient._pool_init_class_lock:
173-
if not AwsMySQLClient._pool_initialized:
174-
self._pool = SqlAlchemyTortoisePooledConnectionProvider(
175-
pool_configurator=self._configure_pool,
176-
pool_mapping=self._get_pool_key,
177-
)
178-
ConnectionProviderManager.set_connection_provider(self._pool)
179-
AwsMySQLClient._pool_initialized = True
180-
181152
# Connection Management
182153
async def create_connection(self, with_db: bool) -> None:
183154
"""Initialize connection pool and configure database settings."""
184155
# Validate charset
185156
if self.charset.lower() not in [cs[0] for cs in MYSQL_CHARACTER_SETS if cs is not None]:
186157
raise DBConnectionError(f"Unknown character set: {self.charset}")
187-
188-
# Initialize connection pool only once
189-
await self._init_pool_if_needed()
190158

191159
# Set transaction support based on storage engine
192160
if self.storage_engine.lower() != "innodb":
@@ -195,8 +163,6 @@ async def create_connection(self, with_db: bool) -> None:
195163
# Set template based on database requirement
196164
self._template = self._template_with_db if with_db else self._template_no_db
197165

198-
logger.debug(f"Created connection pool {self._pool} with params: {self._template}")
199-
200166
async def close(self) -> None:
201167
"""Close connections - AWS wrapper handles cleanup internally."""
202168
pass
@@ -205,10 +171,10 @@ def acquire_connection(self):
205171
"""Acquire a connection from the pool."""
206172
return self._acquire_connection(with_db=True)
207173

208-
def _acquire_connection(self, with_db: bool):
174+
def _acquire_connection(self, with_db: bool) -> TortoiseAwsClientConnectionWrapper:
209175
"""Create connection wrapper for specified database mode."""
210176
return TortoiseAwsClientConnectionWrapper(
211-
self, self._pool_init_lock, mysql.connector.Connect, with_db=with_db
177+
self, mysql.connector.Connect, with_db=with_db
212178
)
213179

214180
# Database Operations
@@ -226,7 +192,7 @@ async def db_delete(self) -> None:
226192

227193
# Query Execution Methods
228194
@translate_exceptions
229-
async def execute_insert(self, query: str, values: list) -> int:
195+
async def execute_insert(self, query: str, values: List[Any]) -> int:
230196
"""Execute an INSERT query and return the last inserted row ID."""
231197
async with self.acquire_connection() as connection:
232198
logger.debug(f"{query}: {values}")
@@ -235,7 +201,7 @@ async def execute_insert(self, query: str, values: list) -> int:
235201
return cursor.lastrowid
236202

237203
@translate_exceptions
238-
async def execute_many(self, query: str, values: list[list]) -> None:
204+
async def execute_many(self, query: str, values: List[List[Any]]) -> None:
239205
"""Execute a query with multiple parameter sets."""
240206
async with self.acquire_connection() as connection:
241207
logger.debug(f"{query}: {values}")
@@ -245,7 +211,7 @@ async def execute_many(self, query: str, values: list[list]) -> None:
245211
else:
246212
await cursor.executemany(query, values)
247213

248-
async def _execute_many_with_transaction(self, cursor, connection, query: str, values: list[list]) -> None:
214+
async def _execute_many_with_transaction(self, cursor: Any, connection: Any, query: str, values: List[List[Any]]) -> None:
249215
"""Execute many queries within a transaction."""
250216
try:
251217
await connection.set_autocommit(False)
@@ -260,7 +226,7 @@ async def _execute_many_with_transaction(self, cursor, connection, query: str, v
260226
await connection.set_autocommit(True)
261227

262228
@translate_exceptions
263-
async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]:
229+
async def execute_query(self, query: str, values: Optional[List[Any]] = None) -> Tuple[int, List[Dict[str, Any]]]:
264230
"""Execute a query and return row count and results."""
265231
async with self.acquire_connection() as connection:
266232
logger.debug(f"{query}: {values}")
@@ -272,7 +238,7 @@ async def execute_query(self, query: str, values: list | None = None) -> tuple[i
272238
return cursor.rowcount, [dict(zip(fields, row)) for row in rows]
273239
return cursor.rowcount, []
274240

275-
async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]:
241+
async def execute_query_dict(self, query: str, values: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
276242
"""Execute a query and return only the results as dictionaries."""
277243
return (await self.execute_query(query, values))[1]
278244

@@ -297,7 +263,7 @@ async def _execute_script(self, query: str, with_db: bool) -> None:
297263
# Transaction Support
298264
def _in_transaction(self) -> TransactionContext:
299265
"""Create a new transaction context."""
300-
return TortoiseAwsClientTransactionContext(TransactionWrapper(self), self._pool_init_lock)
266+
return TortoiseAwsClientTransactionContext(TransactionWrapper(self))
301267

302268

303269
class TransactionWrapper(AwsMySQLClient, TransactionalDBClient):
@@ -307,7 +273,7 @@ def __init__(self, connection: AwsMySQLClient) -> None:
307273
self.connection_name = connection.connection_name
308274
self._connection = connection._connection
309275
self._lock = asyncio.Lock()
310-
self._savepoint: str | None = None
276+
self._savepoint: Optional[str] = None
311277
self._finalized: bool = False
312278
self._parent = connection
313279

@@ -373,14 +339,14 @@ async def release_savepoint(self):
373339
self._finalized = True
374340

375341
@translate_exceptions
376-
async def execute_many(self, query: str, values: list[list]) -> None:
342+
async def execute_many(self, query: str, values: List[List[Any]]) -> None:
377343
"""Execute many queries without autocommit handling (already in transaction)."""
378344
async with self.acquire_connection() as connection:
379345
logger.debug(f"{query}: {values}")
380346
async with connection.cursor() as cursor:
381347
await cursor.executemany(query, values)
382348

383349

384-
def _gen_savepoint_name(_c=count()) -> str:
350+
def _gen_savepoint_name(_c: count = count()) -> str:
385351
"""Generate a unique savepoint name."""
386352
return f"tortoise_savepoint_{next(_c)}"

aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from types import ModuleType
15-
from typing import Callable, Dict
15+
from typing import Any, Callable, Dict, Optional
1616

1717
from sqlalchemy import Dialect
1818
from sqlalchemy.dialects.mysql import mysqlconnector
1919
from sqlalchemy.dialects.postgresql import psycopg
2020

21+
from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager
2122
from aws_advanced_python_wrapper.database_dialect import DatabaseDialect
2223
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2324
from aws_advanced_python_wrapper.errors import AwsWrapperError
@@ -83,3 +84,34 @@ def _get_pool_dialect(self, driver_dialect: DriverDialect) -> Dialect:
8384
dialect.dbapi = driver_dialect.get_driver_module()
8485

8586
return dialect
87+
88+
89+
def setup_tortoise_connection_provider(
90+
pool_configurator: Optional[Callable[[HostInfo, Properties], Dict[str, Any]]] = None,
91+
pool_mapping: Optional[Callable[[HostInfo, Properties], str]] = None
92+
) -> SqlAlchemyTortoisePooledConnectionProvider:
93+
"""
94+
Helper function to set up and configure the Tortoise connection provider.
95+
96+
Args:
97+
pool_configurator: Optional function to configure pool settings.
98+
Defaults to basic pool configuration.
99+
pool_mapping: Optional function to generate pool keys.
100+
Defaults to basic pool key generation.
101+
102+
Returns:
103+
Configured SqlAlchemyTortoisePooledConnectionProvider instance.
104+
"""
105+
def default_pool_configurator(host_info: HostInfo, props: Properties) -> Dict[str, Any]:
106+
return {"pool_size": 5, "max_overflow": -1}
107+
108+
def default_pool_mapping(host_info: HostInfo, props: Properties) -> str:
109+
return f"{host_info.url}{props.get('user', '')}{props.get('database', '')}"
110+
111+
provider = SqlAlchemyTortoisePooledConnectionProvider(
112+
pool_configurator=pool_configurator or default_pool_configurator,
113+
pool_mapping=pool_mapping or default_pool_mapping
114+
)
115+
116+
ConnectionProviderManager.set_connection_provider(provider)
117+
return provider

tests/integration/container/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def pytest_runtest_setup(item):
6969
else:
7070
TestEnvironment.get_current().set_current_driver(None)
7171

72-
logger.info("Starting test preparation for: " + test_name)
72+
logger.info(f"Starting test preparation for: {test_name}")
7373

7474
segment: Optional[Segment] = None
7575
if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features():
@@ -107,7 +107,11 @@ def pytest_runtest_setup(item):
107107
logger.warning("conftest.ExceptionWhileObtainingInstanceIDs", ex)
108108
instances = list()
109109

110-
sleep(5)
110+
# Only sleep if we still need to retry
111+
if (len(instances) < request.get_num_of_instances()
112+
or len(instances) == 0
113+
or not rds_utility.is_db_instance_writer(instances[0])):
114+
sleep(5)
111115

112116
assert len(instances) > 0
113117
current_writer = instances[0]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)