From cad4e9b5574e36405d800dfcfb254fff7ac9f80d Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Mon, 24 Nov 2025 07:42:56 -0800 Subject: [PATCH 01/11] feat: tortoise --- .../sql_alchemy_connection_provider.py | 2 +- .../tortoise/__init__.py | 37 ++ .../tortoise/backend/__init__.py | 13 + .../tortoise/backend/base/__init__.py | 13 + .../tortoise/backend/base/client.py | 181 +++++++ .../tortoise/backend/mysql/__init__.py | 16 + .../tortoise/backend/mysql/client.py | 385 ++++++++++++++ .../tortoise/backend/mysql/executor.py | 21 + .../backend/mysql/schema_generator.py | 21 + aws_advanced_python_wrapper/tortoise/utils.py | 33 ++ aws_advanced_python_wrapper/wrapper.py | 5 + poetry.lock | 158 ++++-- pyproject.toml | 6 +- tests/unit/test_tortoise_base_client.py | 319 ++++++++++++ tests/unit/test_tortoise_mysql_client.py | 488 ++++++++++++++++++ 15 files changed, 1652 insertions(+), 46 deletions(-) create mode 100644 aws_advanced_python_wrapper/tortoise/__init__.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/__init__.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/base/__init__.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/base/client.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/mysql/client.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py create mode 100644 aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py create mode 100644 aws_advanced_python_wrapper/tortoise/utils.py create mode 100644 tests/unit/test_tortoise_base_client.py create mode 100644 tests/unit/test_tortoise_mysql_client.py diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index 2f3c746d..f2f57407 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -88,7 +88,7 @@ def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: if self._accept_url_func: return self._accept_url_func(host_info, props) url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) - return RdsUrlType.RDS_INSTANCE == url_type + return RdsUrlType.RDS_INSTANCE == url_type or RdsUrlType.RDS_WRITER_CLUSTER def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS or strategy in self._accepted_strategies diff --git a/aws_advanced_python_wrapper/tortoise/__init__.py b/aws_advanced_python_wrapper/tortoise/__init__.py new file mode 100644 index 00000000..ec7d689a --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/__init__.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise.backends.base.config_generator import DB_LOOKUP + +# Register AWS MySQL backend +DB_LOOKUP["aws-mysql"] = { + "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "vmap": { + "path": "database", + "hostname": "host", + "port": "port", + "username": "user", + "password": "password", + }, + "defaults": {"port": 3306, "charset": "utf8mb4", "sql_mode": "STRICT_TRANS_TABLES"}, + "cast": { + "minsize": int, + "maxsize": int, + "connect_timeout": float, + "echo": bool, + "use_unicode": bool, + "pool_recycle": int, + "ssl": bool, + }, +} \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/__init__.py b/aws_advanced_python_wrapper/tortoise/backend/__init__.py new file mode 100644 index 00000000..7e87cab9 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py b/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py new file mode 100644 index 00000000..7e87cab9 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/client.py b/aws_advanced_python_wrapper/tortoise/backend/base/client.py new file mode 100644 index 00000000..6c30e8cc --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/base/client.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import mysql.connector +from contextlib import asynccontextmanager +from typing import Any, Callable, Generic + +from tortoise.backends.base.client import BaseDBAsyncClient, T_conn, TransactionalDBClient, TransactionContext +from tortoise.connection import connections +from tortoise.exceptions import TransactionManagementError + +from aws_advanced_python_wrapper import AwsWrapperConnection + + +async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection: + """Create an AWS wrapper connection with async cursor support.""" + connection = await asyncio.to_thread( + AwsWrapperConnection.connect, + connect_func, + **kwargs, + ) + return AwsConnectionAsyncWrapper(connection) + + +class AwsCursorAsyncWrapper: + """Wraps a sync cursor to provide async interface.""" + + def __init__(self, sync_cursor): + self._cursor = sync_cursor + + async def execute(self, query, params=None): + """Execute a query asynchronously.""" + return await asyncio.to_thread(self._cursor.execute, query, params) + + async def executemany(self, query, params_list): + """Execute multiple queries asynchronously.""" + return await asyncio.to_thread(self._cursor.executemany, query, params_list) + + async def fetchall(self): + """Fetch all results asynchronously.""" + return await asyncio.to_thread(self._cursor.fetchall) + + async def fetchone(self): + """Fetch one result asynchronously.""" + return await asyncio.to_thread(self._cursor.fetchone) + + async def close(self): + """Close cursor asynchronously.""" + return await asyncio.to_thread(self._cursor.close) + + def __getattr__(self, name): + """Delegate non-async attributes to the wrapped cursor.""" + return getattr(self._cursor, name) + + +class AwsConnectionAsyncWrapper(AwsWrapperConnection): + """AWS wrapper connection with async cursor support.""" + + def __init__(self, connection: AwsWrapperConnection): + self._wrapped_connection = connection + + @asynccontextmanager + async def cursor(self): + """Create an async cursor context manager.""" + cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor) + try: + yield AwsCursorAsyncWrapper(cursor_obj) + finally: + await asyncio.to_thread(cursor_obj.close) + + async def rollback(self): + """Rollback the current transaction.""" + return await asyncio.to_thread(self._wrapped_connection.rollback) + + async def commit(self): + """Commit the current transaction.""" + return await asyncio.to_thread(self._wrapped_connection.commit) + + async def set_autocommit(self, value: bool): + """Set autocommit mode.""" + return await asyncio.to_thread(lambda: setattr(self._wrapped_connection, 'autocommit', value)) + + def __getattr__(self, name): + """Delegate all other attributes/methods to the wrapped connection.""" + return getattr(self._wrapped_connection, name) + + def __del__(self): + """Delegate cleanup to wrapped connection.""" + if hasattr(self, '_wrapped_connection'): + # Let the wrapped connection handle its own cleanup + pass + + +class TortoiseAwsClientConnectionWrapper(Generic[T_conn]): + """Manages acquiring from and releasing connections to a pool.""" + + __slots__ = ("client", "connection", "_pool_init_lock", "connect_func", "with_db") + + def __init__( + self, + client: BaseDBAsyncClient, + pool_init_lock: asyncio.Lock, + connect_func: Callable, + with_db: bool = True + ) -> None: + self.connect_func = connect_func + self.client = client + self.connection: T_conn | None = None + self._pool_init_lock = pool_init_lock + self.with_db = with_db + + async def ensure_connection(self) -> None: + """Ensure the connection pool is initialized.""" + await self.client.create_connection(with_db=self.with_db) + + async def __aenter__(self) -> T_conn: + """Acquire connection from pool.""" + await self.ensure_connection() + self.connection = await ConnectWithAwsWrapper(self.connect_func, **self.client._template) + return self.connection + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Close connection and release back to pool.""" + if self.connection: + await asyncio.to_thread(self.connection.close) + + +class TortoiseAwsClientTransactionContext(TransactionContext): + """Transaction context that uses a pool to acquire connections.""" + + __slots__ = ("client", "connection_name", "token", "_pool_init_lock") + + def __init__(self, client: TransactionalDBClient, pool_init_lock: asyncio.Lock) -> None: + self.client = client + self.connection_name = client.connection_name + self._pool_init_lock = pool_init_lock + + async def ensure_connection(self) -> None: + """Ensure the connection pool is initialized.""" + await self.client._parent.create_connection(with_db=True) + + async def __aenter__(self) -> TransactionalDBClient: + """Enter transaction context.""" + await self.ensure_connection() + + # Set the context variable so the current task sees a TransactionWrapper connection + self.token = connections.set(self.connection_name, self.client) + + # Create connection and begin transaction + self.client._connection = await ConnectWithAwsWrapper( + mysql.connector.Connect, + **self.client._parent._template + ) + await self.client.begin() + return self.client + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit transaction context with proper cleanup.""" + try: + if not self.client._finalized: + if exc_type: + # Can't rollback a transaction that already failed + if exc_type is not TransactionManagementError: + await self.client.rollback() + else: + await self.client.commit() + finally: + connections.reset(self.token) + await asyncio.to_thread(self.client._connection.close) \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py new file mode 100644 index 00000000..d66f3284 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .client import AwsMySQLClient + +client_class = AwsMySQLClient \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py new file mode 100644 index 00000000..d1307f80 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -0,0 +1,385 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from functools import wraps +from itertools import count +from typing import Any, Callable, Coroutine, Dict, SupportsInt, TypeVar + +import mysql.connector +import sqlparse +from mysql.connector import errors +from mysql.connector.charsets import MYSQL_CHARACTER_SETS +from pypika_tortoise import MySQLQuery +from tortoise import timezone +from tortoise.backends.base.client import ( + BaseDBAsyncClient, + Capabilities, + ConnectionWrapper, + NestedTransactionContext, + TransactionalDBClient, + TransactionContext, +) +from tortoise.exceptions import ( + DBConnectionError, + IntegrityError, + OperationalError, + TransactionManagementError, +) + +from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.sql_alchemy_connection_provider import SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.tortoise.backend.base.client import ( + TortoiseAwsClientConnectionWrapper, + TortoiseAwsClientTransactionContext, +) +from aws_advanced_python_wrapper.tortoise.backend.mysql.executor import AwsMySQLExecutor +from aws_advanced_python_wrapper.tortoise.backend.mysql.schema_generator import AwsMySQLSchemaGenerator +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.errors import AwsWrapperError + +logger = Logger(__name__) +T = TypeVar("T") +FuncType = Callable[..., Coroutine[None, None, T]] + + +def translate_exceptions(func: FuncType) -> FuncType: + """Decorator to translate MySQL connector exceptions to Tortoise exceptions.""" + @wraps(func) + async def translate_exceptions_(self, *args) -> T: + try: + try: + return await func(self, *args) + except AwsWrapperError as aws_err: + if aws_err.__cause__: + raise aws_err.__cause__ + raise + except errors.IntegrityError as exc: + raise IntegrityError(exc) + except ( + errors.OperationalError, + errors.ProgrammingError, + errors.DataError, + errors.InternalError, + errors.NotSupportedError, + errors.DatabaseError + ) as exc: + raise OperationalError(exc) + + return translate_exceptions_ + +class AwsMySQLClient(BaseDBAsyncClient): + """AWS Advanced Python Wrapper MySQL client for Tortoise ORM.""" + query_class = MySQLQuery + executor_class = AwsMySQLExecutor + schema_generator = AwsMySQLSchemaGenerator + capabilities = Capabilities( + dialect="mysql", + requires_limit=True, + inline_comment=True, + support_index_hint=True, + support_for_posix_regex_queries=True, + support_json_attributes=True, + ) + _pool_initialized = False + _pool_init_class_lock = asyncio.Lock() + + def __init__( + self, + *, + user: str, + password: str, + database: str, + host: str, + port: SupportsInt, + **kwargs, + ): + """Initialize AWS MySQL client with connection parameters.""" + super().__init__(**kwargs) + + # Basic connection parameters + self.user = user + self.password = password + self.database = database + self.host = host + self.port = int(port) + self.extra = kwargs.copy() + + # Extract MySQL-specific settings + self.storage_engine = self.extra.pop("storage_engine", "innodb") + self.charset = self.extra.pop("charset", "utf8mb4") + self.pool_minsize = int(self.extra.pop("minsize", 1)) + self.pool_maxsize = int(self.extra.pop("maxsize", 5)) + + # Remove Tortoise-specific parameters + self.extra.pop("connection_name", None) + self.extra.pop("fetch_inserted", None) + self.extra.pop("db", None) + self.extra.pop("autocommit", None) + self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") + + # Initialize connection templates + self._init_connection_templates() + + # Initialize state + self._template = {} + self._connection = None + self._pool = None + self._pool_init_lock = asyncio.Lock() + + def _init_connection_templates(self) -> None: + """Initialize connection templates for with/without database.""" + base_template = { + "user": self.user, + "password": self.password, + "host": self.host, + "port": self.port, + "autocommit": True, + **self.extra + } + + self._template_with_db = {**base_template, "database": self.database} + self._template_no_db = {**base_template, "database": None} + + # Pool Management + def _configure_pool(self, host_info: HostInfo, props: Dict[str, Any]) -> Dict[str, Any]: + """Configure connection pool settings.""" + return {"pool_size": self.pool_maxsize} + + @staticmethod + def _get_pool_key(host_info: HostInfo, props: Dict[str, Any]) -> str: + """Generate unique pool key for connection pooling.""" + url = host_info.url + user = props["user"] + db = props["database"] + return f"{url}{user}{db}" + + async def _init_pool_if_needed(self) -> None: + """Initialize connection pool only once across all instances.""" + if not AwsMySQLClient._pool_initialized: + async with AwsMySQLClient._pool_init_class_lock: + if not AwsMySQLClient._pool_initialized: + self._pool = SqlAlchemyPooledConnectionProvider( + pool_configurator=self._configure_pool, + pool_mapping=self._get_pool_key, + ) + ConnectionProviderManager.set_connection_provider(self._pool) + AwsMySQLClient._pool_initialized = True + + # Connection Management + async def create_connection(self, with_db: bool) -> None: + """Initialize connection pool and configure database settings.""" + # Validate charset + if self.charset.lower() not in [cs[0] for cs in MYSQL_CHARACTER_SETS if cs is not None]: + raise DBConnectionError(f"Unknown character set: {self.charset}") + + # Initialize connection pool only once + await self._init_pool_if_needed() + + # Set transaction support based on storage engine + if self.storage_engine.lower() != "innodb": + self.capabilities.__dict__["supports_transactions"] = False + + # Set template based on database requirement + self._template = self._template_with_db if with_db else self._template_no_db + + logger.debug("Created connection %s pool with params: %s", self._pool, self._template) + + async def close(self) -> None: + """Close connections - AWS wrapper handles cleanup internally.""" + pass + + def acquire_connection(self): + """Acquire a connection from the pool.""" + return self._acquire_connection(with_db=True) + + def _acquire_connection(self, with_db: bool): + """Create connection wrapper for specified database mode.""" + return TortoiseAwsClientConnectionWrapper( + self, self._pool_init_lock, mysql.connector.Connect, with_db=with_db + ) + + # Database Operations + async def db_create(self) -> None: + """Create the database.""" + await self.create_connection(with_db=False) + await self._execute_script(f"CREATE DATABASE {self.database};", False) + await self.close() + + async def db_delete(self) -> None: + """Delete the database.""" + await self.create_connection(with_db=False) + await self._execute_script(f"DROP DATABASE {self.database};", False) + await self.close() + + # Query Execution Methods + @translate_exceptions + async def execute_insert(self, query: str, values: list) -> int: + """Execute an INSERT query and return the last inserted row ID.""" + async with self.acquire_connection() as connection: + logger.debug("%s: %s", query, values) + async with connection.cursor() as cursor: + await cursor.execute(query, values) + return cursor.lastrowid + + @translate_exceptions + async def execute_many(self, query: str, values: list[list]) -> None: + """Execute a query with multiple parameter sets.""" + async with self.acquire_connection() as connection: + logger.debug("%s: %s", query, values) + async with connection.cursor() as cursor: + if self.capabilities.supports_transactions: + await self._execute_many_with_transaction(cursor, connection, query, values) + else: + await cursor.executemany(query, values) + + async def _execute_many_with_transaction(self, cursor, connection, query: str, values: list[list]) -> None: + """Execute many queries within a transaction.""" + try: + await connection.set_autocommit(False) + try: + await cursor.executemany(query, values) + except Exception: + await connection.rollback() + raise + else: + await connection.commit() + finally: + await connection.set_autocommit(True) + + @translate_exceptions + async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: + """Execute a query and return row count and results.""" + async with self.acquire_connection() as connection: + logger.debug("%s: %s", query, values) + async with connection.cursor() as cursor: + await cursor.execute(query, values) + rows = await cursor.fetchall() + if rows: + fields = [desc[0] for desc in cursor.description] + return cursor.rowcount, [dict(zip(fields, row)) for row in rows] + return cursor.rowcount, [] + + async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: + """Execute a query and return only the results as dictionaries.""" + return (await self.execute_query(query, values))[1] + + async def execute_script(self, query: str) -> None: + """Execute a script query.""" + await self._execute_script(query, True) + + @translate_exceptions + async def _execute_script(self, query: str, with_db: bool) -> None: + """Execute a multi-statement query by parsing and running statements sequentially.""" + async with self._acquire_connection(with_db) as connection: + logger.debug(query) + async with connection.cursor() as cursor: + # Parse multi-statement queries since MySQL Connector doesn't handle them well + statements = sqlparse.split(query) + for statement in statements: + statement = statement.strip() + if statement: + await cursor.execute(statement) + + # Transaction Support + def _in_transaction(self) -> TransactionContext: + """Create a new transaction context.""" + return TortoiseAwsClientTransactionContext(TransactionWrapper(self), self._pool_init_lock) + + +class TransactionWrapper(AwsMySQLClient, TransactionalDBClient): + """Transaction wrapper for AWS MySQL client.""" + + def __init__(self, connection: AwsMySQLClient) -> None: + self.connection_name = connection.connection_name + self._connection = connection._connection + self._lock = asyncio.Lock() + self._savepoint: str | None = None + self._finalized: bool = False + self._parent = connection + + def _in_transaction(self) -> TransactionContext: + """Create a nested transaction context.""" + return NestedTransactionContext(TransactionWrapper(self)) + + def acquire_connection(self): + """Acquire the transaction connection.""" + return ConnectionWrapper(self._lock, self) + + # Transaction Control Methods + @translate_exceptions + async def begin(self) -> None: + """Begin the transaction.""" + await self._connection.set_autocommit(False) + self._finalized = False + + async def commit(self) -> None: + """Commit the transaction.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + await self._connection.commit() + await self._connection.set_autocommit(True) + self._finalized = True + + async def rollback(self) -> None: + """Rollback the transaction.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + await self._connection.rollback() + await self._connection.set_autocommit(True) + self._finalized = True + + # Savepoint Management + @translate_exceptions + async def savepoint(self) -> None: + """Create a savepoint.""" + self._savepoint = _gen_savepoint_name() + async with self._connection.cursor() as cursor: + await cursor.execute(f"SAVEPOINT {self._savepoint}") + + async def savepoint_rollback(self): + """Rollback to the savepoint.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + async with self._connection.cursor() as cursor: + await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint}") + self._savepoint = None + self._finalized = True + + async def release_savepoint(self): + """Release the savepoint.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to release") + async with self._connection.cursor() as cursor: + await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint}") + self._savepoint = None + self._finalized = True + + @translate_exceptions + async def execute_many(self, query: str, values: list[list]) -> None: + """Execute many queries without autocommit handling (already in transaction).""" + async with self.acquire_connection() as connection: + logger.debug("%s: %s", query, values) + async with connection.cursor() as cursor: + await cursor.executemany(query, values) + + +def _gen_savepoint_name(_c=count()) -> str: + """Generate a unique savepoint name.""" + return f"tortoise_savepoint_{next(_c)}" \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py new file mode 100644 index 00000000..53015ee5 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py @@ -0,0 +1,21 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from aws_advanced_python_wrapper.tortoise.utils import load_mysql_module + +MySQLExecutor = load_mysql_module("executor.py", "MySQLExecutor") + +class AwsMySQLExecutor(MySQLExecutor): + """AWS MySQL Executor for Tortoise ORM.""" + pass diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py new file mode 100644 index 00000000..bc59e203 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py @@ -0,0 +1,21 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from aws_advanced_python_wrapper.tortoise.utils import load_mysql_module + +MySQLSchemaGenerator = load_mysql_module("schema_generator.py", "MySQLSchemaGenerator") + +class AwsMySQLSchemaGenerator(MySQLSchemaGenerator): + """AWS MySQL Executor for Tortoise ORM.""" + pass diff --git a/aws_advanced_python_wrapper/tortoise/utils.py b/aws_advanced_python_wrapper/tortoise/utils.py new file mode 100644 index 00000000..6fb8d995 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/utils.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.util +import sys +from pathlib import Path +from typing import Type + + +def load_mysql_module(module_file: str, class_name: str) -> Type: + """Load MySQL backend module without __init__.py.""" + import tortoise + tortoise_path = Path(tortoise.__file__).parent + module_path = tortoise_path / "backends" / "mysql" / module_file + module_name = f"tortoise.backends.mysql.{module_file[:-3]}" + + if module_name not in sys.modules: + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + return getattr(sys.modules[module_name], class_name) diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index e04d293f..b2835ba6 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -264,6 +264,11 @@ def rowcount(self) -> int: @property def arraysize(self) -> int: return self.target_cursor.arraysize + + # Optional for PEP249 + @property + def lastrowid(self) -> int: + return self.target_cursor.lastrowid def close(self) -> None: self._plugin_manager.execute(self.target_cursor, "Cursor.close", diff --git a/poetry.lock b/poetry.lock index 5653d237..a00e1521 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,18 +1,24 @@ # This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] -name = "astor" -version = "0.8.1" -description = "Read/rewrite/write Python ASTs" +name = "aiosqlite" +version = "0.21.0" +description = "asyncio bridge to the standard sqlite3 module" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" -groups = ["dev"] -markers = "python_version < \"3.9\"" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "astor-0.8.1-py2.py3-none-any.whl", hash = "sha256:070a54e890cefb5b3739d19f30f5a5ec840ffc9c50ffa7d23cc9fc1a38ebbfc5"}, - {file = "astor-0.8.1.tar.gz", hash = "sha256:6a6effda93f4e1ce9f618779b2dd1d9d84f1e32812c23a29b3fff6fd7f63fa5e"}, + {file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"}, + {file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"}, ] +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"] + [[package]] name = "aws-xray-sdk" version = "2.15.0" @@ -29,36 +35,6 @@ files = [ botocore = ">=1.11.3" wrapt = "*" -[[package]] -name = "backports-zoneinfo" -version = "0.2.1" -description = "Backport of the standard library zoneinfo module" -optional = false -python-versions = ">=3.6" -groups = ["dev", "test"] -markers = "python_version < \"3.9\"" -files = [ - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, - {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, -] - -[package.extras] -tzdata = ["tzdata"] - [[package]] name = "beautifulsoup4" version = "4.12.3" @@ -483,7 +459,6 @@ files = [ ] [package.dependencies] -astor = {version = "*", markers = "python_version < \"3.9\""} classify-imports = "*" flake8 = "*" @@ -720,6 +695,18 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "iso8601" +version = "2.1.0" +description = "Simple module to parse ISO 8601 dates" +optional = false +python-versions = ">=3.7,<4.0" +groups = ["main"] +files = [ + {file = "iso8601-2.1.0-py3-none-any.whl", hash = "sha256:aac4145c4dcb66ad8b648a02830f5e2ff6c24af20f4f482689be402db2429242"}, + {file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"}, +] + [[package]] name = "isort" version = "5.13.2" @@ -1205,7 +1192,6 @@ files = [ ] [package.dependencies] -"backports.zoneinfo" = {version = ">=0.2.0", markers = "python_version < \"3.9\""} typing-extensions = {version = ">=4.6", markers = "python_version < \"3.13\""} tzdata = {version = "*", markers = "sys_platform == \"win32\""} @@ -1324,6 +1310,18 @@ files = [ {file = "pyflakes-3.1.0.tar.gz", hash = "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc"}, ] +[[package]] +name = "pypika-tortoise" +version = "0.6.2" +description = "Forked from pypika and streamline just for tortoise-orm" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pypika_tortoise-0.6.2-py3-none-any.whl", hash = "sha256:425462b02ede0a5ed7b812ec12427419927ed6b19282c55667d1cbc9a440d3cb"}, + {file = "pypika_tortoise-0.6.2.tar.gz", hash = "sha256:f95ab59d9b6454db2e8daa0934728458350a1f3d56e81d9d1debc8eebeff26b3"}, +] + [[package]] name = "pytest" version = "7.4.4" @@ -1347,6 +1345,25 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +groups = ["test"] +files = [ + {file = "pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b"}, + {file = "pytest_asyncio-0.21.2.tar.gz", hash = "sha256:d67738fc232b94b326b9d060750beb16e0074210b98dd8b58a5239fa2a154f45"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-html" version = "4.1.1" @@ -1435,6 +1452,18 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytz" +version = "2025.2" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, + {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, +] + [[package]] name = "requests" version = "2.32.4" @@ -1607,6 +1636,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlparse" +version = "0.4.4" +description = "A non-validating SQL parser." +optional = false +python-versions = ">=3.5" +groups = ["main"] +files = [ + {file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"}, + {file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"}, +] + +[package.extras] +dev = ["build", "flake8"] +doc = ["sphinx"] +test = ["pytest", "pytest-cov"] + [[package]] name = "tabulate" version = "0.9.0" @@ -1647,6 +1693,32 @@ files = [ {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] +[[package]] +name = "tortoise-orm" +version = "0.25.1" +description = "Easy async ORM for python, built with relations in mind" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "tortoise_orm-0.25.1-py3-none-any.whl", hash = "sha256:df0ef7e06eb0650a7e5074399a51ee6e532043308c612db2cac3882486a3fd9f"}, + {file = "tortoise_orm-0.25.1.tar.gz", hash = "sha256:4d5bfd13d5750935ffe636a6b25597c5c8f51c47e5b72d7509d712eda1a239fe"}, +] + +[package.dependencies] +aiosqlite = ">=0.16.0,<1.0.0" +iso8601 = {version = ">=2.1.0,<3.0.0", markers = "python_version < \"4.0\""} +pypika-tortoise = {version = ">=0.6.1,<1.0.0", markers = "python_version < \"4.0\""} +pytz = "*" + +[package.extras] +accel = ["ciso8601 ; sys_platform != \"win32\" and implementation_name == \"cpython\"", "orjson", "uvloop ; sys_platform != \"win32\" and implementation_name == \"cpython\""] +aiomysql = ["aiomysql"] +asyncmy = ["asyncmy (>=0.2.8,<1.0.0) ; python_version < \"4.0\""] +asyncodbc = ["asyncodbc (>=0.1.1,<1.0.0) ; python_version < \"4.0\""] +asyncpg = ["asyncpg"] +psycopg = ["psycopg[binary,pool] (>=3.0.12,<4.0.0)"] + [[package]] name = "toxiproxy-python" version = "0.1.1" @@ -2160,7 +2232,7 @@ description = "HTTP library with thread-safe connection pooling, file post, and optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" groups = ["main", "test"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"}, {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"}, @@ -2292,5 +2364,5 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" -python-versions = "^3.8.1" -content-hash = "4d6dc3f0080c62b28f78487a4cbefa724da3af5e9fd100a6abd376bcae3baadf" +python-versions = "^3.9" +content-hash = "e290dd3792a38a361c52cf353dbd00f5eeb850beef88339648991b41eeddb883" diff --git a/pyproject.toml b/pyproject.toml index ae8796e2..bb516992 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] [tool.poetry.dependencies] -python = "^3.8.1" +python = "^3.9" resourcebundle = "2.1.0" boto3 = "^1.34.111" toml = "^0.10.2" @@ -32,6 +32,8 @@ types_aws_xray_sdk = "^2.13.0" opentelemetry-api = "^1.22.0" opentelemetry-sdk = "^1.22.0" requests = "^2.32.2" +tortoise-orm = "^0.25.1" +sqlparse = "^0.4.4" [tool.poetry.group.dev.dependencies] mypy = "^1.9.0" @@ -63,6 +65,7 @@ mysql-connector-python = "9.0.0" opentelemetry-exporter-otlp = "^1.22.0" opentelemetry-exporter-otlp-proto-grpc = "^1.22.0" opentelemetry-sdk-extension-aws = "^2.0.1" +pytest-asyncio = "^0.21.0" [tool.isort] sections = "FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER" @@ -77,4 +80,3 @@ filterwarnings = [ 'ignore:cache could not write path', 'ignore:could not create cache path' ] - diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py new file mode 100644 index 00000000..17c9e592 --- /dev/null +++ b/tests/unit/test_tortoise_base_client.py @@ -0,0 +1,319 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from aws_advanced_python_wrapper.tortoise.backend.base.client import ( + AwsCursorAsyncWrapper, + AwsConnectionAsyncWrapper, + TortoiseAwsClientConnectionWrapper, + TortoiseAwsClientTransactionContext, + ConnectWithAwsWrapper, +) + + +class TestAwsCursorAsyncWrapper: + def test_init(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + assert wrapper._cursor == mock_cursor + + @pytest.mark.asyncio + async def test_execute(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "result" + result = await wrapper.execute("SELECT 1", ["param"]) + + mock_to_thread.assert_called_once_with(mock_cursor.execute, "SELECT 1", ["param"]) + assert result == "result" + + @pytest.mark.asyncio + async def test_executemany(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "result" + result = await wrapper.executemany("INSERT", [["param1"], ["param2"]]) + + mock_to_thread.assert_called_once_with(mock_cursor.executemany, "INSERT", [["param1"], ["param2"]]) + assert result == "result" + + @pytest.mark.asyncio + async def test_fetchall(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = [("row1",), ("row2",)] + result = await wrapper.fetchall() + + mock_to_thread.assert_called_once_with(mock_cursor.fetchall) + assert result == [("row1",), ("row2",)] + + @pytest.mark.asyncio + async def test_fetchone(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = ("row1",) + result = await wrapper.fetchone() + + mock_to_thread.assert_called_once_with(mock_cursor.fetchone) + assert result == ("row1",) + + @pytest.mark.asyncio + async def test_close(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + await wrapper.close() + mock_to_thread.assert_called_once_with(mock_cursor.close) + + def test_getattr(self): + mock_cursor = MagicMock() + mock_cursor.rowcount = 5 + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + assert wrapper.rowcount == 5 + + +class TestAwsConnectionAsyncWrapper: + def test_init(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + assert wrapper._wrapped_connection == mock_connection + + @pytest.mark.asyncio + async def test_cursor_context_manager(self): + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.side_effect = [mock_cursor, None] # cursor creation, then close + + async with wrapper.cursor() as cursor: + assert isinstance(cursor, AwsCursorAsyncWrapper) + assert cursor._cursor == mock_cursor + + assert mock_to_thread.call_count == 2 + + @pytest.mark.asyncio + async def test_rollback(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "rollback_result" + result = await wrapper.rollback() + + mock_to_thread.assert_called_once_with(mock_connection.rollback) + assert result == "rollback_result" + + @pytest.mark.asyncio + async def test_commit(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "commit_result" + result = await wrapper.commit() + + mock_to_thread.assert_called_once_with(mock_connection.commit) + assert result == "commit_result" + + @pytest.mark.asyncio + async def test_set_autocommit(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + await wrapper.set_autocommit(True) + + mock_to_thread.assert_called_once() + # Verify the lambda function sets autocommit + lambda_func = mock_to_thread.call_args[0][0] + lambda_func() + assert mock_connection.autocommit == True + + def test_getattr(self): + mock_connection = MagicMock() + mock_connection.some_attr = "test_value" + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + assert wrapper.some_attr == "test_value" + + +class TestTortoiseAwsClientConnectionWrapper: + def test_init(self): + mock_client = MagicMock() + mock_lock = asyncio.Lock() + mock_connect_func = MagicMock() + + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_lock, mock_connect_func) + + assert wrapper.client == mock_client + assert wrapper._pool_init_lock == mock_lock + assert wrapper.connect_func == mock_connect_func + assert wrapper.connection is None + + @pytest.mark.asyncio + async def test_ensure_connection(self): + mock_client = MagicMock() + mock_client.create_connection = AsyncMock() + mock_lock = asyncio.Lock() + + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_lock, MagicMock()) + + await wrapper.ensure_connection() + mock_client.create_connection.assert_called_once_with(with_db=True) + + @pytest.mark.asyncio + async def test_context_manager(self): + mock_client = MagicMock() + mock_client._template = {"host": "localhost"} + mock_client.create_connection = AsyncMock() + mock_lock = asyncio.Lock() + mock_connect_func = MagicMock() + mock_connection = MagicMock() + + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_lock, mock_connect_func) + + with patch('aws_advanced_python_wrapper.tortoise.backend.base.client.ConnectWithAwsWrapper') as mock_connect: + mock_connect.return_value = mock_connection + + with patch('asyncio.to_thread') as mock_to_thread: + async with wrapper as conn: + assert conn == mock_connection + assert wrapper.connection == mock_connection + + # Verify close was called on exit + mock_to_thread.assert_called_once_with(mock_connection.close) + + +class TestTortoiseAwsClientTransactionContext: + def test_init(self): + mock_client = MagicMock() + mock_client.connection_name = "test_conn" + mock_lock = asyncio.Lock() + + context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + + assert context.client == mock_client + assert context.connection_name == "test_conn" + assert context._pool_init_lock == mock_lock + + @pytest.mark.asyncio + async def test_ensure_connection(self): + mock_client = MagicMock() + mock_client._parent.create_connection = AsyncMock() + mock_lock = asyncio.Lock() + + context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + + await context.ensure_connection() + mock_client._parent.create_connection.assert_called_once_with(with_db=True) + + @pytest.mark.asyncio + async def test_context_manager_commit(self): + mock_client = MagicMock() + mock_client._parent._template = {"host": "localhost"} + mock_client._parent.create_connection = AsyncMock() + mock_client.connection_name = "test_conn" + mock_client._finalized = False + mock_client.begin = AsyncMock() + mock_client.commit = AsyncMock() + mock_lock = asyncio.Lock() + mock_connection = MagicMock() + + context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + + mock_to_thread_ref = None + + with patch('aws_advanced_python_wrapper.tortoise.backend.base.client.ConnectWithAwsWrapper') as mock_connect: + mock_connect.return_value = mock_connection + with patch('tortoise.connection.connections') as mock_connections: + mock_connections.set.return_value = "test_token" + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread_ref = mock_to_thread + async with context as client: + assert client == mock_client + assert mock_client._connection == mock_connection + + # Verify commit was called and connection was closed + mock_client.commit.assert_called_once() + mock_to_thread_ref.assert_called_once_with(mock_connection.close) + + @pytest.mark.asyncio + async def test_context_manager_rollback_on_exception(self): + mock_client = MagicMock() + mock_client._parent._template = {"host": "localhost"} + mock_client._parent.create_connection = AsyncMock() + mock_client.connection_name = "test_conn" + mock_client._finalized = False + mock_client.begin = AsyncMock() + mock_client.rollback = AsyncMock() + mock_lock = asyncio.Lock() + mock_connection = MagicMock() + + context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + + mock_to_thread_ref = None + + with patch('aws_advanced_python_wrapper.tortoise.backend.base.client.ConnectWithAwsWrapper') as mock_connect: + mock_connect.return_value = mock_connection + with patch('tortoise.connection.connections') as mock_connections: + mock_connections.set.return_value = "test_token" + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread_ref = mock_to_thread + try: + async with context as client: + raise ValueError("Test exception") + except ValueError: + pass + + # Verify rollback was called and connection was closed + mock_client.rollback.assert_called_once() + mock_to_thread_ref.assert_called_once_with(mock_connection.close) + + +class TestConnectWithAwsWrapper: + @pytest.mark.asyncio + async def test_connect_with_aws_wrapper(self): + mock_connect_func = MagicMock() + mock_connection = MagicMock() + kwargs = {"host": "localhost", "user": "test"} + + with patch('aws_advanced_python_wrapper.AwsWrapperConnection.connect') as mock_aws_connect: + mock_aws_connect.return_value = mock_connection + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = mock_connection + + result = await ConnectWithAwsWrapper(mock_connect_func, **kwargs) + + mock_to_thread.assert_called_once() + assert isinstance(result, AwsConnectionAsyncWrapper) + assert result._wrapped_connection == mock_connection \ No newline at end of file diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py new file mode 100644 index 00000000..82c4dffa --- /dev/null +++ b/tests/unit/test_tortoise_mysql_client.py @@ -0,0 +1,488 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from mysql.connector import errors +from tortoise.exceptions import DBConnectionError, IntegrityError, OperationalError, TransactionManagementError + +from aws_advanced_python_wrapper.tortoise.backend.mysql.client import ( + AwsMySQLClient, + TransactionWrapper, + translate_exceptions, + _gen_savepoint_name, +) + + +class TestTranslateExceptions: + @pytest.mark.asyncio + async def test_translate_operational_error(self): + @translate_exceptions + async def test_func(self): + raise errors.OperationalError("Test error") + + with pytest.raises(OperationalError): + await test_func(None) + + @pytest.mark.asyncio + async def test_translate_integrity_error(self): + @translate_exceptions + async def test_func(self): + raise errors.IntegrityError("Test error") + + with pytest.raises(IntegrityError): + await test_func(None) + + @pytest.mark.asyncio + async def test_no_exception(self): + @translate_exceptions + async def test_func(self): + return "success" + + result = await test_func(None) + assert result == "success" + + +class TestAwsMySQLClient: + def test_init(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="utf8mb4", + storage_engine="InnoDB", + minsize=2, + maxsize=10, + connection_name="test_conn" + ) + + assert client.user == "test_user" + assert client.password == "test_pass" + assert client.database == "test_db" + assert client.host == "localhost" + assert client.port == 3306 + assert client.charset == "utf8mb4" + assert client.storage_engine == "InnoDB" + assert client.pool_minsize == 2 + assert client.pool_maxsize == 10 + + def test_init_defaults(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + assert client.charset == "utf8mb4" + assert client.storage_engine == "innodb" + assert client.pool_minsize == 1 + assert client.pool_maxsize == 5 + + def test_configure_pool(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + maxsize=15, + connection_name="test_conn" + ) + + mock_host_info = MagicMock() + mock_props = {} + + result = client._configure_pool(mock_host_info, mock_props) + assert result == {"pool_size": 15} + + def test_get_pool_key(self): + mock_host_info = MagicMock() + mock_host_info.url = "mysql://localhost:3306" + mock_props = {"user": "testuser", "database": "testdb"} + + result = AwsMySQLClient._get_pool_key(mock_host_info, mock_props) + assert result == "mysql://localhost:3306testusertestdb" + + @pytest.mark.asyncio + async def test_create_connection_invalid_charset(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="invalid_charset", + connection_name="test_conn" + ) + + with pytest.raises(DBConnectionError, match="Unknown character set"): + await client.create_connection(with_db=True) + + @pytest.mark.asyncio + async def test_create_connection_success(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + with patch.object(client, '_init_pool_if_needed') as mock_init_pool: + mock_init_pool.return_value = None + + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + await client.create_connection(with_db=True) + + assert client._template["user"] == "test_user" + assert client._template["database"] == "test_db" + assert client._template["autocommit"] is True + mock_init_pool.assert_called_once() + + @pytest.mark.asyncio + async def test_close(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + # close() method now does nothing (AWS wrapper handles cleanup) + await client.close() + # No assertions needed since method is a no-op + + def test_acquire_connection(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + connection_wrapper = client.acquire_connection() + assert connection_wrapper.client == client + + @pytest.mark.asyncio + async def test_execute_insert(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.lastrowid = 123 + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + mock_cursor.execute = AsyncMock() + + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + result = await client.execute_insert("INSERT INTO test VALUES (?)", ["value"]) + + assert result == 123 + mock_cursor.execute.assert_called_once_with("INSERT INTO test VALUES (?)", ["value"]) + + @pytest.mark.asyncio + async def test_execute_many_with_transactions(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + storage_engine="innodb" + ) + + mock_connection = MagicMock() + mock_cursor = MagicMock() + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + with patch.object(client, '_execute_many_with_transaction') as mock_execute_many_tx: + mock_execute_many_tx.return_value = None + + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + await client.execute_many("INSERT INTO test VALUES (?)", [["val1"], ["val2"]]) + + mock_execute_many_tx.assert_called_once_with( + mock_cursor, mock_connection, "INSERT INTO test VALUES (?)", [["val1"], ["val2"]] + ) + + @pytest.mark.asyncio + async def test_execute_query(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.rowcount = 2 + mock_cursor.description = [("id",), ("name",)] + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.fetchall = AsyncMock(return_value=[(1, "test"), (2, "test2")]) + + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + rowcount, results = await client.execute_query("SELECT * FROM test") + + assert rowcount == 2 + assert len(results) == 2 + assert results[0] == {"id": 1, "name": "test"} + + @pytest.mark.asyncio + async def test_execute_query_dict(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + with patch.object(client, 'execute_query') as mock_execute_query: + mock_execute_query.return_value = (2, [{"id": 1}, {"id": 2}]) + + results = await client.execute_query_dict("SELECT * FROM test") + + assert results == [{"id": 1}, {"id": 2}] + + def test_in_transaction(self): + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + transaction_context = client._in_transaction() + assert transaction_context is not None + + +class TestTransactionWrapper: + def test_init(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + + assert wrapper.connection_name == "test_conn" + assert wrapper._parent == parent_client + assert wrapper._finalized is False + assert wrapper._savepoint is None + + @pytest.mark.asyncio + async def test_begin(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.begin() + + mock_connection.set_autocommit.assert_called_once_with(False) + assert wrapper._finalized is False + + @pytest.mark.asyncio + async def test_commit(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.commit = AsyncMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.commit() + + mock_connection.commit.assert_called_once() + mock_connection.set_autocommit.assert_called_once_with(True) + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_commit_already_finalized(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + wrapper._finalized = True + + with pytest.raises(TransactionManagementError, match="Transaction already finalized"): + await wrapper.commit() + + @pytest.mark.asyncio + async def test_rollback(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.rollback = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.rollback() + + mock_connection.rollback.assert_called_once() + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_savepoint(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute = AsyncMock() + + wrapper._connection = mock_connection + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + await wrapper.savepoint() + + assert wrapper._savepoint is not None + assert wrapper._savepoint.startswith("tortoise_savepoint_") + mock_cursor.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_savepoint_rollback(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + wrapper._savepoint = "test_savepoint" + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute = AsyncMock() + + wrapper._connection = mock_connection + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + await wrapper.savepoint_rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK TO SAVEPOINT test_savepoint") + assert wrapper._savepoint is None + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_savepoint_rollback_no_savepoint(self): + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + wrapper._savepoint = None + + with pytest.raises(TransactionManagementError, match="No savepoint to rollback to"): + await wrapper.savepoint_rollback() + + +class TestGenSavepointName: + def test_gen_savepoint_name(self): + name1 = _gen_savepoint_name() + name2 = _gen_savepoint_name() + + assert name1.startswith("tortoise_savepoint_") + assert name2.startswith("tortoise_savepoint_") + assert name1 != name2 # Should be unique \ No newline at end of file From d4a1e1e6ed8a8c0620e08acdaccb486666fec8a4 Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Mon, 1 Dec 2025 12:30:45 -0800 Subject: [PATCH 02/11] Add thread pool executors --- .../tortoise/backend/base/client.py | 77 ++++++++++++++----- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/client.py b/aws_advanced_python_wrapper/tortoise/backend/base/client.py index 6c30e8cc..cb81e18d 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/base/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/base/client.py @@ -14,6 +14,7 @@ import asyncio import mysql.connector +from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import Any, Callable, Generic @@ -24,41 +25,67 @@ from aws_advanced_python_wrapper import AwsWrapperConnection -async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection: - """Create an AWS wrapper connection with async cursor support.""" - connection = await asyncio.to_thread( - AwsWrapperConnection.connect, - connect_func, - **kwargs, +class AwsWrapperAsyncConnector: + """Factory class for creating AWS wrapper connections.""" + + _executor: ThreadPoolExecutor = ThreadPoolExecutor( + thread_name_prefix="AwsWrapperConnectorExecutor" ) - return AwsConnectionAsyncWrapper(connection) + + @staticmethod + async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection: + """Create an AWS wrapper connection with async cursor support.""" + loop = asyncio.get_event_loop() + connection = await loop.run_in_executor( + AwsWrapperAsyncConnector._executor, + lambda: AwsWrapperConnection.connect(connect_func, **kwargs) + ) + return AwsConnectionAsyncWrapper(connection) + + @staticmethod + async def CloseAwsWrapper(connection: AwsWrapperConnection) -> None: + """Close an AWS wrapper connection asynchronously.""" + loop = asyncio.get_event_loop() + await loop.run_in_executor( + AwsWrapperAsyncConnector._executor, + connection.close + ) class AwsCursorAsyncWrapper: """Wraps a sync cursor to provide async interface.""" + _executor: ThreadPoolExecutor = ThreadPoolExecutor( + thread_name_prefix="AwsCursorAsyncWrapperExecutor" + ) + def __init__(self, sync_cursor): self._cursor = sync_cursor async def execute(self, query, params=None): """Execute a query asynchronously.""" - return await asyncio.to_thread(self._cursor.execute, query, params) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._cursor.execute, query, params) async def executemany(self, query, params_list): """Execute multiple queries asynchronously.""" - return await asyncio.to_thread(self._cursor.executemany, query, params_list) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._cursor.executemany, query, params_list) async def fetchall(self): """Fetch all results asynchronously.""" - return await asyncio.to_thread(self._cursor.fetchall) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._cursor.fetchall) async def fetchone(self): """Fetch one result asynchronously.""" - return await asyncio.to_thread(self._cursor.fetchone) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._cursor.fetchone) async def close(self): """Close cursor asynchronously.""" - return await asyncio.to_thread(self._cursor.close) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._cursor.close) def __getattr__(self, name): """Delegate non-async attributes to the wrapped cursor.""" @@ -68,29 +95,37 @@ def __getattr__(self, name): class AwsConnectionAsyncWrapper(AwsWrapperConnection): """AWS wrapper connection with async cursor support.""" + _executor: ThreadPoolExecutor = ThreadPoolExecutor( + thread_name_prefix="AwsConnectionAsyncWrapperExecutor" + ) + def __init__(self, connection: AwsWrapperConnection): self._wrapped_connection = connection @asynccontextmanager async def cursor(self): """Create an async cursor context manager.""" - cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor) + loop = asyncio.get_event_loop() + cursor_obj = await loop.run_in_executor(self._executor, self._wrapped_connection.cursor) try: yield AwsCursorAsyncWrapper(cursor_obj) finally: - await asyncio.to_thread(cursor_obj.close) + await loop.run_in_executor(self._executor, cursor_obj.close) async def rollback(self): """Rollback the current transaction.""" - return await asyncio.to_thread(self._wrapped_connection.rollback) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._wrapped_connection.rollback) async def commit(self): """Commit the current transaction.""" - return await asyncio.to_thread(self._wrapped_connection.commit) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._wrapped_connection.commit) async def set_autocommit(self, value: bool): """Set autocommit mode.""" - return await asyncio.to_thread(lambda: setattr(self._wrapped_connection, 'autocommit', value)) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: setattr(self._wrapped_connection, 'autocommit', value)) def __getattr__(self, name): """Delegate all other attributes/methods to the wrapped connection.""" @@ -128,13 +163,13 @@ async def ensure_connection(self) -> None: async def __aenter__(self) -> T_conn: """Acquire connection from pool.""" await self.ensure_connection() - self.connection = await ConnectWithAwsWrapper(self.connect_func, **self.client._template) + self.connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(self.connect_func, **self.client._template) return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Close connection and release back to pool.""" if self.connection: - await asyncio.to_thread(self.connection.close) + await AwsWrapperAsyncConnector.CloseAwsWrapper(self.connection) class TortoiseAwsClientTransactionContext(TransactionContext): @@ -159,7 +194,7 @@ async def __aenter__(self) -> TransactionalDBClient: self.token = connections.set(self.connection_name, self.client) # Create connection and begin transaction - self.client._connection = await ConnectWithAwsWrapper( + self.client._connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper( mysql.connector.Connect, **self.client._parent._template ) @@ -178,4 +213,4 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.client.commit() finally: connections.reset(self.token) - await asyncio.to_thread(self.client._connection.close) \ No newline at end of file + await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection) \ No newline at end of file From d224974d83c44088bbcda2c444b7ca3044f229c6 Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Tue, 2 Dec 2025 08:47:08 -0800 Subject: [PATCH 03/11] fix logs --- .../tortoise/backend/mysql/client.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py index d1307f80..fbf9f536 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -156,7 +156,7 @@ def _init_connection_templates(self) -> None: # Pool Management def _configure_pool(self, host_info: HostInfo, props: Dict[str, Any]) -> Dict[str, Any]: """Configure connection pool settings.""" - return {"pool_size": self.pool_maxsize} + return {"pool_size": self.pool_maxsize, "max_overflow": -1} @staticmethod def _get_pool_key(host_info: HostInfo, props: Dict[str, Any]) -> str: @@ -195,7 +195,7 @@ async def create_connection(self, with_db: bool) -> None: # Set template based on database requirement self._template = self._template_with_db if with_db else self._template_no_db - logger.debug("Created connection %s pool with params: %s", self._pool, self._template) + logger.debug(f"Created connection pool {self._pool} with params: {self._template}") async def close(self) -> None: """Close connections - AWS wrapper handles cleanup internally.""" @@ -229,7 +229,7 @@ async def db_delete(self) -> None: async def execute_insert(self, query: str, values: list) -> int: """Execute an INSERT query and return the last inserted row ID.""" async with self.acquire_connection() as connection: - logger.debug("%s: %s", query, values) + logger.debug(f"{query}: {values}") async with connection.cursor() as cursor: await cursor.execute(query, values) return cursor.lastrowid @@ -238,7 +238,7 @@ async def execute_insert(self, query: str, values: list) -> int: async def execute_many(self, query: str, values: list[list]) -> None: """Execute a query with multiple parameter sets.""" async with self.acquire_connection() as connection: - logger.debug("%s: %s", query, values) + logger.debug(f"{query}: {values}") async with connection.cursor() as cursor: if self.capabilities.supports_transactions: await self._execute_many_with_transaction(cursor, connection, query, values) @@ -263,7 +263,7 @@ async def _execute_many_with_transaction(self, cursor, connection, query: str, v async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: """Execute a query and return row count and results.""" async with self.acquire_connection() as connection: - logger.debug("%s: %s", query, values) + logger.debug(f"{query}: {values}") async with connection.cursor() as cursor: await cursor.execute(query, values) rows = await cursor.fetchall() @@ -284,7 +284,8 @@ async def execute_script(self, query: str) -> None: async def _execute_script(self, query: str, with_db: bool) -> None: """Execute a multi-statement query by parsing and running statements sequentially.""" async with self._acquire_connection(with_db) as connection: - logger.debug(query) + logger.debug(f"Executing script: {query}") + print(f"Executing script: {query}") async with connection.cursor() as cursor: # Parse multi-statement queries since MySQL Connector doesn't handle them well statements = sqlparse.split(query) @@ -375,7 +376,7 @@ async def release_savepoint(self): async def execute_many(self, query: str, values: list[list]) -> None: """Execute many queries without autocommit handling (already in transaction).""" async with self.acquire_connection() as connection: - logger.debug("%s: %s", query, values) + logger.debug(f"{query}: {values}") async with connection.cursor() as cursor: await cursor.executemany(query, values) From 2a5b1519413ba9760dc785b8a4bc23beb2a88409 Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Tue, 2 Dec 2025 20:33:36 -0800 Subject: [PATCH 04/11] add sql alchemy tortoise connection provider --- aws_advanced_python_wrapper/driver_dialect.py | 5 ++ .../driver_dialect_manager.py | 6 +- .../mysql_driver_dialect.py | 9 +- .../pg_driver_dialect.py | 11 +-- ...dvanced_python_wrapper_messages.properties | 2 + .../sqlalchemy_driver_dialect.py | 12 ++- .../tortoise/backend/mysql/client.py | 22 ++--- ...ql_alchemy_tortoise_connection_provider.py | 85 +++++++++++++++++++ .../container/utils/test_telemetry_info.py | 13 +-- 9 files changed, 127 insertions(+), 38 deletions(-) create mode 100644 aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index c9df891a..f1b338e7 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection, Cursor + from types import ModuleType from abc import ABC from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError @@ -164,3 +165,7 @@ def ping(self, conn: Connection) -> bool: return True except Exception: return False + + def get_driver_module(self) -> ModuleType: + raise UnsupportedOperationError( + Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "get_driver_module")) diff --git a/aws_advanced_python_wrapper/driver_dialect_manager.py b/aws_advanced_python_wrapper/driver_dialect_manager.py index dd71de40..420008dc 100644 --- a/aws_advanced_python_wrapper/driver_dialect_manager.py +++ b/aws_advanced_python_wrapper/driver_dialect_manager.py @@ -24,8 +24,7 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import (Properties, - WrapperProperties) +from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperties from aws_advanced_python_wrapper.utils.utils import Utils logger = Logger(__name__) @@ -52,7 +51,8 @@ class DriverDialectManager(DriverDialectProvider): } pool_connection_driver_dialect: Dict[str, str] = { - "SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect" + "SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect", + "SqlAlchemyTortoisePooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect", } @staticmethod diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index dd7055c5..6b0d7a95 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection + from types import ModuleType from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError from inspect import signature @@ -28,12 +29,11 @@ from aws_advanced_python_wrapper.errors import UnsupportedOperationError from aws_advanced_python_wrapper.utils.decorators import timeout from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import (Properties, - PropertiesUtils, - WrapperProperties) +from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties CMYSQL_ENABLED = False +import mysql.connector from mysql.connector import MySQLConnection # noqa: E402 from mysql.connector.cursor import MySQLCursor # noqa: E402 @@ -201,3 +201,6 @@ def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) def supports_connect_timeout(self) -> bool: return True + + def get_driver_module(self) -> ModuleType: + return mysql.connector diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index fbe441f3..7d02e118 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -14,6 +14,7 @@ from __future__ import annotations +from inspect import signature from typing import TYPE_CHECKING, Any, Callable, Set import psycopg @@ -21,16 +22,13 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection - -from inspect import signature + from types import ModuleType from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import (Properties, - PropertiesUtils, - WrapperProperties) +from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties class PgDriverDialect(DriverDialect): @@ -175,3 +173,6 @@ def supports_tcp_keepalive(self) -> bool: def supports_abort_connection(self) -> bool: return True + + def get_driver_module(self) -> ModuleType: + return psycopg \ No newline at end of file diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 65f16806..5c95e152 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -385,6 +385,8 @@ WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHos SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None. SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties. +SqlAlchemyTortoiseConnectionProvider.UnableToDetermineDialect=[SqlAlchemyTortoiseConnectionProvider] Unable to resolve sql alchemy dialect for the following driver dialect '{}'. + SqlAlchemyDriverDialect.SetValueOnNoneConnection=[SqlAlchemyDriverDialect] Attempted to set the '{}' value on a pooled connection, but no underlying driver connection was found. This can happen if the pooled connection has previously been closed. StaleDnsHelper.ClusterEndpointDns=[StaleDnsPlugin] Cluster endpoint {} resolves to {}. diff --git a/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py b/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py index 290a8f86..d167be2d 100644 --- a/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py @@ -16,17 +16,18 @@ from typing import TYPE_CHECKING, Any -from aws_advanced_python_wrapper.driver_dialect import DriverDialect -from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.utils.messages import Messages - if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.utils.properties import Properties + from types import ModuleType from sqlalchemy import PoolProxiedConnection +from aws_advanced_python_wrapper.driver_dialect import DriverDialect +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.messages import Messages + class SqlAlchemyDriverDialect(DriverDialect): _driver_name: str = "SQLAlchemy" @@ -125,3 +126,6 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection): return None return self._underlying_driver.transfer_session_state(from_driver_conn, to_driver_conn) + + def get_driver_module(self) -> ModuleType: + return self._underlying_driver.get_driver_module() diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py index fbf9f536..397eb5c2 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -22,7 +22,6 @@ from mysql.connector import errors from mysql.connector.charsets import MYSQL_CHARACTER_SETS from pypika_tortoise import MySQLQuery -from tortoise import timezone from tortoise.backends.base.client import ( BaseDBAsyncClient, Capabilities, @@ -39,16 +38,16 @@ ) from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager +from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError from aws_advanced_python_wrapper.hostinfo import HostInfo -from aws_advanced_python_wrapper.sql_alchemy_connection_provider import SqlAlchemyPooledConnectionProvider from aws_advanced_python_wrapper.tortoise.backend.base.client import ( TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext, ) from aws_advanced_python_wrapper.tortoise.backend.mysql.executor import AwsMySQLExecutor from aws_advanced_python_wrapper.tortoise.backend.mysql.schema_generator import AwsMySQLSchemaGenerator +from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import SqlAlchemyTortoisePooledConnectionProvider from aws_advanced_python_wrapper.utils.log import Logger -from aws_advanced_python_wrapper.errors import AwsWrapperError logger = Logger(__name__) T = TypeVar("T") @@ -62,10 +61,12 @@ async def translate_exceptions_(self, *args) -> T: try: try: return await func(self, *args) - except AwsWrapperError as aws_err: + except AwsWrapperError as aws_err: # Unwrap any AwsWrappedErrors if aws_err.__cause__: raise aws_err.__cause__ raise + except FailoverError as exc: # Raise any failover errors + raise except errors.IntegrityError as exc: raise IntegrityError(exc) except ( @@ -127,16 +128,16 @@ def __init__( self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) self.extra.pop("db", None) - self.extra.pop("autocommit", None) + self.extra.pop("autocommit", None) # We need this to be true self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") # Initialize connection templates self._init_connection_templates() - + # Initialize state self._template = {} self._connection = None - self._pool = None + self._pool: SqlAlchemyTortoisePooledConnectionProvider = None self._pool_init_lock = asyncio.Lock() def _init_connection_templates(self) -> None: @@ -158,20 +159,19 @@ def _configure_pool(self, host_info: HostInfo, props: Dict[str, Any]) -> Dict[st """Configure connection pool settings.""" return {"pool_size": self.pool_maxsize, "max_overflow": -1} - @staticmethod - def _get_pool_key(host_info: HostInfo, props: Dict[str, Any]) -> str: + def _get_pool_key(self, host_info: HostInfo, props: Dict[str, Any]) -> str: """Generate unique pool key for connection pooling.""" url = host_info.url user = props["user"] db = props["database"] - return f"{url}{user}{db}" + return f"{url}{user}{db}{self.connection_name}" async def _init_pool_if_needed(self) -> None: """Initialize connection pool only once across all instances.""" if not AwsMySQLClient._pool_initialized: async with AwsMySQLClient._pool_init_class_lock: if not AwsMySQLClient._pool_initialized: - self._pool = SqlAlchemyPooledConnectionProvider( + self._pool = SqlAlchemyTortoisePooledConnectionProvider( pool_configurator=self._configure_pool, pool_mapping=self._get_pool_key, ) diff --git a/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py new file mode 100644 index 00000000..31938b07 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import ModuleType +from typing import Callable, Dict + +from sqlalchemy import Dialect +from sqlalchemy.dialects.mysql import mysqlconnector +from sqlalchemy.dialects.postgresql import psycopg + +from aws_advanced_python_wrapper.database_dialect import DatabaseDialect +from aws_advanced_python_wrapper.driver_dialect import DriverDialect +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.sql_alchemy_connection_provider import SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import Properties + +logger = Logger(__name__) + + + +class SqlAlchemyTortoisePooledConnectionProvider(SqlAlchemyPooledConnectionProvider): + """ + Tortoise-specific pooled connection provider that handles failover by disposing pools. + """ + + _sqlalchemy_dialect_map : Dict[str, ModuleType] = { + "MySQLDriverDialect": mysqlconnector, + "PostgresDriverDialect": psycopg + } + + def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: + if self._accept_url_func: + return self._accept_url_func(host_info, props) + url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) + return url_type.is_rds + + def _create_pool( + self, + target_func: Callable, + driver_dialect: DriverDialect, + database_dialect: DatabaseDialect, + host_info: HostInfo, + props: Properties): + kwargs = dict() if self._pool_configurator is None else self._pool_configurator(host_info, props) + prepared_properties = driver_dialect.prepare_connect_info(host_info, props) + database_dialect.prepare_conn_props(prepared_properties) + kwargs["creator"] = self._get_connection_func(target_func, prepared_properties) + dialect = self._get_pool_dialect(driver_dialect) + if not dialect: + raise AwsWrapperError(Messages.get_formatted("SqlAlchemyTortoisePooledConnectionProvider.NoDialect", driver_dialect.__class__.__name__)) + + ''' + We need to pass in pre_ping and dialect to QueuePool the queue pool to enable health checks. + Without this health check, we could be using dead connections after a failover. + ''' + kwargs["pre_ping"] = True + kwargs["dialect"] = dialect + return self._create_sql_alchemy_pool(**kwargs) + + def _get_pool_dialect(self, driver_dialect: DriverDialect) -> Dialect: + dialect = None + driver_dialect_class_name = driver_dialect.__class__.__name__ + if driver_dialect_class_name == "SqlAlchemyDriverDialect": + driver_dialect_class_name = driver_dialect_class_name._underlying_driver_dialect.__class__.__name__ + module = self._sqlalchemy_dialect_map.get(driver_dialect_class_name) + + if not module: + return dialect + dialect = module.dialect() + dialect.dbapi = driver_dialect.get_driver_module() + + return dialect diff --git a/tests/integration/container/utils/test_telemetry_info.py b/tests/integration/container/utils/test_telemetry_info.py index 24267dc8..e37941ef 100644 --- a/tests/integration/container/utils/test_telemetry_info.py +++ b/tests/integration/container/utils/test_telemetry_info.py @@ -11,18 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + import typing from typing import Any, Dict From 1a5d163293616bbf88b29f93bea7b5a8ddb926ba Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Tue, 2 Dec 2025 20:34:00 -0800 Subject: [PATCH 05/11] add tortoise integration tests --- .../tortoise/backend/base/client.py | 11 +- .../tortoise/backend/mysql/client.py | 68 ++--- ...ql_alchemy_tortoise_connection_provider.py | 34 ++- tests/integration/container/conftest.py | 8 +- .../container/tortoise/__init__.py | 13 + .../container/tortoise/test_tortoise_basic.py | 255 +++++++++++++++++ .../tortoise/test_tortoise_common.py | 98 +++++++ .../tortoise/test_tortoise_custom_endpoint.py | 108 +++++++ .../tortoise/test_tortoise_failover.py | 249 +++++++++++++++++ .../test_tortoise_iam_authentication.py | 41 +++ .../tortoise/test_tortoise_models.py | 44 +++ .../tortoise/test_tortoise_multi_plugins.py | 264 ++++++++++++++++++ .../tortoise/test_tortoise_secrets_manager.py | 74 +++++ .../container/utils/connection_utils.py | 28 ++ .../utils/test_environment_request.py | 2 +- .../integration/container/utils/test_utils.py | 53 ++++ 16 files changed, 1288 insertions(+), 62 deletions(-) create mode 100644 tests/integration/container/tortoise/__init__.py create mode 100644 tests/integration/container/tortoise/test_tortoise_basic.py create mode 100644 tests/integration/container/tortoise/test_tortoise_common.py create mode 100644 tests/integration/container/tortoise/test_tortoise_custom_endpoint.py create mode 100644 tests/integration/container/tortoise/test_tortoise_failover.py create mode 100644 tests/integration/container/tortoise/test_tortoise_iam_authentication.py create mode 100644 tests/integration/container/tortoise/test_tortoise_models.py create mode 100644 tests/integration/container/tortoise/test_tortoise_multi_plugins.py create mode 100644 tests/integration/container/tortoise/test_tortoise_secrets_manager.py create mode 100644 tests/integration/container/utils/test_utils.py diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/client.py b/aws_advanced_python_wrapper/tortoise/backend/base/client.py index cb81e18d..f4fe3c3c 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/base/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/base/client.py @@ -29,7 +29,7 @@ class AwsWrapperAsyncConnector: """Factory class for creating AWS wrapper connections.""" _executor: ThreadPoolExecutor = ThreadPoolExecutor( - thread_name_prefix="AwsWrapperConnectorExecutor" + thread_name_prefix="AwsWrapperAsyncExecutor" ) @staticmethod @@ -141,19 +141,17 @@ def __del__(self): class TortoiseAwsClientConnectionWrapper(Generic[T_conn]): """Manages acquiring from and releasing connections to a pool.""" - __slots__ = ("client", "connection", "_pool_init_lock", "connect_func", "with_db") + __slots__ = ("client", "connection", "connect_func", "with_db") def __init__( self, client: BaseDBAsyncClient, - pool_init_lock: asyncio.Lock, connect_func: Callable, with_db: bool = True ) -> None: self.connect_func = connect_func self.client = client self.connection: T_conn | None = None - self._pool_init_lock = pool_init_lock self.with_db = with_db async def ensure_connection(self) -> None: @@ -175,12 +173,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class TortoiseAwsClientTransactionContext(TransactionContext): """Transaction context that uses a pool to acquire connections.""" - __slots__ = ("client", "connection_name", "token", "_pool_init_lock") + __slots__ = ("client", "connection_name", "token") - def __init__(self, client: TransactionalDBClient, pool_init_lock: asyncio.Lock) -> None: + def __init__(self, client: TransactionalDBClient) -> None: self.client = client self.connection_name = client.connection_name - self._pool_init_lock = pool_init_lock async def ensure_connection(self) -> None: """Ensure the connection pool is initialized.""" diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py index 397eb5c2..c4857acb 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -15,7 +15,7 @@ import asyncio from functools import wraps from itertools import count -from typing import Any, Callable, Coroutine, Dict, SupportsInt, TypeVar +from typing import Any, Callable, Coroutine, Dict, List, Optional, SupportsInt, Tuple, TypeVar import mysql.connector import sqlparse @@ -37,7 +37,7 @@ TransactionManagementError, ) -from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager +from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager, ConnectionProvider from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.tortoise.backend.base.client import ( @@ -94,7 +94,7 @@ class AwsMySQLClient(BaseDBAsyncClient): support_for_posix_regex_queries=True, support_json_attributes=True, ) - _pool_initialized = False + _provider: Optional[ConnectionProvider] = None _pool_init_class_lock = asyncio.Lock() def __init__( @@ -121,24 +121,19 @@ def __init__( # Extract MySQL-specific settings self.storage_engine = self.extra.pop("storage_engine", "innodb") self.charset = self.extra.pop("charset", "utf8mb4") - self.pool_minsize = int(self.extra.pop("minsize", 1)) - self.pool_maxsize = int(self.extra.pop("maxsize", 5)) # Remove Tortoise-specific parameters self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) - self.extra.pop("db", None) - self.extra.pop("autocommit", None) # We need this to be true + self.extra.pop("autocommit", None) self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") # Initialize connection templates self._init_connection_templates() # Initialize state - self._template = {} - self._connection = None - self._pool: SqlAlchemyTortoisePooledConnectionProvider = None - self._pool_init_lock = asyncio.Lock() + self._template: Dict[str, Any] = {} + self._connection: Optional[Any] = None def _init_connection_templates(self) -> None: """Initialize connection templates for with/without database.""" @@ -154,39 +149,12 @@ def _init_connection_templates(self) -> None: self._template_with_db = {**base_template, "database": self.database} self._template_no_db = {**base_template, "database": None} - # Pool Management - def _configure_pool(self, host_info: HostInfo, props: Dict[str, Any]) -> Dict[str, Any]: - """Configure connection pool settings.""" - return {"pool_size": self.pool_maxsize, "max_overflow": -1} - - def _get_pool_key(self, host_info: HostInfo, props: Dict[str, Any]) -> str: - """Generate unique pool key for connection pooling.""" - url = host_info.url - user = props["user"] - db = props["database"] - return f"{url}{user}{db}{self.connection_name}" - - async def _init_pool_if_needed(self) -> None: - """Initialize connection pool only once across all instances.""" - if not AwsMySQLClient._pool_initialized: - async with AwsMySQLClient._pool_init_class_lock: - if not AwsMySQLClient._pool_initialized: - self._pool = SqlAlchemyTortoisePooledConnectionProvider( - pool_configurator=self._configure_pool, - pool_mapping=self._get_pool_key, - ) - ConnectionProviderManager.set_connection_provider(self._pool) - AwsMySQLClient._pool_initialized = True - # Connection Management async def create_connection(self, with_db: bool) -> None: """Initialize connection pool and configure database settings.""" # Validate charset if self.charset.lower() not in [cs[0] for cs in MYSQL_CHARACTER_SETS if cs is not None]: raise DBConnectionError(f"Unknown character set: {self.charset}") - - # Initialize connection pool only once - await self._init_pool_if_needed() # Set transaction support based on storage engine if self.storage_engine.lower() != "innodb": @@ -195,8 +163,6 @@ async def create_connection(self, with_db: bool) -> None: # Set template based on database requirement self._template = self._template_with_db if with_db else self._template_no_db - logger.debug(f"Created connection pool {self._pool} with params: {self._template}") - async def close(self) -> None: """Close connections - AWS wrapper handles cleanup internally.""" pass @@ -205,10 +171,10 @@ def acquire_connection(self): """Acquire a connection from the pool.""" return self._acquire_connection(with_db=True) - def _acquire_connection(self, with_db: bool): + def _acquire_connection(self, with_db: bool) -> TortoiseAwsClientConnectionWrapper: """Create connection wrapper for specified database mode.""" return TortoiseAwsClientConnectionWrapper( - self, self._pool_init_lock, mysql.connector.Connect, with_db=with_db + self, mysql.connector.Connect, with_db=with_db ) # Database Operations @@ -226,7 +192,7 @@ async def db_delete(self) -> None: # Query Execution Methods @translate_exceptions - async def execute_insert(self, query: str, values: list) -> int: + async def execute_insert(self, query: str, values: List[Any]) -> int: """Execute an INSERT query and return the last inserted row ID.""" async with self.acquire_connection() as connection: logger.debug(f"{query}: {values}") @@ -235,7 +201,7 @@ async def execute_insert(self, query: str, values: list) -> int: return cursor.lastrowid @translate_exceptions - async def execute_many(self, query: str, values: list[list]) -> None: + async def execute_many(self, query: str, values: List[List[Any]]) -> None: """Execute a query with multiple parameter sets.""" async with self.acquire_connection() as connection: logger.debug(f"{query}: {values}") @@ -245,7 +211,7 @@ async def execute_many(self, query: str, values: list[list]) -> None: else: await cursor.executemany(query, values) - async def _execute_many_with_transaction(self, cursor, connection, query: str, values: list[list]) -> None: + async def _execute_many_with_transaction(self, cursor: Any, connection: Any, query: str, values: List[List[Any]]) -> None: """Execute many queries within a transaction.""" try: await connection.set_autocommit(False) @@ -260,7 +226,7 @@ async def _execute_many_with_transaction(self, cursor, connection, query: str, v await connection.set_autocommit(True) @translate_exceptions - async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: + async def execute_query(self, query: str, values: Optional[List[Any]] = None) -> Tuple[int, List[Dict[str, Any]]]: """Execute a query and return row count and results.""" async with self.acquire_connection() as connection: logger.debug(f"{query}: {values}") @@ -272,7 +238,7 @@ async def execute_query(self, query: str, values: list | None = None) -> tuple[i return cursor.rowcount, [dict(zip(fields, row)) for row in rows] return cursor.rowcount, [] - async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: + async def execute_query_dict(self, query: str, values: Optional[List[Any]] = None) -> List[Dict[str, Any]]: """Execute a query and return only the results as dictionaries.""" return (await self.execute_query(query, values))[1] @@ -297,7 +263,7 @@ async def _execute_script(self, query: str, with_db: bool) -> None: # Transaction Support def _in_transaction(self) -> TransactionContext: """Create a new transaction context.""" - return TortoiseAwsClientTransactionContext(TransactionWrapper(self), self._pool_init_lock) + return TortoiseAwsClientTransactionContext(TransactionWrapper(self)) class TransactionWrapper(AwsMySQLClient, TransactionalDBClient): @@ -307,7 +273,7 @@ def __init__(self, connection: AwsMySQLClient) -> None: self.connection_name = connection.connection_name self._connection = connection._connection self._lock = asyncio.Lock() - self._savepoint: str | None = None + self._savepoint: Optional[str] = None self._finalized: bool = False self._parent = connection @@ -373,7 +339,7 @@ async def release_savepoint(self): self._finalized = True @translate_exceptions - async def execute_many(self, query: str, values: list[list]) -> None: + async def execute_many(self, query: str, values: List[List[Any]]) -> None: """Execute many queries without autocommit handling (already in transaction).""" async with self.acquire_connection() as connection: logger.debug(f"{query}: {values}") @@ -381,6 +347,6 @@ async def execute_many(self, query: str, values: list[list]) -> None: await cursor.executemany(query, values) -def _gen_savepoint_name(_c=count()) -> str: +def _gen_savepoint_name(_c: count = count()) -> str: """Generate a unique savepoint name.""" return f"tortoise_savepoint_{next(_c)}" \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py index 31938b07..4f9b5cfd 100644 --- a/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py +++ b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import ModuleType -from typing import Callable, Dict +from typing import Any, Callable, Dict, Optional from sqlalchemy import Dialect from sqlalchemy.dialects.mysql import mysqlconnector from sqlalchemy.dialects.postgresql import psycopg +from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager from aws_advanced_python_wrapper.database_dialect import DatabaseDialect from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.errors import AwsWrapperError @@ -83,3 +84,34 @@ def _get_pool_dialect(self, driver_dialect: DriverDialect) -> Dialect: dialect.dbapi = driver_dialect.get_driver_module() return dialect + + +def setup_tortoise_connection_provider( + pool_configurator: Optional[Callable[[HostInfo, Properties], Dict[str, Any]]] = None, + pool_mapping: Optional[Callable[[HostInfo, Properties], str]] = None +) -> SqlAlchemyTortoisePooledConnectionProvider: + """ + Helper function to set up and configure the Tortoise connection provider. + + Args: + pool_configurator: Optional function to configure pool settings. + Defaults to basic pool configuration. + pool_mapping: Optional function to generate pool keys. + Defaults to basic pool key generation. + + Returns: + Configured SqlAlchemyTortoisePooledConnectionProvider instance. + """ + def default_pool_configurator(host_info: HostInfo, props: Properties) -> Dict[str, Any]: + return {"pool_size": 5, "max_overflow": -1} + + def default_pool_mapping(host_info: HostInfo, props: Properties) -> str: + return f"{host_info.url}{props.get('user', '')}{props.get('database', '')}" + + provider = SqlAlchemyTortoisePooledConnectionProvider( + pool_configurator=pool_configurator or default_pool_configurator, + pool_mapping=pool_mapping or default_pool_mapping + ) + + ConnectionProviderManager.set_connection_provider(provider) + return provider diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 2e23eeba..eed03977 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -69,7 +69,7 @@ def pytest_runtest_setup(item): else: TestEnvironment.get_current().set_current_driver(None) - logger.info("Starting test preparation for: " + test_name) + logger.info(f"Starting test preparation for: {test_name}") segment: Optional[Segment] = None if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): @@ -107,7 +107,11 @@ def pytest_runtest_setup(item): logger.warning("conftest.ExceptionWhileObtainingInstanceIDs", ex) instances = list() - sleep(5) + # Only sleep if we still need to retry + if (len(instances) < request.get_num_of_instances() + or len(instances) == 0 + or not rds_utility.is_db_instance_writer(instances[0])): + sleep(5) assert len(instances) > 0 current_writer = instances[0] diff --git a/tests/integration/container/tortoise/__init__.py b/tests/integration/container/tortoise/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/tests/integration/container/tortoise/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/container/tortoise/test_tortoise_basic.py b/tests/integration/container/tortoise/test_tortoise_basic.py new file mode 100644 index 00000000..4cbbbfde --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_basic.py @@ -0,0 +1,255 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +import pytest +import pytest_asyncio +from tortoise.transactions import atomic, in_transaction + +from tests.integration.container.tortoise.test_tortoise_common import \ + setup_tortoise +from tests.integration.container.tortoise.test_tortoise_models import ( + UniqueName, User) + + +class TestTortoiseBasic: + """ + Test class for Tortoise ORM integration with AWS Advanced Python Wrapper. + Contains tests related to tortoise operations + """ + + @pytest_asyncio.fixture + async def setup_tortoise_basic(self, conn_utils): + """Setup Tortoise with default plugins.""" + async for result in setup_tortoise(conn_utils): + yield result + + + + @pytest.mark.asyncio + async def test_basic_crud_operations(self, setup_tortoise_basic): + """Test basic CRUD operations with AWS wrapper.""" + # Create + user = await User.create(name="John Doe", email="john@example.com") + assert user.id is not None + assert user.name == "John Doe" + + # Read + found_user = await User.get(id=user.id) + assert found_user.name == "John Doe" + assert found_user.email == "john@example.com" + + # Update + found_user.name = "Jane Doe" + await found_user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Jane Doe" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_basic_crud_operations_config(self, setup_tortoise_basic): + """Test basic CRUD operations with AWS wrapper.""" + # Create + user = await User.create(name="John Doe", email="john@example.com") + assert user.id is not None + assert user.name == "John Doe" + + # Read + found_user = await User.get(id=user.id) + assert found_user.name == "John Doe" + assert found_user.email == "john@example.com" + + # Update + found_user.name = "Jane Doe" + await found_user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Jane Doe" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_transaction_support(self, setup_tortoise_basic): + """Test transaction handling with AWS wrapper.""" + async with in_transaction() as conn: + user1 = await User.create(name="User 1", email="user1@example.com", using_db=conn) + user2 = await User.create(name="User 2", email="user2@example.com", using_db=conn) + + # Verify users exist within transaction + users = await User.filter(name__in=["User 1", "User 2"]).using_db(conn) + assert len(users) == 2 + + @pytest.mark.asyncio + async def test_transaction_rollback(self, setup_tortoise_basic): + """Test transaction rollback with AWS wrapper.""" + try: + async with in_transaction() as conn: + await User.create(name="Test User", email="test@example.com", using_db=conn) + # Force rollback by raising exception + raise ValueError("Test rollback") + except ValueError: + pass + + # Verify user was not created due to rollback + users = await User.filter(name="Test User") + assert len(users) == 0 + + @pytest.mark.asyncio + async def test_bulk_operations(self, setup_tortoise_basic): + """Test bulk operations with AWS wrapper.""" + # Bulk create + users_data = [ + {"name": f"User {i}", "email": f"user{i}@example.com"} + for i in range(5) + ] + await User.bulk_create([User(**data) for data in users_data]) + + # Verify bulk creation + users = await User.filter(name__startswith="User") + assert len(users) == 5 + + # Bulk update + await User.filter(name__startswith="User").update(name="Updated User") + + updated_users = await User.filter(name="Updated User") + assert len(updated_users) == 5 + + @pytest.mark.asyncio + async def test_query_operations(self, setup_tortoise_basic): + """Test various query operations with AWS wrapper.""" + # Setup test data + await User.create(name="Alice", email="alice@example.com") + await User.create(name="Bob", email="bob@example.com") + await User.create(name="Charlie", email="charlie@example.com") + + # Test filtering + alice = await User.get(name="Alice") + assert alice.email == "alice@example.com" + + # Test count + count = await User.all().count() + assert count >= 3 + + # Test ordering + users = await User.all().order_by("name") + assert users[0].name == "Alice" + + # Test exists + exists = await User.filter(name="Alice").exists() + assert exists is True + + # Test values + emails = await User.all().values_list("email", flat=True) + assert "alice@example.com" in emails + + @pytest.mark.asyncio + async def test_bulk_create_with_ids(self, setup_tortoise_basic): + """Test bulk create operations with ID verification.""" + # Bulk create 1000 UniqueName objects with no name (null values) + await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) + + # Get all created records with id and name + all_ = await UniqueName.all().values("id", "name") + + # Get the starting ID + inc = all_[0]["id"] + + # Sort by ID for comparison + all_sorted = sorted(all_, key=lambda x: x["id"]) + + # Verify the IDs are sequential and names are None + expected = [{"id": val + inc, "name": None} for val in range(1000)] + + assert len(all_sorted) == 1000 + assert all_sorted == expected + + @pytest.mark.asyncio + async def test_concurrency_read(self, setup_tortoise_basic): + """Test concurrent read operations with AWS wrapper.""" + + await User.create(name="Test User", email="test@example.com") + user1 = await User.first() + + # Perform 100 concurrent reads + all_read = await asyncio.gather(*[User.first() for _ in range(100)]) + + # All reads should return the same user + assert all_read == [user1 for _ in range(100)] + + @pytest.mark.asyncio + async def test_concurrency_create(self, setup_tortoise_basic): + """Test concurrent create operations with AWS wrapper.""" + + # Perform 100 concurrent creates with unique emails + all_write = await asyncio.gather(*[ + User.create(name="Test", email=f"test{i}@example.com") + for i in range(100) + ]) + + # Read all created users + all_read = await User.all() + + # All created users should exist in the database + assert set(all_write) == set(all_read) + + @pytest.mark.asyncio + async def test_atomic_decorator(self, setup_tortoise_basic): + """Test atomic decorator for transaction handling with AWS wrapper.""" + + @atomic() + async def create_users_atomically(): + user1 = await User.create(name="Atomic User 1", email="atomic1@example.com") + user2 = await User.create(name="Atomic User 2", email="atomic2@example.com") + return user1, user2 + + # Execute atomic operation + user1, user2 = await create_users_atomically() + + # Verify both users were created + assert user1.id is not None + assert user2.id is not None + + # Verify users exist in database + found_users = await User.filter(name__startswith="Atomic User") + assert len(found_users) == 2 + + @pytest.mark.asyncio + async def test_atomic_decorator_rollback(self, setup_tortoise_basic): + """Test atomic decorator rollback on exception with AWS wrapper.""" + + @atomic() + async def create_users_with_error(): + await User.create(name="Atomic Rollback User", email="rollback@example.com") + # Force rollback by raising exception + raise ValueError("Intentional error for rollback test") + + # Execute atomic operation that should fail + with pytest.raises(ValueError): + await create_users_with_error() + + # Verify user was not created due to rollback + users = await User.filter(name="Atomic Rollback User") + assert len(users) == 0 \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_common.py b/tests/integration/container/tortoise/test_tortoise_common.py new file mode 100644 index 00000000..3dd44277 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_common.py @@ -0,0 +1,98 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise +from tests.integration.container.tortoise.test_tortoise_models import * +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment + + +async def clear_test_models(): + """Clear all test models by calling .all().delete() on each.""" + from tortoise.models import Model + + import tests.integration.container.tortoise.test_tortoise_models as models_module + + for attr_name in dir(models_module): + attr = getattr(models_module, attr_name) + if isinstance(attr, type) and issubclass(attr, Model) and attr != Model: + await attr.all().delete() + +async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwargs): + """Setup Tortoise with AWS MySQL backend and configurable plugins.""" + db_url = conn_utils.get_aws_tortoise_url( + TestEnvironment.get_current().get_engine(), + plugins=plugins, + **kwargs, + ) + + config = { + "connections": { + "default": db_url + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.test_tortoise_models"], + "default_connection": "default", + } + } + } + + from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ + setup_tortoise_connection_provider + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + + await clear_test_models() + + yield + + await clear_test_models() + await Tortoise.close_connections() + + +async def run_basic_read_operations(name_prefix="Test", email_prefix="test"): + """Common test logic for basic read operations.""" + user = await User.create(name=f"{name_prefix} User", email=f"{email_prefix}@example.com") + + found_user = await User.get(id=user.id) + assert found_user.name == f"{name_prefix} User" + assert found_user.email == f"{email_prefix}@example.com" + + users = await User.filter(name=f"{name_prefix} User") + assert len(users) == 1 + assert users[0].id == user.id + + +async def run_basic_write_operations(name_prefix="Write", email_prefix="write"): + """Common test logic for basic write operations.""" + user = await User.create(name=f"{name_prefix} Test", email=f"{email_prefix}@example.com") + assert user.id is not None + + user.name = f"Updated {name_prefix}" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == f"Updated {name_prefix}" + + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) diff --git a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py new file mode 100644 index 00000000..9a1697de --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py @@ -0,0 +1,108 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from time import perf_counter_ns, sleep +from uuid import uuid4 + +import pytest +import pytest_asyncio +from boto3 import client +from botocore.exceptions import ClientError + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.tortoise.test_tortoise_models import User +from tests.integration.container.utils.test_environment import TestEnvironment + + +class TestTortoiseCustomEndpoint: + """Test class for Tortoise ORM with custom endpoint plugin.""" + endpoint_id = f"test-tortoise-endpoint-{uuid4()}" + endpoint_info = {} + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing.""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + instances = env_info.get_database_info().get_instances() + instance_ids = [instances[0].get_instance_id()] + + try: + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + self._wait_until_endpoint_available(rds_client) + yield self.endpoint_info["Endpoint"] + finally: + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pass # Ignore if endpoint doesn't exist + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + """Wait for the custom endpoint to become available.""" + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) + continue + + response_endpoint = response_endpoints[0] + TestTortoiseCustomEndpoint.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") + + @pytest_asyncio.fixture + async def setup_tortoise_custom_endpoint(self, conn_utils, create_custom_endpoint): + """Setup Tortoise with custom endpoint plugin.""" + async for result in setup_tortoise(conn_utils, plugins="custom_endpoint,aurora_connection_tracker"): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_custom_endpoint): + """Test basic read operations with custom endpoint plugin.""" + await run_basic_read_operations("Custom Test", "custom") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_custom_endpoint): + """Test basic write operations with custom endpoint plugin.""" + await run_basic_write_operations("Custom", "customwrite") \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_failover.py b/tests/integration/container/tortoise/test_tortoise_failover.py new file mode 100644 index 00000000..0abe36f6 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_failover.py @@ -0,0 +1,249 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time + +import pytest +import pytest_asyncio +from tortoise import connections +from tortoise.transactions import in_transaction + +from aws_advanced_python_wrapper.errors import ( + FailoverSuccessError, TransactionResolutionUnknownError) +from tests.integration.container.tortoise.test_tortoise_common import \ + setup_tortoise +from tests.integration.container.tortoise.test_tortoise_models import \ + TableWithSleepTrigger +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_utils import get_sleep_trigger_sql + + +class TestTortoiseFailover: + """Test class for Tortoise ORM with AWS wrapper plugins.""" + @pytest.fixture(scope='class') + def aurora_utility(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + async def _create_sleep_trigger_record(self, name_prefix="Plugin Test", value="test_value", using_db=None): + """Create 3 TableWithSleepTrigger records.""" + await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) + # await TableWithSleepTrigger.create(name=f"{name_prefix}2", value=value, using_db=using_db) + # await TableWithSleepTrigger.create(name=f"{name_prefix}3", value=value, using_db=using_db) + + @pytest_asyncio.fixture + async def sleep_trigger_setup(self): + """Setup and cleanup sleep trigger for testing.""" + connection = connections.get("default") + db_engine = TestEnvironment.get_current().get_engine() + trigger_sql = get_sleep_trigger_sql(db_engine, 60, "table_with_sleep_trigger") + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + await connection.execute_query(trigger_sql) + yield + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + + @pytest_asyncio.fixture + async def setup_tortoise_with_failover(self, conn_utils): + """Setup Tortoise with failover plugins.""" + kwargs = { + "topology_refresh_ms": 10000, + } + async for result in setup_tortoise(conn_utils, plugins="failover", **kwargs): + yield result + + + @pytest.mark.asyncio + async def test_basic_operations_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test basic operations work with AWS wrapper plugins enabled.""" + insert_exception = None + + def insert_thread(): + nonlocal insert_exception + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._create_sleep_trigger_record()) + except Exception as e: + insert_exception = e + + def failover_thread(): + time.sleep(5) # Wait for insert to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start both threads + insert_t = threading.Thread(target=insert_thread) + failover_t = threading.Thread(target=failover_thread) + + insert_t.start() + failover_t.start() + + # Wait for both threads to complete + insert_t.join() + failover_t.join() + + # Assert that insert thread got FailoverSuccessError + assert insert_exception is not None + assert isinstance(insert_exception, FailoverSuccessError) + + @pytest.mark.asyncio + async def test_transaction_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test transactions with failover during long-running operations.""" + transaction_exception = None + + def transaction_thread(): + nonlocal transaction_exception + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def run_transaction(): + async with in_transaction() as conn: + await self._create_sleep_trigger_record("TX Plugin Test", "tx_test_value", conn) + + loop.run_until_complete(run_transaction()) + except Exception as e: + transaction_exception = e + + def failover_thread(): + time.sleep(5) # Wait for transaction to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start both threads + tx_t = threading.Thread(target=transaction_thread) + failover_t = threading.Thread(target=failover_thread) + + tx_t.start() + failover_t.start() + + # Wait for both threads to complete + tx_t.join() + failover_t.join() + + # Assert that transaction thread got FailoverSuccessError + assert transaction_exception is not None + assert isinstance(transaction_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) + + # Verify no records were created due to transaction rollback + record_count = await TableWithSleepTrigger.all().count() + assert record_count == 0 + + # Verify autocommit is re-enabled after failover by inserting a record + autocommit_record = await TableWithSleepTrigger.create( + name="Autocommit Test", + value="autocommit_value" + ) + + # Verify the record exists (should be auto-committed) + found_record = await TableWithSleepTrigger.get(id=autocommit_record.id) + assert found_record.name == "Autocommit Test" + assert found_record.value == "autocommit_value" + + # Clean up the test record + await found_record.delete() + + @pytest.mark.asyncio + async def test_concurrent_queries_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test concurrent queries with failover during long-running operation.""" + connection = connections.get("default") + + # Step 1: Run 15 concurrent select queries + async def run_select_query(query_id): + return await connection.execute_query(f"SELECT {query_id} as query_id") + + initial_tasks = [run_select_query(i) for i in range(15)] + initial_results = await asyncio.gather(*initial_tasks) + assert len(initial_results) == 15 + + # Step 2: Run sleep query with failover + sleep_exception = None + + def sleep_query_thread(): + nonlocal sleep_exception + for attempt in range(3): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Concurrent Test {attempt}", value="sleep_value") + ) + except Exception as e: + sleep_exception = e + break # Stop on first exception (likely the failover) + + def failover_thread(): + time.sleep(5) # Wait for sleep query to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + sleep_t = threading.Thread(target=sleep_query_thread) + failover_t = threading.Thread(target=failover_thread) + + sleep_t.start() + failover_t.start() + + sleep_t.join() + failover_t.join() + + # Verify failover exception occurred + assert sleep_exception is not None + assert isinstance(sleep_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) + + # Step 3: Run another 15 concurrent select queries after failover + post_failover_tasks = [run_select_query(i + 100) for i in range(15)] + post_failover_results = await asyncio.gather(*post_failover_tasks) + assert len(post_failover_results) == 15 + + @pytest.mark.asyncio + async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test multiple concurrent insert operations with failover during long-running operations.""" + # Track exceptions from all insert threads + insert_exceptions = [] + + def insert_thread(thread_id): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Concurrent Insert {thread_id}", value="insert_value") + ) + except Exception as e: + insert_exceptions.append(e) + + def failover_thread(): + time.sleep(5) # Wait for inserts to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start 15 insert threads + insert_threads = [] + for i in range(15): + t = threading.Thread(target=insert_thread, args=(i,)) + insert_threads.append(t) + t.start() + + # Start failover thread + failover_t = threading.Thread(target=failover_thread) + failover_t.start() + + # Wait for all threads to complete + for t in insert_threads: + t.join() + failover_t.join() + + # Verify ALL threads got FailoverSuccessError or TransactionResolutionUnknownError + assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" + failover_errors = [e for e in insert_exceptions if isinstance(e, (FailoverSuccessError, TransactionResolutionUnknownError))] + assert len(failover_errors) == 15, f"Expected all 15 threads to get failover errors, got {len(failover_errors)}" + print(f"Successfully verified all 15 threads got failover exceptions") \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py new file mode 100644 index 00000000..5dee829c --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.tortoise.test_tortoise_models import User +from tests.integration.container.utils.test_environment import TestEnvironment + + +class TestTortoiseIamAuthentication: + """Test class for Tortoise ORM with IAM authentication.""" + + @pytest_asyncio.fixture + async def setup_tortoise_iam(self, conn_utils): + """Setup Tortoise with IAM authentication.""" + async for result in setup_tortoise(conn_utils, plugins="iam", user=conn_utils.iam_user): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_iam): + """Test basic read operations with IAM authentication.""" + await run_basic_read_operations() + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_iam): + """Test basic write operations with IAM authentication.""" + await run_basic_write_operations() \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_models.py b/tests/integration/container/tortoise/test_tortoise_models.py new file mode 100644 index 00000000..60f1e043 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_models.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "users" + + +class UniqueName(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=20, null=True, unique=True) + optional = fields.CharField(max_length=20, null=True) + other_optional = fields.CharField(max_length=20, null=True) + + class Meta: + table = "unique_names" + + +class TableWithSleepTrigger(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + value = fields.CharField(max_length=100) + + class Meta: + table = "table_with_sleep_trigger" \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py new file mode 100644 index 00000000..83169d52 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py @@ -0,0 +1,264 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time +from uuid import uuid4 + +import pytest +import pytest_asyncio +from boto3 import client +from botocore.exceptions import ClientError +from tortoise import connections + +from aws_advanced_python_wrapper.errors import FailoverSuccessError +from tests.integration.container.tortoise.test_tortoise_common import setup_tortoise, run_basic_read_operations, run_basic_write_operations +from tests.integration.container.tortoise.test_tortoise_models import TableWithSleepTrigger +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_utils import get_sleep_trigger_sql + + +class TestTortoiseMultiPlugins: + """Test class for Tortoise ORM with multiple AWS wrapper plugins.""" + endpoint_id = f"test-multi-endpoint-{uuid4()}" + endpoint_info = {} + + @pytest.fixture(scope='class') + def aurora_utility(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing.""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + instances = env_info.get_database_info().get_instances() + instance_ids = [instances[0].get_instance_id()] + + try: + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + self._wait_until_endpoint_available(rds_client) + yield self.endpoint_info["Endpoint"] + finally: + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pass + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + """Wait for the custom endpoint to become available.""" + from time import perf_counter_ns, sleep + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) + continue + + response_endpoint = response_endpoints[0] + TestTortoiseMultiPlugins.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") + + async def _create_sleep_trigger_record(self, name_prefix="Multi Test", value="multi_value", using_db=None): + """Create TableWithSleepTrigger record.""" + await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) + + @pytest_asyncio.fixture + async def sleep_trigger_setup(self): + """Setup and cleanup sleep trigger for testing.""" + connection = connections.get("default") + db_engine = TestEnvironment.get_current().get_engine() + trigger_sql = get_sleep_trigger_sql(db_engine, 60, "table_with_sleep_trigger") + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + await connection.execute_query(trigger_sql) + yield + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + + @pytest_asyncio.fixture + async def setup_tortoise_multi_plugins(self, conn_utils, create_custom_endpoint): + """Setup Tortoise with multiple plugins.""" + kwargs = { + "topology_refresh_ms": 10000, + "reader_host_selector_strategy": "fastest_response" + } + async for result in setup_tortoise(conn_utils, plugins="failover,iam,custom_endpoint,aurora_connection_tracker,fastest_response_strategy", user=conn_utils.iam_user, **kwargs): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_multi_plugins): + """Test basic read operations with multiple plugins.""" + await run_basic_read_operations("Multi", "multi") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_multi_plugins): + """Test basic write operations with multiple plugins.""" + await run_basic_write_operations("Multi", "multi") + + @pytest.mark.asyncio + async def test_multi_plugins_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): + """Test multiple plugins work with failover during long-running operations.""" + insert_exception = None + + def insert_thread(): + nonlocal insert_exception + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._create_sleep_trigger_record()) + except Exception as e: + insert_exception = e + + def failover_thread(): + time.sleep(5) # Wait for insert to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start both threads + insert_t = threading.Thread(target=insert_thread) + failover_t = threading.Thread(target=failover_thread) + + insert_t.start() + failover_t.start() + + # Wait for both threads to complete + insert_t.join() + failover_t.join() + + # Assert that insert thread got FailoverSuccessError + assert insert_exception is not None + assert isinstance(insert_exception, FailoverSuccessError) + + @pytest.mark.asyncio + async def test_concurrent_queries_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): + """Test concurrent queries with failover during long-running operation.""" + connection = connections.get("default") + + # Step 1: Run 15 concurrent select queries + async def run_select_query(query_id): + return await connection.execute_query(f"SELECT {query_id} as query_id") + + initial_tasks = [run_select_query(i) for i in range(15)] + initial_results = await asyncio.gather(*initial_tasks) + assert len(initial_results) == 15 + + # Step 2: Run sleep query with failover + sleep_exception = None + + def sleep_query_thread(): + nonlocal sleep_exception + for attempt in range(3): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Multi Concurrent Test {attempt}", value="multi_sleep_value") + ) + except Exception as e: + sleep_exception = e + break # Stop on first exception (likely the failover) + + def failover_thread(): + time.sleep(5) # Wait for sleep query to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + sleep_t = threading.Thread(target=sleep_query_thread) + failover_t = threading.Thread(target=failover_thread) + + sleep_t.start() + failover_t.start() + + sleep_t.join() + failover_t.join() + + # Verify failover exception occurred + assert sleep_exception is not None + assert isinstance(sleep_exception, FailoverSuccessError) + + # Step 3: Run another 15 concurrent select queries after failover + post_failover_tasks = [run_select_query(i + 100) for i in range(15)] + post_failover_results = await asyncio.gather(*post_failover_tasks) + assert len(post_failover_results) == 15 + + @pytest.mark.asyncio + async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): + """Test multiple concurrent insert operations with failover during long-running operations.""" + # Track exceptions from all insert threads + insert_exceptions = [] + + def insert_thread(thread_id): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Multi Concurrent Insert {thread_id}", value="multi_insert_value") + ) + except Exception as e: + insert_exceptions.append(e) + + def failover_thread(): + time.sleep(5) # Wait for inserts to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start 15 insert threads + insert_threads = [] + for i in range(15): + t = threading.Thread(target=insert_thread, args=(i,)) + insert_threads.append(t) + t.start() + + # Start failover thread + failover_t = threading.Thread(target=failover_thread) + failover_t.start() + + # Wait for all threads to complete + for t in insert_threads: + t.join() + failover_t.join() + + # Verify ALL threads got FailoverSuccessError or TransactionResolutionUnknownError + assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" + failover_errors = [e for e in insert_exceptions if isinstance(e, FailoverSuccessError)] + assert len(failover_errors) == 15, f"Expected all 15 threads to get failover errors, got {len(failover_errors)}" + print(f"Successfully verified all 15 threads got failover exceptions with multiple plugins") \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py new file mode 100644 index 00000000..d150c2fa --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import boto3 +import pytest +import pytest_asyncio + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.tortoise.test_tortoise_models import User +from tests.integration.container.utils.test_environment import TestEnvironment + + +class TestTortoiseSecretsManager: + """Test class for Tortoise ORM with AWS Secrets Manager authentication.""" + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + client = boto3.client('secretsmanager', region_name=region) + + secret_name = f"test-tortoise-secret-{TestEnvironment.get_current().get_info().get_db_name()}" + secret_value = { + "username": conn_utils.user, + "password": conn_utils.password + } + + try: + response = client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_id = response['ARN'] + yield secret_id + finally: + try: + client.delete_secret( + SecretId=secret_name, + ForceDeleteWithoutRecovery=True + ) + except Exception: + pass + + @pytest_asyncio.fixture + async def setup_tortoise_secrets_manager(self, conn_utils, create_secret): + """Setup Tortoise with Secrets Manager authentication.""" + async for result in setup_tortoise(conn_utils, + plugins="aws_secrets_manager", + secrets_manager_secret_id=create_secret): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_secrets_manager): + """Test basic read operations with Secrets Manager authentication.""" + await run_basic_read_operations("Secrets", "secrets") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_secrets_manager): + """Test basic write operations with Secrets Manager authentication.""" + await run_basic_write_operations("Secrets", "secretswrite") \ No newline at end of file diff --git a/tests/integration/container/utils/connection_utils.py b/tests/integration/container/utils/connection_utils.py index 52d0add5..1cc57b06 100644 --- a/tests/integration/container/utils/connection_utils.py +++ b/tests/integration/container/utils/connection_utils.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Optional +from .database_engine import DatabaseEngine from .driver_helper import DriverHelper from .test_environment import TestEnvironment @@ -97,3 +98,30 @@ def get_proxy_connect_params( password = self.password if password is None else password dbname = self.dbname if dbname is None else dbname return DriverHelper.get_connect_params(host, port, user, password, dbname) + + def get_aws_tortoise_url( + self, + db_engine: DatabaseEngine, + host: Optional[str] = None, + port: Optional[int] = None, + user: Optional[str] = None, + password: Optional[str] = None, + dbname: Optional[str] = None, + **kwargs) -> str: + """Build AWS MySQL connection URL for Tortoise ORM with query parameters.""" + host = self.writer_cluster_host if host is None else host + port = self.port if port is None else port + user = self.user if user is None else user + password = self.password if password is None else password + dbname = self.dbname if dbname is None else dbname + + # Build base URL + protocol = "aws-pg" if db_engine == DatabaseEngine.PG else "aws-mysql" + url = f"{protocol}://{user}:{password}@{host}:{port}/{dbname}" + + # Add all kwargs as query parameters + if kwargs: + params = [f"{key}={value}" for key, value in kwargs.items()] + url += "?" + "&".join(params) + + return url diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index db1bbeef..93fe23bf 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -40,7 +40,7 @@ def __init__(self, request: Dict[str, Any]) -> None: self._instances = getattr(DatabaseInstances, typing.cast('str', request.get("instances"))) self._deployment = getattr(DatabaseEngineDeployment, typing.cast('str', request.get("deployment"))) self._target_python_version = getattr(TargetPythonVersion, typing.cast('str', request.get("targetPythonVersion"))) - self._num_of_instances = typing.cast('int', request.get("numOfInstances")) + self._num_of_instances = typing.cast('int', 2) self._features = set() features: Iterable[str] = typing.cast('Iterable[str]', request.get("features")) if features is not None: diff --git a/tests/integration/container/utils/test_utils.py b/tests/integration/container/utils/test_utils.py new file mode 100644 index 00000000..bf669128 --- /dev/null +++ b/tests/integration/container/utils/test_utils.py @@ -0,0 +1,53 @@ +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .database_engine import DatabaseEngine + + +def get_sleep_sql(db_engine : DatabaseEngine, seconds : int): + if db_engine == DatabaseEngine.MYSQL: + return f"SELECT SLEEP({seconds});" + elif db_engine == DatabaseEngine.PG: + return f"SELECT PG_SLEEP({seconds});" + else: + raise ValueError("Unknown database engine: " + db) + + +def get_sleep_trigger_sql(db_engine: DatabaseEngine, duration: int, table_name: str) -> str: + """Generate SQL to create a sleep trigger for INSERT operations.""" + if db_engine == DatabaseEngine.MYSQL: + return f""" + CREATE TRIGGER {table_name}_sleep_trigger + BEFORE INSERT ON {table_name} + FOR EACH ROW + BEGIN + DO SLEEP({duration}); + END + """ + elif db_engine == DatabaseEngine.PG: + return f""" + CREATE OR REPLACE FUNCTION {table_name}_sleep_function() + RETURNS TRIGGER AS $$ + BEGIN + PERFORM pg_sleep({duration}); + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + + CREATE TRIGGER {table_name}_sleep_trigger + BEFORE INSERT ON {table_name} + FOR EACH ROW + EXECUTE FUNCTION {table_name}_sleep_function(); + """ + else: + raise ValueError(f"Unknown database engine: {db_engine}") + From 04221c1cde7167157ac3f2cc5b24150f8d88d571 Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Thu, 4 Dec 2025 01:49:56 -0800 Subject: [PATCH 06/11] Added more intergation tests and fixed unit tests --- .../sql_alchemy_connection_provider.py | 1 - .../tortoise/backend/base/client.py | 62 +-- .../tortoise/backend/mysql/client.py | 9 +- tests/integration/container/conftest.py | 2 + .../container/tortoise/models/__init__.py | 13 + .../test_models.py} | 0 .../tortoise/models/test_models_copy.py | 44 ++ .../models/test_models_relationships.py | 78 ++++ .../container/tortoise/router/__init__.py | 13 + .../container/tortoise/router/test_router.py | 20 + .../container/tortoise/test_tortoise_basic.py | 20 +- .../tortoise/test_tortoise_common.py | 17 +- .../tortoise/test_tortoise_config.py | 292 +++++++++++++ .../tortoise/test_tortoise_custom_endpoint.py | 16 +- .../tortoise/test_tortoise_failover.py | 40 +- .../test_tortoise_iam_authentication.py | 13 +- .../tortoise/test_tortoise_multi_plugins.py | 29 +- .../tortoise/test_tortoise_relations.py | 227 +++++++++++ .../tortoise/test_tortoise_secrets_manager.py | 14 +- tests/unit/test_tortoise_base_client.py | 99 ++--- tests/unit/test_tortoise_mysql_client.py | 382 +++++++++--------- 21 files changed, 1062 insertions(+), 329 deletions(-) create mode 100644 tests/integration/container/tortoise/models/__init__.py rename tests/integration/container/tortoise/{test_tortoise_models.py => models/test_models.py} (100%) create mode 100644 tests/integration/container/tortoise/models/test_models_copy.py create mode 100644 tests/integration/container/tortoise/models/test_models_relationships.py create mode 100644 tests/integration/container/tortoise/router/__init__.py create mode 100644 tests/integration/container/tortoise/router/test_router.py create mode 100644 tests/integration/container/tortoise/test_tortoise_config.py create mode 100644 tests/integration/container/tortoise/test_tortoise_relations.py diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index f2f57407..cfdf0c6b 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -137,7 +137,6 @@ def connect( if queue_pool is None: raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url)) - return queue_pool.connect() # The pool key should always be retrieved using this method, because the username diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/client.py b/aws_advanced_python_wrapper/tortoise/backend/base/client.py index f4fe3c3c..bb31b224 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/base/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/base/client.py @@ -14,7 +14,6 @@ import asyncio import mysql.connector -from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import Any, Callable, Generic @@ -26,66 +25,47 @@ class AwsWrapperAsyncConnector: - """Factory class for creating AWS wrapper connections.""" - - _executor: ThreadPoolExecutor = ThreadPoolExecutor( - thread_name_prefix="AwsWrapperAsyncExecutor" - ) + """Class for creating and closing AWS wrapper connections.""" @staticmethod async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection: """Create an AWS wrapper connection with async cursor support.""" - loop = asyncio.get_event_loop() - connection = await loop.run_in_executor( - AwsWrapperAsyncConnector._executor, - lambda: AwsWrapperConnection.connect(connect_func, **kwargs) + connection = await asyncio.to_thread( + AwsWrapperConnection.connect, connect_func, **kwargs ) return AwsConnectionAsyncWrapper(connection) @staticmethod async def CloseAwsWrapper(connection: AwsWrapperConnection) -> None: """Close an AWS wrapper connection asynchronously.""" - loop = asyncio.get_event_loop() - await loop.run_in_executor( - AwsWrapperAsyncConnector._executor, - connection.close - ) + await asyncio.to_thread(connection.close) class AwsCursorAsyncWrapper: - """Wraps a sync cursor to provide async interface.""" - - _executor: ThreadPoolExecutor = ThreadPoolExecutor( - thread_name_prefix="AwsCursorAsyncWrapperExecutor" - ) + """Wraps sync AwsCursor cursor with async support.""" def __init__(self, sync_cursor): self._cursor = sync_cursor async def execute(self, query, params=None): """Execute a query asynchronously.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._cursor.execute, query, params) + return await asyncio.to_thread(self._cursor.execute, query, params) async def executemany(self, query, params_list): """Execute multiple queries asynchronously.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._cursor.executemany, query, params_list) + return await asyncio.to_thread(self._cursor.executemany, query, params_list) async def fetchall(self): """Fetch all results asynchronously.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._cursor.fetchall) + return await asyncio.to_thread(self._cursor.fetchall) async def fetchone(self): """Fetch one result asynchronously.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._cursor.fetchone) + return await asyncio.to_thread(self._cursor.fetchone) async def close(self): """Close cursor asynchronously.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._cursor.close) + return await asyncio.to_thread(self._cursor.close) def __getattr__(self, name): """Delegate non-async attributes to the wrapped cursor.""" @@ -93,11 +73,7 @@ def __getattr__(self, name): class AwsConnectionAsyncWrapper(AwsWrapperConnection): - """AWS wrapper connection with async cursor support.""" - - _executor: ThreadPoolExecutor = ThreadPoolExecutor( - thread_name_prefix="AwsConnectionAsyncWrapperExecutor" - ) + """Wraps sync AwsConnection with async cursor support.""" def __init__(self, connection: AwsWrapperConnection): self._wrapped_connection = connection @@ -105,27 +81,23 @@ def __init__(self, connection: AwsWrapperConnection): @asynccontextmanager async def cursor(self): """Create an async cursor context manager.""" - loop = asyncio.get_event_loop() - cursor_obj = await loop.run_in_executor(self._executor, self._wrapped_connection.cursor) + cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor) try: yield AwsCursorAsyncWrapper(cursor_obj) finally: - await loop.run_in_executor(self._executor, cursor_obj.close) + await asyncio.to_thread(cursor_obj.close) async def rollback(self): """Rollback the current transaction.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._wrapped_connection.rollback) + return await asyncio.to_thread(self._wrapped_connection.rollback) async def commit(self): """Commit the current transaction.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._wrapped_connection.commit) + return await asyncio.to_thread(self._wrapped_connection.commit) async def set_autocommit(self, value: bool): """Set autocommit mode.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, lambda: setattr(self._wrapped_connection, 'autocommit', value)) + return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value) def __getattr__(self, name): """Delegate all other attributes/methods to the wrapped connection.""" @@ -209,5 +181,5 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: else: await self.client.commit() finally: + await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection) connections.reset(self.token) - await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection) \ No newline at end of file diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py index c4857acb..c63626b2 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -117,7 +117,7 @@ def __init__( self.host = host self.port = int(port) self.extra = kwargs.copy() - + # Extract MySQL-specific settings self.storage_engine = self.extra.pop("storage_engine", "innodb") self.charset = self.extra.pop("charset", "utf8mb4") @@ -128,6 +128,12 @@ def __init__( self.extra.pop("autocommit", None) self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") + # Ensure that timeouts are integers + timeout_params = ["connect_timeout", "monitoring-connect_timeout",] + for param in timeout_params: + if param in self.extra and self.extra[param] is not None: + self.extra[param] = int(self.extra[param]) + # Initialize connection templates self._init_connection_templates() @@ -251,7 +257,6 @@ async def _execute_script(self, query: str, with_db: bool) -> None: """Execute a multi-statement query by parsing and running statements sequentially.""" async with self._acquire_connection(with_db) as connection: logger.debug(f"Executing script: {query}") - print(f"Executing script: {query}") async with connection.cursor() as cursor: # Parse multi-statement queries since MySQL Connector doesn't handle them well statements = sqlparse.split(query) diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index eed03977..2647de22 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -68,6 +68,8 @@ def pytest_runtest_setup(item): test_name = item.callspec.id else: TestEnvironment.get_current().set_current_driver(None) + # Fallback to item.name if no callspec (for non-parameterized tests) + test_name = getattr(item, 'name', None) or str(item) logger.info(f"Starting test preparation for: {test_name}") diff --git a/tests/integration/container/tortoise/models/__init__.py b/tests/integration/container/tortoise/models/__init__.py new file mode 100644 index 00000000..7d84fea8 --- /dev/null +++ b/tests/integration/container/tortoise/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_models.py b/tests/integration/container/tortoise/models/test_models.py similarity index 100% rename from tests/integration/container/tortoise/test_tortoise_models.py rename to tests/integration/container/tortoise/models/test_models.py diff --git a/tests/integration/container/tortoise/models/test_models_copy.py b/tests/integration/container/tortoise/models/test_models_copy.py new file mode 100644 index 00000000..60f1e043 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models_copy.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "users" + + +class UniqueName(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=20, null=True, unique=True) + optional = fields.CharField(max_length=20, null=True) + other_optional = fields.CharField(max_length=20, null=True) + + class Meta: + table = "unique_names" + + +class TableWithSleepTrigger(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + value = fields.CharField(max_length=100) + + class Meta: + table = "table_with_sleep_trigger" \ No newline at end of file diff --git a/tests/integration/container/tortoise/models/test_models_relationships.py b/tests/integration/container/tortoise/models/test_models_relationships.py new file mode 100644 index 00000000..b941311d --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models_relationships.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +# One-to-One Relationship Models +class RelTestAccount(Model): + id = fields.IntField(primary_key=True) + username = fields.CharField(max_length=50, unique=True) + email = fields.CharField(max_length=100) + + class Meta: + table = "rel_test_accounts" + + +class RelTestAccountProfile(Model): + id = fields.IntField(primary_key=True) + account = fields.OneToOneField("models.RelTestAccount", related_name="profile", on_delete=fields.CASCADE) + bio = fields.TextField(null=True) + avatar_url = fields.CharField(max_length=200, null=True) + + class Meta: + table = "rel_test_account_profiles" + + +# One-to-Many Relationship Models +class RelTestPublisher(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "rel_test_publishers" + + +class RelTestPublication(Model): + id = fields.IntField(primary_key=True) + title = fields.CharField(max_length=200) + isbn = fields.CharField(max_length=13, unique=True) + publisher = fields.ForeignKeyField("models.RelTestPublisher", related_name="publications", on_delete=fields.CASCADE) + published_date = fields.DateField(null=True) + + class Meta: + table = "rel_test_publications" + + +# Many-to-Many Relationship Models +class RelTestLearner(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + learner_id = fields.CharField(max_length=20, unique=True) + + class Meta: + table = "rel_test_learners" + + +class RelTestSubject(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + code = fields.CharField(max_length=10, unique=True) + credits = fields.IntField() + learners = fields.ManyToManyField("models.RelTestLearner", related_name="subjects") + + class Meta: + table = "rel_test_subjects" \ No newline at end of file diff --git a/tests/integration/container/tortoise/router/__init__.py b/tests/integration/container/tortoise/router/__init__.py new file mode 100644 index 00000000..7d84fea8 --- /dev/null +++ b/tests/integration/container/tortoise/router/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License \ No newline at end of file diff --git a/tests/integration/container/tortoise/router/test_router.py b/tests/integration/container/tortoise/router/test_router.py new file mode 100644 index 00000000..a0893277 --- /dev/null +++ b/tests/integration/container/tortoise/router/test_router.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +class TestRouter: + def db_for_read(self, model): + return "default" + + def db_for_write(self, model): + return "default" \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_basic.py b/tests/integration/container/tortoise/test_tortoise_basic.py index 4cbbbfde..f6c16661 100644 --- a/tests/integration/container/tortoise/test_tortoise_basic.py +++ b/tests/integration/container/tortoise/test_tortoise_basic.py @@ -13,22 +13,30 @@ # limitations under the License. import asyncio -import time import pytest import pytest_asyncio from tortoise.transactions import atomic, in_transaction +from tests.integration.container.tortoise.models.test_models import ( + UniqueName, User) from tests.integration.container.tortoise.test_tortoise_common import \ setup_tortoise -from tests.integration.container.tortoise.test_tortoise_models import ( - UniqueName, User) +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +@disable_on_engines([DatabaseEngine.PG]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseBasic: """ Test class for Tortoise ORM integration with AWS Advanced Python Wrapper. - Contains tests related to tortoise operations + Contains tests related to basic test operations. """ @pytest_asyncio.fixture @@ -95,8 +103,8 @@ async def test_basic_crud_operations_config(self, setup_tortoise_basic): async def test_transaction_support(self, setup_tortoise_basic): """Test transaction handling with AWS wrapper.""" async with in_transaction() as conn: - user1 = await User.create(name="User 1", email="user1@example.com", using_db=conn) - user2 = await User.create(name="User 2", email="user2@example.com", using_db=conn) + await User.create(name="User 1", email="user1@example.com", using_db=conn) + await User.create(name="User 2", email="user2@example.com", using_db=conn) # Verify users exist within transaction users = await User.filter(name__in=["User 1", "User 2"]).using_db(conn) diff --git a/tests/integration/container/tortoise/test_tortoise_common.py b/tests/integration/container/tortoise/test_tortoise_common.py index 3dd44277..3b08f005 100644 --- a/tests/integration/container/tortoise/test_tortoise_common.py +++ b/tests/integration/container/tortoise/test_tortoise_common.py @@ -14,11 +14,11 @@ import pytest import pytest_asyncio -from tortoise import Tortoise +from tortoise import Tortoise, connections # Import to register the aws-mysql backend import aws_advanced_python_wrapper.tortoise -from tests.integration.container.tortoise.test_tortoise_models import * +from tests.integration.container.tortoise.models.test_models import * from tests.integration.container.utils.rds_test_utility import RdsTestUtility from tests.integration.container.utils.test_environment import TestEnvironment @@ -27,7 +27,7 @@ async def clear_test_models(): """Clear all test models by calling .all().delete() on each.""" from tortoise.models import Model - import tests.integration.container.tortoise.test_tortoise_models as models_module + import tests.integration.container.tortoise.models.test_models as models_module for attr_name in dir(models_module): attr = getattr(models_module, attr_name) @@ -48,7 +48,7 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar }, "apps": { "models": { - "models": ["tests.integration.container.tortoise.test_tortoise_models"], + "models": ["tests.integration.container.tortoise.models.test_models"], "default_connection": "default", } } @@ -65,7 +65,7 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar yield await clear_test_models() - await Tortoise.close_connections() + await reset_tortoise() async def run_basic_read_operations(name_prefix="Test", email_prefix="test"): @@ -96,3 +96,10 @@ async def run_basic_write_operations(name_prefix="Write", email_prefix="write"): with pytest.raises(Exception): await User.get(id=user.id) + +async def reset_tortoise(): + await Tortoise.close_connections() + await Tortoise._reset_apps() + Tortoise._inited = False + Tortoise.apps = {} + connections._db_config = {} \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_config.py b/tests/integration/container/tortoise/test_tortoise_config.py new file mode 100644 index 00000000..98431bf6 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_config.py @@ -0,0 +1,292 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise, connections + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise +from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ + setup_tortoise_connection_provider +from tests.integration.container.tortoise.models.test_models import User +from tests.integration.container.tortoise.test_tortoise_common import \ + reset_tortoise +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseConfig: + """Test class for Tortoise ORM configuration scenarios.""" + + async def _clear_all_test_models(self): + """Clear all test models for a specific connection.""" + await User.all().delete() + + @pytest_asyncio.fixture + async def setup_tortoise_dict_config(self, conn_utils): + """Setup Tortoise with dictionary configuration instead of URL.""" + # Ensure clean state + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": conn_utils.dbname, + "plugins": "aurora_connection_tracker,failover", + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + } + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + await self._clear_all_test_models() + + yield + + await self._clear_all_test_models() + await reset_tortoise() + + @pytest_asyncio.fixture + async def setup_tortoise_multi_db(self, conn_utils): + """Setup Tortoise with two different databases using same backend.""" + # Create second database name + original_db = conn_utils.dbname + second_db = f"{original_db}_test2" + + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": original_db, + "plugins": "aurora_connection_tracker,failover", + } + }, + "second_db": { + "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": second_db, + "plugins": "aurora_connection_tracker,failover" + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + }, + "models2": { + "models": ["tests.integration.container.tortoise.models.test_models_copy"], + "default_connection": "second_db", + } + } + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + + # Create second database + conn = connections.get("second_db") + try: + await conn.db_create() + except Exception: + pass + + await Tortoise.generate_schemas() + + await self._clear_all_test_models() + + yield second_db + await self._clear_all_test_models() + + # Drop second database + conn = connections.get("second_db") + await conn.db_delete() + await reset_tortoise() + + @pytest_asyncio.fixture + async def setup_tortoise_with_router(self, conn_utils): + """Setup Tortoise with router configuration.""" + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": conn_utils.dbname, + "plugins": "aurora_connection_tracker,failover" + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + }, + "routers": ["tests.integration.container.tortoise.router.test_router.TestRouter"] + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + await self._clear_all_test_models() + + yield + + await self._clear_all_test_models() + await reset_tortoise() + + @pytest.mark.asyncio + async def test_dict_config_read_operations(self, setup_tortoise_dict_config): + """Test basic read operations with dictionary configuration.""" + # Create test data + user = await User.create(name="Dict Config User", email="dict@example.com") + + # Read operations + found_user = await User.get(id=user.id) + assert found_user.name == "Dict Config User" + assert found_user.email == "dict@example.com" + + # Query operations + users = await User.filter(name="Dict Config User") + assert len(users) == 1 + assert users[0].id == user.id + + @pytest.mark.asyncio + async def test_dict_config_write_operations(self, setup_tortoise_dict_config): + """Test basic write operations with dictionary configuration.""" + # Create + user = await User.create(name="Dict Write Test", email="dictwrite@example.com") + assert user.id is not None + + # Update + user.name = "Updated Dict User" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Updated Dict User" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_multi_db_operations(self, setup_tortoise_multi_db): + """Test operations with multiple databases using same backend.""" + # Get second connection + second_conn = connections.get("second_db") + + # Create users in different databases + user1 = await User.create(name="DB1 User", email="db1@example.com") + user2 = await User.create(name="DB2 User", email="db2@example.com", using_db=second_conn) + + # Verify users exist in their respective databases + db1_users = await User.all() + db2_users = await User.all().using_db(second_conn) + + assert len(db1_users) == 1 + assert len(db2_users) == 1 + assert db1_users[0].name == "DB1 User" + assert db2_users[0].name == "DB2 User" + + # Verify isolation - users don't exist in the other database + db1_user_in_db2 = await User.filter(name="DB1 User").using_db(second_conn) + db2_user_in_db1 = await User.filter(name="DB2 User") + + assert len(db1_user_in_db2) == 0 + assert len(db2_user_in_db1) == 0 + + # Test updates in different databases + user1.name = "Updated DB1 User" + await user1.save() + + user2.name = "Updated DB2 User" + await user2.save(using_db=second_conn) + + # Verify updates + updated_user1 = await User.get(id=user1.id) + updated_user2 = await User.get(id=user2.id, using_db=second_conn) + + assert updated_user1.name == "Updated DB1 User" + assert updated_user2.name == "Updated DB2 User" + + @pytest.mark.asyncio + async def test_router_read_operations(self, setup_tortoise_with_router): + """Test read operations with router configuration.""" + # Create test data + user = await User.create(name="Router User", email="router@example.com") + + # Read operations (should be routed by router) + found_user = await User.get(id=user.id) + assert found_user.name == "Router User" + assert found_user.email == "router@example.com" + + # Query operations + users = await User.filter(name="Router User") + assert len(users) == 1 + assert users[0].id == user.id + + @pytest.mark.asyncio + async def test_router_write_operations(self, setup_tortoise_with_router): + """Test write operations with router configuration.""" + # Create (should be routed by router) + user = await User.create(name="Router Write Test", email="routerwrite@example.com") + assert user.id is not None + + # Update (should be routed by router) + user.name = "Updated Router User" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Updated Router User" + + # Delete (should be routed by router) + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) diff --git a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py index 9a1697de..d70ea2c8 100644 --- a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py +++ b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py @@ -20,12 +20,26 @@ from boto3 import client from botocore.exceptions import ClientError +from tests.integration.container.tortoise.models.test_models import User from tests.integration.container.tortoise.test_tortoise_common import ( run_basic_read_operations, run_basic_write_operations, setup_tortoise) -from tests.integration.container.tortoise.test_tortoise_models import User +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_deployments([DatabaseEngineDeployment.AURORA]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseCustomEndpoint: """Test class for Tortoise ORM with custom endpoint plugin.""" endpoint_id = f"test-tortoise-endpoint-{uuid4()}" diff --git a/tests/integration/container/tortoise/test_tortoise_failover.py b/tests/integration/container/tortoise/test_tortoise_failover.py index 0abe36f6..fdd73c79 100644 --- a/tests/integration/container/tortoise/test_tortoise_failover.py +++ b/tests/integration/container/tortoise/test_tortoise_failover.py @@ -23,27 +23,38 @@ from aws_advanced_python_wrapper.errors import ( FailoverSuccessError, TransactionResolutionUnknownError) +from tests.integration.container.tortoise.models.test_models import \ + TableWithSleepTrigger from tests.integration.container.tortoise.test_tortoise_common import \ setup_tortoise -from tests.integration.container.tortoise.test_tortoise_models import \ - TableWithSleepTrigger +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_engines, enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment from tests.integration.container.utils.rds_test_utility import RdsTestUtility from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures from tests.integration.container.utils.test_utils import get_sleep_trigger_sql +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseFailover: - """Test class for Tortoise ORM with AWS wrapper plugins.""" + """Test class for Tortoise ORM with Failover.""" @pytest.fixture(scope='class') def aurora_utility(self): region: str = TestEnvironment.get_current().get_info().get_region() return RdsTestUtility(region) async def _create_sleep_trigger_record(self, name_prefix="Plugin Test", value="test_value", using_db=None): - """Create 3 TableWithSleepTrigger records.""" await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) - # await TableWithSleepTrigger.create(name=f"{name_prefix}2", value=value, using_db=using_db) - # await TableWithSleepTrigger.create(name=f"{name_prefix}3", value=value, using_db=using_db) @pytest_asyncio.fixture async def sleep_trigger_setup(self): @@ -61,6 +72,9 @@ async def setup_tortoise_with_failover(self, conn_utils): """Setup Tortoise with failover plugins.""" kwargs = { "topology_refresh_ms": 10000, + "connect_timeout": 10, + "monitoring-connect_timeout": 5, + "use_pure": True, # MySQL specific } async for result in setup_tortoise(conn_utils, plugins="failover", **kwargs): yield result @@ -68,7 +82,7 @@ async def setup_tortoise_with_failover(self, conn_utils): @pytest.mark.asyncio async def test_basic_operations_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): - """Test basic operations work with AWS wrapper plugins enabled.""" + """Test failover when inserting to a single table""" insert_exception = None def insert_thread(): @@ -81,7 +95,7 @@ def insert_thread(): insert_exception = e def failover_thread(): - time.sleep(5) # Wait for insert to start + time.sleep(3) # Wait for insert to start aurora_utility.failover_cluster_and_wait_until_writer_changed() # Start both threads @@ -160,7 +174,7 @@ async def test_concurrent_queries_with_failover(self, setup_tortoise_with_failov """Test concurrent queries with failover during long-running operation.""" connection = connections.get("default") - # Step 1: Run 15 concurrent select queries + # Run 15 concurrent select queries async def run_select_query(query_id): return await connection.execute_query(f"SELECT {query_id} as query_id") @@ -168,10 +182,10 @@ async def run_select_query(query_id): initial_results = await asyncio.gather(*initial_tasks) assert len(initial_results) == 15 - # Step 2: Run sleep query with failover + # Run insert query with failover sleep_exception = None - def sleep_query_thread(): + def insert_query_thread(): nonlocal sleep_exception for attempt in range(3): try: @@ -188,7 +202,7 @@ def failover_thread(): time.sleep(5) # Wait for sleep query to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - sleep_t = threading.Thread(target=sleep_query_thread) + sleep_t = threading.Thread(target=insert_query_thread) failover_t = threading.Thread(target=failover_thread) sleep_t.start() @@ -201,7 +215,7 @@ def failover_thread(): assert sleep_exception is not None assert isinstance(sleep_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) - # Step 3: Run another 15 concurrent select queries after failover + # Run another 15 concurrent select queries after failover to verify that we can still make connections post_failover_tasks = [run_select_query(i + 100) for i in range(15)] post_failover_results = await asyncio.gather(*post_failover_tasks) assert len(post_failover_results) == 15 diff --git a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py index 5dee829c..47160e1d 100644 --- a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py +++ b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py @@ -17,10 +17,19 @@ from tests.integration.container.tortoise.test_tortoise_common import ( run_basic_read_operations, run_basic_write_operations, setup_tortoise) -from tests.integration.container.tortoise.test_tortoise_models import User -from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features, + enable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_features([TestEnvironmentFeatures.IAM]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseIamAuthentication: """Test class for Tortoise ORM with IAM authentication.""" diff --git a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py index 83169d52..c9c0fe34 100644 --- a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py +++ b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py @@ -15,6 +15,7 @@ import asyncio import threading import time +from time import perf_counter_ns, sleep from uuid import uuid4 import pytest @@ -24,13 +25,30 @@ from tortoise import connections from aws_advanced_python_wrapper.errors import FailoverSuccessError -from tests.integration.container.tortoise.test_tortoise_common import setup_tortoise, run_basic_read_operations, run_basic_write_operations -from tests.integration.container.tortoise.test_tortoise_models import TableWithSleepTrigger +from tests.integration.container.tortoise.models.test_models import \ + TableWithSleepTrigger +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_features, enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment from tests.integration.container.utils.rds_test_utility import RdsTestUtility from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures from tests.integration.container.utils.test_utils import get_sleep_trigger_sql +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_features([TestEnvironmentFeatures.IAM]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseMultiPlugins: """Test class for Tortoise ORM with multiple AWS wrapper plugins.""" endpoint_id = f"test-multi-endpoint-{uuid4()}" @@ -71,7 +89,6 @@ def create_custom_endpoint(self): def _wait_until_endpoint_available(self, rds_client): """Wait for the custom endpoint to become available.""" - from time import perf_counter_ns, sleep end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes available = False @@ -175,7 +192,7 @@ async def test_concurrent_queries_with_failover(self, setup_tortoise_multi_plugi """Test concurrent queries with failover during long-running operation.""" connection = connections.get("default") - # Step 1: Run 15 concurrent select queries + # Run 15 concurrent select queries async def run_select_query(query_id): return await connection.execute_query(f"SELECT {query_id} as query_id") @@ -183,7 +200,7 @@ async def run_select_query(query_id): initial_results = await asyncio.gather(*initial_tasks) assert len(initial_results) == 15 - # Step 2: Run sleep query with failover + # Run sleep query with failover sleep_exception = None def sleep_query_thread(): @@ -216,7 +233,7 @@ def failover_thread(): assert sleep_exception is not None assert isinstance(sleep_exception, FailoverSuccessError) - # Step 3: Run another 15 concurrent select queries after failover + # Run another 15 concurrent select queries after failover post_failover_tasks = [run_select_query(i + 100) for i in range(15)] post_failover_results = await asyncio.gather(*post_failover_tasks) assert len(post_failover_results) == 15 diff --git a/tests/integration/container/tortoise/test_tortoise_relations.py b/tests/integration/container/tortoise/test_tortoise_relations.py new file mode 100644 index 00000000..59d0fe9a --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_relations.py @@ -0,0 +1,227 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise, connections +from tortoise.exceptions import IntegrityError + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise +from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ + setup_tortoise_connection_provider +from tests.integration.container.tortoise.models.test_models_relationships import ( + RelTestAccount, RelTestAccountProfile, RelTestLearner, RelTestPublication, + RelTestPublisher, RelTestSubject) +from tests.integration.container.tortoise.test_tortoise_common import \ + reset_tortoise +from tests.integration.container.utils.conditions import disable_on_engines +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment + + +async def clear_relationship_models(): + """Clear all relationship test models.""" + await RelTestSubject.all().delete() + await RelTestLearner.all().delete() + await RelTestPublication.all().delete() + await RelTestPublisher.all().delete() + await RelTestAccountProfile.all().delete() + await RelTestAccount.all().delete() + + + +@disable_on_engines([DatabaseEngine.PG]) +class TestTortoiseRelationships: + + @pytest_asyncio.fixture(scope="function") + async def setup_tortoise_relationships(self, conn_utils): + """Setup Tortoise with relationship models.""" + db_url = conn_utils.get_aws_tortoise_url( + TestEnvironment.get_current().get_engine(), + plugins="aurora_connection_tracker", + ) + + config = { + "connections": { + "default": db_url + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models_relationships"], + "default_connection": "default", + } + } + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + + # Clear any existing data + await clear_relationship_models() + + yield + + # Cleanup + await clear_relationship_models() + await reset_tortoise() + + @pytest.mark.asyncio + async def test_one_to_one_relationship_creation(self, setup_tortoise_relationships): + """Test creating one-to-one relationships""" + account = await RelTestAccount.create(username="testuser", email="test@example.com") + profile = await RelTestAccountProfile.create(account=account, bio="Test bio", avatar_url="http://example.com/avatar.jpg") + + # Test forward relationship + assert profile.account_id == account.id + + # Test reverse relationship + account_with_profile = await RelTestAccount.get(id=account.id).prefetch_related("profile") + assert account_with_profile.profile.bio == "Test bio" + + @pytest.mark.asyncio + async def test_one_to_one_cascade_delete(self, setup_tortoise_relationships): + """Test cascade delete in one-to-one relationship""" + account = await RelTestAccount.create(username="deleteuser", email="delete@example.com") + await RelTestAccountProfile.create(account=account, bio="Will be deleted") + + # Delete account should cascade to profile + await account.delete() + + # Profile should be deleted + profile_count = await RelTestAccountProfile.filter(account_id=account.id).count() + assert profile_count == 0 + + @pytest.mark.asyncio + async def test_one_to_one_unique_constraint(self, setup_tortoise_relationships): + """Test unique constraint in one-to-one relationship""" + account = await RelTestAccount.create(username="uniqueuser", email="unique@example.com") + await RelTestAccountProfile.create(account=account, bio="First profile") + + # Creating another profile for same account should fail + with pytest.raises(IntegrityError): + await RelTestAccountProfile.create(account=account, bio="Second profile") + + @pytest.mark.asyncio + async def test_one_to_many_relationship_creation(self, setup_tortoise_relationships): + """Test creating one-to-many relationships""" + publisher = await RelTestPublisher.create(name="Test Publisher", email="publisher@example.com") + pub1 = await RelTestPublication.create(title="Publication 1", isbn="1234567890123", publisher=publisher) + pub2 = await RelTestPublication.create(title="Publication 2", isbn="1234567890124", publisher=publisher) + + # Test forward relationship + assert pub1.publisher_id == publisher.id + assert pub2.publisher_id == publisher.id + + # Test reverse relationship + publisher_with_pubs = await RelTestPublisher.get(id=publisher.id).prefetch_related("publications") + assert len(publisher_with_pubs.publications) == 2 + assert {pub.title for pub in publisher_with_pubs.publications} == {"Publication 1", "Publication 2"} + + @pytest.mark.asyncio + async def test_one_to_many_cascade_delete(self, setup_tortoise_relationships): + """Test cascade delete in one-to-many relationship""" + publisher = await RelTestPublisher.create(name="Delete Publisher", email="deletepub@example.com") + await RelTestPublication.create(title="Delete Pub 1", isbn="9999999999999", publisher=publisher) + await RelTestPublication.create(title="Delete Pub 2", isbn="9999999999998", publisher=publisher) + + # Delete publisher should cascade to publications + await publisher.delete() + + # Publications should be deleted + pub_count = await RelTestPublication.filter(publisher_id=publisher.id).count() + assert pub_count == 0 + + @pytest.mark.asyncio + async def test_foreign_key_constraint(self, setup_tortoise_relationships): + """Test foreign key constraint enforcement""" + # Creating publication with non-existent publisher should fail + with pytest.raises(IntegrityError): + await RelTestPublication.create(title="Orphan Publication", isbn="0000000000000", publisher_id=99999) + + @pytest.mark.asyncio + async def test_many_to_many_relationship_creation(self, setup_tortoise_relationships): + """Test creating many-to-many relationships""" + learner1 = await RelTestLearner.create(name="Learner 1", learner_id="L001") + learner2 = await RelTestLearner.create(name="Learner 2", learner_id="L002") + subject1 = await RelTestSubject.create(name="Math 101", code="MATH101", credits=3) + subject2 = await RelTestSubject.create(name="Physics 101", code="PHYS101", credits=4) + + # Add learners to subjects + await subject1.learners.add(learner1, learner2) + await subject2.learners.add(learner1) + + # Test forward relationship + subject1_with_learners = await RelTestSubject.get(id=subject1.id).prefetch_related("learners") + assert len(subject1_with_learners.learners) == 2 + assert {l.name for l in subject1_with_learners.learners} == {"Learner 1", "Learner 2"} + + # Test reverse relationship + learner1_with_subjects = await RelTestLearner.get(id=learner1.id).prefetch_related("subjects") + assert len(learner1_with_subjects.subjects) == 2 + assert {s.name for s in learner1_with_subjects.subjects} == {"Math 101", "Physics 101"} + + @pytest.mark.asyncio + async def test_many_to_many_remove_relationship(self, setup_tortoise_relationships): + """Test removing many-to-many relationships""" + learner = await RelTestLearner.create(name="Remove Learner", learner_id="L003") + subject = await RelTestSubject.create(name="Remove Subject", code="REM101", credits=2) + + # Add relationship + await subject.learners.add(learner) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 1 + + # Remove relationship + await subject.learners.remove(learner) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 0 + + @pytest.mark.asyncio + async def test_many_to_many_clear_relationships(self, setup_tortoise_relationships): + """Test clearing all many-to-many relationships""" + learner1 = await RelTestLearner.create(name="Clear Learner 1", learner_id="L004") + learner2 = await RelTestLearner.create(name="Clear Learner 2", learner_id="L005") + subject = await RelTestSubject.create(name="Clear Subject", code="CLR101", credits=1) + + # Add multiple relationships + await subject.learners.add(learner1, learner2) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 2 + + # Clear all relationships + await subject.learners.clear() + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 0 + + @pytest.mark.asyncio + async def test_unique_constraints(self, setup_tortoise_relationships): + """Test unique constraints on model fields""" + # Test account unique username + await RelTestAccount.create(username="unique1", email="unique1@example.com") + with pytest.raises(IntegrityError): + await RelTestAccount.create(username="unique1", email="different@example.com") + + # Test publication unique ISBN + publisher = await RelTestPublisher.create(name="ISBN Publisher", email="isbn@example.com") + await RelTestPublication.create(title="ISBN Pub 1", isbn="1111111111111", publisher=publisher) + with pytest.raises(IntegrityError): + await RelTestPublication.create(title="ISBN Pub 2", isbn="1111111111111", publisher=publisher) + + # Test subject unique code + await RelTestSubject.create(name="Unique Subject 1", code="UNQ101", credits=3) + with pytest.raises(IntegrityError): + await RelTestSubject.create(name="Unique Subject 2", code="UNQ101", credits=4) \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py index d150c2fa..16989f21 100644 --- a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py +++ b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py @@ -18,12 +18,24 @@ import pytest import pytest_asyncio +from tests.integration.container.tortoise.models.test_models import User from tests.integration.container.tortoise.test_tortoise_common import ( run_basic_read_operations, run_basic_write_operations, setup_tortoise) -from tests.integration.container.tortoise.test_tortoise_models import User +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseSecretsManager: """Test class for Tortoise ORM with AWS Secrets Manager authentication.""" diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py index 17c9e592..d6d79c6e 100644 --- a/tests/unit/test_tortoise_base_client.py +++ b/tests/unit/test_tortoise_base_client.py @@ -21,7 +21,7 @@ AwsConnectionAsyncWrapper, TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext, - ConnectWithAwsWrapper, + AwsWrapperAsyncConnector, ) @@ -38,6 +38,7 @@ async def test_execute(self): with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "result" + result = await wrapper.execute("SELECT 1", ["param"]) mock_to_thread.assert_called_once_with(mock_cursor.execute, "SELECT 1", ["param"]) @@ -50,6 +51,7 @@ async def test_executemany(self): with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "result" + result = await wrapper.executemany("INSERT", [["param1"], ["param2"]]) mock_to_thread.assert_called_once_with(mock_cursor.executemany, "INSERT", [["param1"], ["param2"]]) @@ -62,6 +64,7 @@ async def test_fetchall(self): with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = [("row1",), ("row2",)] + result = await wrapper.fetchall() mock_to_thread.assert_called_once_with(mock_cursor.fetchall) @@ -74,6 +77,7 @@ async def test_fetchone(self): with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = ("row1",) + result = await wrapper.fetchone() mock_to_thread.assert_called_once_with(mock_cursor.fetchone) @@ -85,6 +89,8 @@ async def test_close(self): wrapper = AwsCursorAsyncWrapper(mock_cursor) with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + await wrapper.close() mock_to_thread.assert_called_once_with(mock_cursor.close) @@ -111,7 +117,7 @@ async def test_cursor_context_manager(self): wrapper = AwsConnectionAsyncWrapper(mock_connection) with patch('asyncio.to_thread') as mock_to_thread: - mock_to_thread.side_effect = [mock_cursor, None] # cursor creation, then close + mock_to_thread.side_effect = [mock_cursor, None] async with wrapper.cursor() as cursor: assert isinstance(cursor, AwsCursorAsyncWrapper) @@ -126,6 +132,7 @@ async def test_rollback(self): with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "rollback_result" + result = await wrapper.rollback() mock_to_thread.assert_called_once_with(mock_connection.rollback) @@ -138,6 +145,7 @@ async def test_commit(self): with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "commit_result" + result = await wrapper.commit() mock_to_thread.assert_called_once_with(mock_connection.commit) @@ -149,13 +157,11 @@ async def test_set_autocommit(self): wrapper = AwsConnectionAsyncWrapper(mock_connection) with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + await wrapper.set_autocommit(True) - mock_to_thread.assert_called_once() - # Verify the lambda function sets autocommit - lambda_func = mock_to_thread.call_args[0][0] - lambda_func() - assert mock_connection.autocommit == True + mock_to_thread.assert_called_once_with(setattr, mock_connection, 'autocommit', True) def test_getattr(self): mock_connection = MagicMock() @@ -168,23 +174,21 @@ def test_getattr(self): class TestTortoiseAwsClientConnectionWrapper: def test_init(self): mock_client = MagicMock() - mock_lock = asyncio.Lock() mock_connect_func = MagicMock() - wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_lock, mock_connect_func) + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_connect_func, with_db=True) assert wrapper.client == mock_client - assert wrapper._pool_init_lock == mock_lock assert wrapper.connect_func == mock_connect_func + assert wrapper.with_db == True assert wrapper.connection is None @pytest.mark.asyncio async def test_ensure_connection(self): mock_client = MagicMock() mock_client.create_connection = AsyncMock() - mock_lock = asyncio.Lock() - wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_lock, MagicMock()) + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, MagicMock(), with_db=True) await wrapper.ensure_connection() mock_client.create_connection.assert_called_once_with(with_db=True) @@ -194,43 +198,39 @@ async def test_context_manager(self): mock_client = MagicMock() mock_client._template = {"host": "localhost"} mock_client.create_connection = AsyncMock() - mock_lock = asyncio.Lock() mock_connect_func = MagicMock() mock_connection = MagicMock() - wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_lock, mock_connect_func) + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_connect_func, with_db=True) - with patch('aws_advanced_python_wrapper.tortoise.backend.base.client.ConnectWithAwsWrapper') as mock_connect: + with patch.object(AwsWrapperAsyncConnector, 'ConnectWithAwsWrapper') as mock_connect: mock_connect.return_value = mock_connection - with patch('asyncio.to_thread') as mock_to_thread: + with patch.object(AwsWrapperAsyncConnector, 'CloseAwsWrapper') as mock_close: async with wrapper as conn: assert conn == mock_connection assert wrapper.connection == mock_connection # Verify close was called on exit - mock_to_thread.assert_called_once_with(mock_connection.close) + mock_close.assert_called_once_with(mock_connection) class TestTortoiseAwsClientTransactionContext: def test_init(self): mock_client = MagicMock() mock_client.connection_name = "test_conn" - mock_lock = asyncio.Lock() - context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + context = TortoiseAwsClientTransactionContext(mock_client) assert context.client == mock_client assert context.connection_name == "test_conn" - assert context._pool_init_lock == mock_lock @pytest.mark.asyncio async def test_ensure_connection(self): mock_client = MagicMock() mock_client._parent.create_connection = AsyncMock() - mock_lock = asyncio.Lock() - context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + context = TortoiseAwsClientTransactionContext(mock_client) await context.ensure_connection() mock_client._parent.create_connection.assert_called_once_with(with_db=True) @@ -244,27 +244,23 @@ async def test_context_manager_commit(self): mock_client._finalized = False mock_client.begin = AsyncMock() mock_client.commit = AsyncMock() - mock_lock = asyncio.Lock() mock_connection = MagicMock() - context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) + context = TortoiseAwsClientTransactionContext(mock_client) - mock_to_thread_ref = None - - with patch('aws_advanced_python_wrapper.tortoise.backend.base.client.ConnectWithAwsWrapper') as mock_connect: + with patch.object(AwsWrapperAsyncConnector, 'ConnectWithAwsWrapper') as mock_connect: mock_connect.return_value = mock_connection with patch('tortoise.connection.connections') as mock_connections: mock_connections.set.return_value = "test_token" - with patch('asyncio.to_thread') as mock_to_thread: - mock_to_thread_ref = mock_to_thread + with patch.object(AwsWrapperAsyncConnector, 'CloseAwsWrapper') as mock_close: async with context as client: assert client == mock_client assert mock_client._connection == mock_connection # Verify commit was called and connection was closed mock_client.commit.assert_called_once() - mock_to_thread_ref.assert_called_once_with(mock_connection.close) + mock_close.assert_called_once_with(mock_connection) @pytest.mark.asyncio async def test_context_manager_rollback_on_exception(self): @@ -275,20 +271,16 @@ async def test_context_manager_rollback_on_exception(self): mock_client._finalized = False mock_client.begin = AsyncMock() mock_client.rollback = AsyncMock() - mock_lock = asyncio.Lock() mock_connection = MagicMock() - context = TortoiseAwsClientTransactionContext(mock_client, mock_lock) - - mock_to_thread_ref = None + context = TortoiseAwsClientTransactionContext(mock_client) - with patch('aws_advanced_python_wrapper.tortoise.backend.base.client.ConnectWithAwsWrapper') as mock_connect: + with patch.object(AwsWrapperAsyncConnector, 'ConnectWithAwsWrapper') as mock_connect: mock_connect.return_value = mock_connection with patch('tortoise.connection.connections') as mock_connections: mock_connections.set.return_value = "test_token" - with patch('asyncio.to_thread') as mock_to_thread: - mock_to_thread_ref = mock_to_thread + with patch.object(AwsWrapperAsyncConnector, 'CloseAwsWrapper') as mock_close: try: async with context as client: raise ValueError("Test exception") @@ -297,23 +289,32 @@ async def test_context_manager_rollback_on_exception(self): # Verify rollback was called and connection was closed mock_client.rollback.assert_called_once() - mock_to_thread_ref.assert_called_once_with(mock_connection.close) + mock_close.assert_called_once_with(mock_connection) -class TestConnectWithAwsWrapper: +class TestAwsWrapperAsyncConnector: @pytest.mark.asyncio async def test_connect_with_aws_wrapper(self): mock_connect_func = MagicMock() mock_connection = MagicMock() kwargs = {"host": "localhost", "user": "test"} - with patch('aws_advanced_python_wrapper.AwsWrapperConnection.connect') as mock_aws_connect: - mock_aws_connect.return_value = mock_connection - with patch('asyncio.to_thread') as mock_to_thread: - mock_to_thread.return_value = mock_connection - - result = await ConnectWithAwsWrapper(mock_connect_func, **kwargs) - - mock_to_thread.assert_called_once() - assert isinstance(result, AwsConnectionAsyncWrapper) - assert result._wrapped_connection == mock_connection \ No newline at end of file + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = mock_connection + + result = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(mock_connect_func, **kwargs) + + mock_to_thread.assert_called_once() + assert isinstance(result, AwsConnectionAsyncWrapper) + assert result._wrapped_connection == mock_connection + + @pytest.mark.asyncio + async def test_close_aws_wrapper(self): + mock_connection = MagicMock() + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await AwsWrapperAsyncConnector.CloseAwsWrapper(mock_connection) + + mock_to_thread.assert_called_once_with(mock_connection.close) \ No newline at end of file diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py index 82c4dffa..249edb96 100644 --- a/tests/unit/test_tortoise_mysql_client.py +++ b/tests/unit/test_tortoise_mysql_client.py @@ -57,18 +57,17 @@ async def test_func(self): class TestAwsMySQLClient: def test_init(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - charset="utf8mb4", - storage_engine="InnoDB", - minsize=2, - maxsize=10, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="utf8mb4", + storage_engine="InnoDB", + connection_name="test_conn" + ) assert client.user == "test_user" assert client.password == "test_pass" @@ -77,124 +76,97 @@ def test_init(self): assert client.port == 3306 assert client.charset == "utf8mb4" assert client.storage_engine == "InnoDB" - assert client.pool_minsize == 2 - assert client.pool_maxsize == 10 def test_init_defaults(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) assert client.charset == "utf8mb4" assert client.storage_engine == "innodb" - assert client.pool_minsize == 1 - assert client.pool_maxsize == 5 - - def test_configure_pool(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - maxsize=15, - connection_name="test_conn" - ) - - mock_host_info = MagicMock() - mock_props = {} - - result = client._configure_pool(mock_host_info, mock_props) - assert result == {"pool_size": 15} - - def test_get_pool_key(self): - mock_host_info = MagicMock() - mock_host_info.url = "mysql://localhost:3306" - mock_props = {"user": "testuser", "database": "testdb"} - - result = AwsMySQLClient._get_pool_key(mock_host_info, mock_props) - assert result == "mysql://localhost:3306testusertestdb" @pytest.mark.asyncio async def test_create_connection_invalid_charset(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - charset="invalid_charset", - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="invalid_charset", + connection_name="test_conn" + ) with pytest.raises(DBConnectionError, match="Unknown character set"): await client.create_connection(with_db=True) @pytest.mark.asyncio async def test_create_connection_success(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) - - with patch.object(client, '_init_pool_if_needed') as mock_init_pool: - mock_init_pool.return_value = None - - with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): - await client.create_connection(with_db=True) - - assert client._template["user"] == "test_user" - assert client._template["database"] == "test_db" - assert client._template["autocommit"] is True - mock_init_pool.assert_called_once() + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + await client.create_connection(with_db=True) + + assert client._template["user"] == "test_user" + assert client._template["database"] == "test_db" + assert client._template["autocommit"] is True @pytest.mark.asyncio async def test_close(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) # close() method now does nothing (AWS wrapper handles cleanup) await client.close() # No assertions needed since method is a no-op def test_acquire_connection(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) connection_wrapper = client.acquire_connection() assert connection_wrapper.client == client @pytest.mark.asyncio async def test_execute_insert(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) mock_connection = MagicMock() mock_cursor = MagicMock() @@ -215,15 +187,16 @@ async def test_execute_insert(self): @pytest.mark.asyncio async def test_execute_many_with_transactions(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn", - storage_engine="innodb" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + storage_engine="innodb" + ) mock_connection = MagicMock() mock_cursor = MagicMock() @@ -246,14 +219,15 @@ async def test_execute_many_with_transactions(self): @pytest.mark.asyncio async def test_execute_query(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) mock_connection = MagicMock() mock_cursor = MagicMock() @@ -277,14 +251,15 @@ async def test_execute_query(self): @pytest.mark.asyncio async def test_execute_query_dict(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) with patch.object(client, 'execute_query') as mock_execute_query: mock_execute_query.return_value = (2, [{"id": 1}, {"id": 2}]) @@ -294,14 +269,15 @@ async def test_execute_query_dict(self): assert results == [{"id": 1}, {"id": 2}] def test_in_transaction(self): - client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) transaction_context = client._in_transaction() assert transaction_context is not None @@ -309,14 +285,15 @@ def test_in_transaction(self): class TestTransactionWrapper: def test_init(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) @@ -327,14 +304,15 @@ def test_init(self): @pytest.mark.asyncio async def test_begin(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() @@ -348,14 +326,15 @@ async def test_begin(self): @pytest.mark.asyncio async def test_commit(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() @@ -371,14 +350,15 @@ async def test_commit(self): @pytest.mark.asyncio async def test_commit_already_finalized(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) wrapper._finalized = True @@ -388,35 +368,39 @@ async def test_commit_already_finalized(self): @pytest.mark.asyncio async def test_rollback(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() mock_connection.rollback = AsyncMock() + mock_connection.set_autocommit = AsyncMock() wrapper._connection = mock_connection await wrapper.rollback() mock_connection.rollback.assert_called_once() + mock_connection.set_autocommit.assert_called_once_with(True) assert wrapper._finalized is True @pytest.mark.asyncio async def test_savepoint(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() @@ -435,14 +419,15 @@ async def test_savepoint(self): @pytest.mark.asyncio async def test_savepoint_rollback(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) wrapper._savepoint = "test_savepoint" @@ -462,14 +447,15 @@ async def test_savepoint_rollback(self): @pytest.mark.asyncio async def test_savepoint_rollback_no_savepoint(self): - parent_client = AwsMySQLClient( - user="test_user", - password="test_pass", - database="test_db", - host="localhost", - port=3306, - connection_name="test_conn" - ) + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) wrapper = TransactionWrapper(parent_client) wrapper._savepoint = None From 6cb65ecc4fdce688ff71a52dc324e5dd1084a2e9 Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Fri, 5 Dec 2025 01:19:08 -0800 Subject: [PATCH 07/11] fix formatting --- aws_advanced_python_wrapper/pg_driver_dialect.py | 2 +- aws_advanced_python_wrapper/tortoise/__init__.py | 6 +++--- .../tortoise/backend/__init__.py | 2 +- .../tortoise/backend/base/__init__.py | 2 +- .../tortoise/backend/mysql/__init__.py | 2 +- .../tortoise/backend/mysql/client.py | 10 ++-------- .../integration/container/tortoise/models/__init__.py | 2 +- .../container/tortoise/models/test_models.py | 2 +- .../container/tortoise/models/test_models_copy.py | 2 +- .../tortoise/models/test_models_relationships.py | 2 +- .../integration/container/tortoise/router/__init__.py | 2 +- .../container/tortoise/test_tortoise_basic.py | 2 +- .../container/tortoise/test_tortoise_common.py | 2 +- .../tortoise/test_tortoise_custom_endpoint.py | 2 +- .../container/tortoise/test_tortoise_failover.py | 1 - .../tortoise/test_tortoise_iam_authentication.py | 2 +- .../container/tortoise/test_tortoise_multi_plugins.py | 1 - .../container/tortoise/test_tortoise_relations.py | 2 +- .../tortoise/test_tortoise_secrets_manager.py | 2 +- .../container/utils/test_environment_request.py | 2 +- tests/unit/test_tortoise_base_client.py | 2 +- tests/unit/test_tortoise_mysql_client.py | 2 +- 22 files changed, 23 insertions(+), 31 deletions(-) diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index 7d02e118..b55ba7c1 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -175,4 +175,4 @@ def supports_abort_connection(self) -> bool: return True def get_driver_module(self) -> ModuleType: - return psycopg \ No newline at end of file + return psycopg diff --git a/aws_advanced_python_wrapper/tortoise/__init__.py b/aws_advanced_python_wrapper/tortoise/__init__.py index ec7d689a..3059fbee 100644 --- a/aws_advanced_python_wrapper/tortoise/__init__.py +++ b/aws_advanced_python_wrapper/tortoise/__init__.py @@ -28,10 +28,10 @@ "cast": { "minsize": int, "maxsize": int, - "connect_timeout": float, + "connect_timeout": int, "echo": bool, "use_unicode": bool, - "pool_recycle": int, "ssl": bool, + "use_pure": bool, }, -} \ No newline at end of file +} diff --git a/aws_advanced_python_wrapper/tortoise/backend/__init__.py b/aws_advanced_python_wrapper/tortoise/backend/__init__.py index 7e87cab9..bd4acb2b 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/__init__.py +++ b/aws_advanced_python_wrapper/tortoise/backend/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py b/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py index 7e87cab9..bd4acb2b 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py +++ b/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py index d66f3284..1092dd31 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from .client import AwsMySQLClient -client_class = AwsMySQLClient \ No newline at end of file +client_class = AwsMySQLClient diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py index c63626b2..d5db2530 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -127,13 +127,7 @@ def __init__( self.extra.pop("fetch_inserted", None) self.extra.pop("autocommit", None) self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") - - # Ensure that timeouts are integers - timeout_params = ["connect_timeout", "monitoring-connect_timeout",] - for param in timeout_params: - if param in self.extra and self.extra[param] is not None: - self.extra[param] = int(self.extra[param]) - + # Initialize connection templates self._init_connection_templates() @@ -354,4 +348,4 @@ async def execute_many(self, query: str, values: List[List[Any]]) -> None: def _gen_savepoint_name(_c: count = count()) -> str: """Generate a unique savepoint name.""" - return f"tortoise_savepoint_{next(_c)}" \ No newline at end of file + return f"tortoise_savepoint_{next(_c)}" diff --git a/tests/integration/container/tortoise/models/__init__.py b/tests/integration/container/tortoise/models/__init__.py index 7d84fea8..12de791f 100644 --- a/tests/integration/container/tortoise/models/__init__.py +++ b/tests/integration/container/tortoise/models/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License \ No newline at end of file +# limitations under the License diff --git a/tests/integration/container/tortoise/models/test_models.py b/tests/integration/container/tortoise/models/test_models.py index 60f1e043..dca5330a 100644 --- a/tests/integration/container/tortoise/models/test_models.py +++ b/tests/integration/container/tortoise/models/test_models.py @@ -41,4 +41,4 @@ class TableWithSleepTrigger(Model): value = fields.CharField(max_length=100) class Meta: - table = "table_with_sleep_trigger" \ No newline at end of file + table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_copy.py b/tests/integration/container/tortoise/models/test_models_copy.py index 60f1e043..dca5330a 100644 --- a/tests/integration/container/tortoise/models/test_models_copy.py +++ b/tests/integration/container/tortoise/models/test_models_copy.py @@ -41,4 +41,4 @@ class TableWithSleepTrigger(Model): value = fields.CharField(max_length=100) class Meta: - table = "table_with_sleep_trigger" \ No newline at end of file + table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_relationships.py b/tests/integration/container/tortoise/models/test_models_relationships.py index b941311d..53bc622f 100644 --- a/tests/integration/container/tortoise/models/test_models_relationships.py +++ b/tests/integration/container/tortoise/models/test_models_relationships.py @@ -75,4 +75,4 @@ class RelTestSubject(Model): learners = fields.ManyToManyField("models.RelTestLearner", related_name="subjects") class Meta: - table = "rel_test_subjects" \ No newline at end of file + table = "rel_test_subjects" diff --git a/tests/integration/container/tortoise/router/__init__.py b/tests/integration/container/tortoise/router/__init__.py index 7d84fea8..12de791f 100644 --- a/tests/integration/container/tortoise/router/__init__.py +++ b/tests/integration/container/tortoise/router/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License \ No newline at end of file +# limitations under the License diff --git a/tests/integration/container/tortoise/test_tortoise_basic.py b/tests/integration/container/tortoise/test_tortoise_basic.py index f6c16661..25530775 100644 --- a/tests/integration/container/tortoise/test_tortoise_basic.py +++ b/tests/integration/container/tortoise/test_tortoise_basic.py @@ -260,4 +260,4 @@ async def create_users_with_error(): # Verify user was not created due to rollback users = await User.filter(name="Atomic Rollback User") - assert len(users) == 0 \ No newline at end of file + assert len(users) == 0 diff --git a/tests/integration/container/tortoise/test_tortoise_common.py b/tests/integration/container/tortoise/test_tortoise_common.py index 3b08f005..1cbf2a1c 100644 --- a/tests/integration/container/tortoise/test_tortoise_common.py +++ b/tests/integration/container/tortoise/test_tortoise_common.py @@ -102,4 +102,4 @@ async def reset_tortoise(): await Tortoise._reset_apps() Tortoise._inited = False Tortoise.apps = {} - connections._db_config = {} \ No newline at end of file + connections._db_config = {} diff --git a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py index d70ea2c8..de51eb5f 100644 --- a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py +++ b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py @@ -119,4 +119,4 @@ async def test_basic_read_operations(self, setup_tortoise_custom_endpoint): @pytest.mark.asyncio async def test_basic_write_operations(self, setup_tortoise_custom_endpoint): """Test basic write operations with custom endpoint plugin.""" - await run_basic_write_operations("Custom", "customwrite") \ No newline at end of file + await run_basic_write_operations("Custom", "customwrite") diff --git a/tests/integration/container/tortoise/test_tortoise_failover.py b/tests/integration/container/tortoise/test_tortoise_failover.py index fdd73c79..baccfd2e 100644 --- a/tests/integration/container/tortoise/test_tortoise_failover.py +++ b/tests/integration/container/tortoise/test_tortoise_failover.py @@ -260,4 +260,3 @@ def failover_thread(): assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" failover_errors = [e for e in insert_exceptions if isinstance(e, (FailoverSuccessError, TransactionResolutionUnknownError))] assert len(failover_errors) == 15, f"Expected all 15 threads to get failover errors, got {len(failover_errors)}" - print(f"Successfully verified all 15 threads got failover exceptions") \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py index 47160e1d..cf76ac57 100644 --- a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py +++ b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py @@ -47,4 +47,4 @@ async def test_basic_read_operations(self, setup_tortoise_iam): @pytest.mark.asyncio async def test_basic_write_operations(self, setup_tortoise_iam): """Test basic write operations with IAM authentication.""" - await run_basic_write_operations() \ No newline at end of file + await run_basic_write_operations() diff --git a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py index c9c0fe34..13856a49 100644 --- a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py +++ b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py @@ -278,4 +278,3 @@ def failover_thread(): assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" failover_errors = [e for e in insert_exceptions if isinstance(e, FailoverSuccessError)] assert len(failover_errors) == 15, f"Expected all 15 threads to get failover errors, got {len(failover_errors)}" - print(f"Successfully verified all 15 threads got failover exceptions with multiple plugins") \ No newline at end of file diff --git a/tests/integration/container/tortoise/test_tortoise_relations.py b/tests/integration/container/tortoise/test_tortoise_relations.py index 59d0fe9a..d2985857 100644 --- a/tests/integration/container/tortoise/test_tortoise_relations.py +++ b/tests/integration/container/tortoise/test_tortoise_relations.py @@ -224,4 +224,4 @@ async def test_unique_constraints(self, setup_tortoise_relationships): # Test subject unique code await RelTestSubject.create(name="Unique Subject 1", code="UNQ101", credits=3) with pytest.raises(IntegrityError): - await RelTestSubject.create(name="Unique Subject 2", code="UNQ101", credits=4) \ No newline at end of file + await RelTestSubject.create(name="Unique Subject 2", code="UNQ101", credits=4) diff --git a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py index 16989f21..65763683 100644 --- a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py +++ b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py @@ -83,4 +83,4 @@ async def test_basic_read_operations(self, setup_tortoise_secrets_manager): @pytest.mark.asyncio async def test_basic_write_operations(self, setup_tortoise_secrets_manager): """Test basic write operations with Secrets Manager authentication.""" - await run_basic_write_operations("Secrets", "secretswrite") \ No newline at end of file + await run_basic_write_operations("Secrets", "secretswrite") diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index 93fe23bf..f8bd36ec 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -40,7 +40,7 @@ def __init__(self, request: Dict[str, Any]) -> None: self._instances = getattr(DatabaseInstances, typing.cast('str', request.get("instances"))) self._deployment = getattr(DatabaseEngineDeployment, typing.cast('str', request.get("deployment"))) self._target_python_version = getattr(TargetPythonVersion, typing.cast('str', request.get("targetPythonVersion"))) - self._num_of_instances = typing.cast('int', 2) + self._num_of_instances = request.get("numOfInstances") self._features = set() features: Iterable[str] = typing.cast('Iterable[str]', request.get("features")) if features is not None: diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py index d6d79c6e..4074a9ee 100644 --- a/tests/unit/test_tortoise_base_client.py +++ b/tests/unit/test_tortoise_base_client.py @@ -317,4 +317,4 @@ async def test_close_aws_wrapper(self): await AwsWrapperAsyncConnector.CloseAwsWrapper(mock_connection) - mock_to_thread.assert_called_once_with(mock_connection.close) \ No newline at end of file + mock_to_thread.assert_called_once_with(mock_connection.close) diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py index 249edb96..5eb4b572 100644 --- a/tests/unit/test_tortoise_mysql_client.py +++ b/tests/unit/test_tortoise_mysql_client.py @@ -471,4 +471,4 @@ def test_gen_savepoint_name(self): assert name1.startswith("tortoise_savepoint_") assert name2.startswith("tortoise_savepoint_") - assert name1 != name2 # Should be unique \ No newline at end of file + assert name1 != name2 # Should be unique From ae3259574b8b49c1a6e8bbea1d4bc01cbb46c737 Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Fri, 5 Dec 2025 16:09:56 -0800 Subject: [PATCH 08/11] fix static check --- aws_advanced_python_wrapper/driver_dialect.py | 2 +- .../driver_dialect_manager.py | 3 +- .../mysql_driver_dialect.py | 8 +- .../pg_driver_dialect.py | 4 +- .../sql_alchemy_connection_provider.py | 2 +- .../sqlalchemy_driver_dialect.py | 3 +- .../tortoise/__init__.py | 12 +- .../tortoise/backend/base/client.py | 78 ++++++---- .../tortoise/backend/mysql/client.py | 73 ++++----- .../tortoise/backend/mysql/executor.py | 5 +- .../backend/mysql/schema_generator.py | 5 +- ...ql_alchemy_tortoise_connection_provider.py | 63 ++++---- aws_advanced_python_wrapper/tortoise/utils.py | 6 +- aws_advanced_python_wrapper/wrapper.py | 6 +- tests/integration/container/conftest.py | 4 +- .../container/tortoise/models/test_models.py | 6 +- .../tortoise/models/test_models_copy.py | 6 +- .../models/test_models_relationships.py | 23 ++- .../container/tortoise/router/test_router.py | 4 +- .../container/tortoise/test_tortoise_basic.py | 82 +++++----- .../tortoise/test_tortoise_common.py | 32 ++-- .../tortoise/test_tortoise_config.py | 70 ++++----- .../tortoise/test_tortoise_custom_endpoint.py | 25 ++- .../tortoise/test_tortoise_failover.py | 83 +++++----- .../test_tortoise_iam_authentication.py | 2 +- .../tortoise/test_tortoise_multi_plugins.py | 85 ++++++----- .../tortoise/test_tortoise_relations.py | 78 +++++----- .../tortoise/test_tortoise_secrets_manager.py | 15 +- .../container/utils/connection_utils.py | 8 +- .../utils/test_environment_request.py | 2 +- .../integration/container/utils/test_utils.py | 7 +- tests/unit/test_tortoise_base_client.py | 142 +++++++++--------- tests/unit/test_tortoise_mysql_client.py | 114 +++++++------- 33 files changed, 540 insertions(+), 518 deletions(-) diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index f1b338e7..7eb7fd41 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -165,7 +165,7 @@ def ping(self, conn: Connection) -> bool: return True except Exception: return False - + def get_driver_module(self) -> ModuleType: raise UnsupportedOperationError( Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "get_driver_module")) diff --git a/aws_advanced_python_wrapper/driver_dialect_manager.py b/aws_advanced_python_wrapper/driver_dialect_manager.py index 420008dc..7f754491 100644 --- a/aws_advanced_python_wrapper/driver_dialect_manager.py +++ b/aws_advanced_python_wrapper/driver_dialect_manager.py @@ -24,7 +24,8 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperties +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) from aws_advanced_python_wrapper.utils.utils import Utils logger = Logger(__name__) diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index 6b0d7a95..6dd4d455 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set +import mysql.connector + if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection @@ -29,11 +30,12 @@ from aws_advanced_python_wrapper.errors import UnsupportedOperationError from aws_advanced_python_wrapper.utils.decorators import timeout from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties +from aws_advanced_python_wrapper.utils.properties import (Properties, + PropertiesUtils, + WrapperProperties) CMYSQL_ENABLED = False -import mysql.connector from mysql.connector import MySQLConnection # noqa: E402 from mysql.connector.cursor import MySQLCursor # noqa: E402 diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index b55ba7c1..816540c9 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -28,7 +28,9 @@ from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties +from aws_advanced_python_wrapper.utils.properties import (Properties, + PropertiesUtils, + WrapperProperties) class PgDriverDialect(DriverDialect): diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index cfdf0c6b..0523ea56 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -88,7 +88,7 @@ def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: if self._accept_url_func: return self._accept_url_func(host_info, props) url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) - return RdsUrlType.RDS_INSTANCE == url_type or RdsUrlType.RDS_WRITER_CLUSTER + return url_type in (RdsUrlType.RDS_INSTANCE, RdsUrlType.RDS_WRITER_CLUSTER) def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS or strategy in self._accepted_strategies diff --git a/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py b/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py index d167be2d..fad08fd9 100644 --- a/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py @@ -32,6 +32,7 @@ class SqlAlchemyDriverDialect(DriverDialect): _driver_name: str = "SQLAlchemy" TARGET_DRIVER_CODE: str = "sqlalchemy" + _underlying_driver_dialect = None def __init__(self, underlying_driver: DriverDialect, props: Properties): super().__init__(props) @@ -126,6 +127,6 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection): return None return self._underlying_driver.transfer_session_state(from_driver_conn, to_driver_conn) - + def get_driver_module(self) -> ModuleType: return self._underlying_driver.get_driver_module() diff --git a/aws_advanced_python_wrapper/tortoise/__init__.py b/aws_advanced_python_wrapper/tortoise/__init__.py index 3059fbee..7e384786 100644 --- a/aws_advanced_python_wrapper/tortoise/__init__.py +++ b/aws_advanced_python_wrapper/tortoise/__init__.py @@ -18,12 +18,12 @@ DB_LOOKUP["aws-mysql"] = { "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", "vmap": { - "path": "database", - "hostname": "host", - "port": "port", - "username": "user", - "password": "password", - }, + "path": "database", + "hostname": "host", + "port": "port", + "username": "user", + "password": "password", + }, "defaults": {"port": 3306, "charset": "utf8mb4", "sql_mode": "STRICT_TRANS_TABLES"}, "cast": { "minsize": int, diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/client.py b/aws_advanced_python_wrapper/tortoise/backend/base/client.py index bb31b224..8d901c0f 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/base/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/base/client.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio -import mysql.connector from contextlib import asynccontextmanager -from typing import Any, Callable, Generic +from typing import Any, Callable, Dict, Generic, cast -from tortoise.backends.base.client import BaseDBAsyncClient, T_conn, TransactionalDBClient, TransactionContext +import mysql.connector +from tortoise.backends.base.client import (BaseDBAsyncClient, T_conn, + TransactionalDBClient, + TransactionContext) from tortoise.connection import connections from tortoise.exceptions import TransactionManagementError @@ -26,47 +30,47 @@ class AwsWrapperAsyncConnector: """Class for creating and closing AWS wrapper connections.""" - + @staticmethod - async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection: + async def connect_with_aws_wrapper(connect_func: Callable, **kwargs) -> AwsConnectionAsyncWrapper: """Create an AWS wrapper connection with async cursor support.""" connection = await asyncio.to_thread( AwsWrapperConnection.connect, connect_func, **kwargs ) return AwsConnectionAsyncWrapper(connection) - + @staticmethod - async def CloseAwsWrapper(connection: AwsWrapperConnection) -> None: + async def close_aws_wrapper(connection: AwsWrapperConnection) -> None: """Close an AWS wrapper connection asynchronously.""" await asyncio.to_thread(connection.close) class AwsCursorAsyncWrapper: """Wraps sync AwsCursor cursor with async support.""" - + def __init__(self, sync_cursor): self._cursor = sync_cursor - + async def execute(self, query, params=None): """Execute a query asynchronously.""" return await asyncio.to_thread(self._cursor.execute, query, params) - + async def executemany(self, query, params_list): """Execute multiple queries asynchronously.""" return await asyncio.to_thread(self._cursor.executemany, query, params_list) - + async def fetchall(self): """Fetch all results asynchronously.""" return await asyncio.to_thread(self._cursor.fetchall) - + async def fetchone(self): """Fetch one result asynchronously.""" return await asyncio.to_thread(self._cursor.fetchone) - + async def close(self): """Close cursor asynchronously.""" return await asyncio.to_thread(self._cursor.close) - + def __getattr__(self, name): """Delegate non-async attributes to the wrapped cursor.""" return getattr(self._cursor, name) @@ -74,7 +78,7 @@ def __getattr__(self, name): class AwsConnectionAsyncWrapper(AwsWrapperConnection): """Wraps sync AwsConnection with async cursor support.""" - + def __init__(self, connection: AwsWrapperConnection): self._wrapped_connection = connection @@ -90,11 +94,11 @@ async def cursor(self): async def rollback(self): """Rollback the current transaction.""" return await asyncio.to_thread(self._wrapped_connection.rollback) - + async def commit(self): """Commit the current transaction.""" return await asyncio.to_thread(self._wrapped_connection.commit) - + async def set_autocommit(self, value: bool): """Set autocommit mode.""" return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value) @@ -102,7 +106,7 @@ async def set_autocommit(self, value: bool): def __getattr__(self, name): """Delegate all other attributes/methods to the wrapped connection.""" return getattr(self._wrapped_connection, name) - + def __del__(self): """Delegate cleanup to wrapped connection.""" if hasattr(self, '_wrapped_connection'): @@ -110,20 +114,30 @@ def __del__(self): pass +class AwsBaseDBAsyncClient(BaseDBAsyncClient): + _template: Dict[str, Any] + + +class AwsTransactionalDBClient(TransactionalDBClient): + _template: Dict[str, Any] + _parent: AwsBaseDBAsyncClient + pass + + class TortoiseAwsClientConnectionWrapper(Generic[T_conn]): """Manages acquiring from and releasing connections to a pool.""" __slots__ = ("client", "connection", "connect_func", "with_db") def __init__( - self, - client: BaseDBAsyncClient, - connect_func: Callable, + self, + client: AwsBaseDBAsyncClient, + connect_func: Callable, with_db: bool = True ) -> None: self.connect_func = connect_func self.client = client - self.connection: T_conn | None = None + self.connection: AwsConnectionAsyncWrapper | None = None self.with_db = with_db async def ensure_connection(self) -> None: @@ -133,13 +147,13 @@ async def ensure_connection(self) -> None: async def __aenter__(self) -> T_conn: """Acquire connection from pool.""" await self.ensure_connection() - self.connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(self.connect_func, **self.client._template) - return self.connection + self.connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(self.connect_func, **self.client._template) + return cast("T_conn", self.connection) async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Close connection and release back to pool.""" if self.connection: - await AwsWrapperAsyncConnector.CloseAwsWrapper(self.connection) + await AwsWrapperAsyncConnector.close_aws_wrapper(self.connection) class TortoiseAwsClientTransactionContext(TransactionContext): @@ -147,8 +161,8 @@ class TortoiseAwsClientTransactionContext(TransactionContext): __slots__ = ("client", "connection_name", "token") - def __init__(self, client: TransactionalDBClient) -> None: - self.client = client + def __init__(self, client: AwsTransactionalDBClient) -> None: + self.client: AwsTransactionalDBClient = client self.connection_name = client.connection_name async def ensure_connection(self) -> None: @@ -158,13 +172,13 @@ async def ensure_connection(self) -> None: async def __aenter__(self) -> TransactionalDBClient: """Enter transaction context.""" await self.ensure_connection() - + # Set the context variable so the current task sees a TransactionWrapper connection self.token = connections.set(self.connection_name, self.client) - + # Create connection and begin transaction - self.client._connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper( - mysql.connector.Connect, + self.client._connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper( + mysql.connector.Connect, **self.client._parent._template ) await self.client.begin() @@ -181,5 +195,5 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: else: await self.client.commit() finally: - await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection) + await AwsWrapperAsyncConnector.close_aws_wrapper(self.client._connection) connections.reset(self.token) diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py index d5db2530..733a7a30 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py @@ -15,38 +15,28 @@ import asyncio from functools import wraps from itertools import count -from typing import Any, Callable, Coroutine, Dict, List, Optional, SupportsInt, Tuple, TypeVar +from typing import (Any, Callable, Coroutine, Dict, List, Optional, + SupportsInt, Tuple, TypeVar) import mysql.connector -import sqlparse +import sqlparse # type: ignore[import-untyped] from mysql.connector import errors from mysql.connector.charsets import MYSQL_CHARACTER_SETS from pypika_tortoise import MySQLQuery -from tortoise.backends.base.client import ( - BaseDBAsyncClient, - Capabilities, - ConnectionWrapper, - NestedTransactionContext, - TransactionalDBClient, - TransactionContext, -) -from tortoise.exceptions import ( - DBConnectionError, - IntegrityError, - OperationalError, - TransactionManagementError, -) - -from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager, ConnectionProvider +from tortoise.backends.base.client import (Capabilities, ConnectionWrapper, + NestedTransactionContext, + TransactionContext) +from tortoise.exceptions import (DBConnectionError, IntegrityError, + OperationalError, TransactionManagementError) + from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError -from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.tortoise.backend.base.client import ( - TortoiseAwsClientConnectionWrapper, - TortoiseAwsClientTransactionContext, -) -from aws_advanced_python_wrapper.tortoise.backend.mysql.executor import AwsMySQLExecutor -from aws_advanced_python_wrapper.tortoise.backend.mysql.schema_generator import AwsMySQLSchemaGenerator -from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import SqlAlchemyTortoisePooledConnectionProvider + AwsBaseDBAsyncClient, AwsConnectionAsyncWrapper, AwsTransactionalDBClient, + TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext) +from aws_advanced_python_wrapper.tortoise.backend.mysql.executor import \ + AwsMySQLExecutor +from aws_advanced_python_wrapper.tortoise.backend.mysql.schema_generator import \ + AwsMySQLSchemaGenerator from aws_advanced_python_wrapper.utils.log import Logger logger = Logger(__name__) @@ -61,11 +51,11 @@ async def translate_exceptions_(self, *args) -> T: try: try: return await func(self, *args) - except AwsWrapperError as aws_err: # Unwrap any AwsWrappedErrors + except AwsWrapperError as aws_err: # Unwrap any AwsWrappedErrors if aws_err.__cause__: raise aws_err.__cause__ raise - except FailoverError as exc: # Raise any failover errors + except FailoverError: # Raise any failover errors raise except errors.IntegrityError as exc: raise IntegrityError(exc) @@ -81,7 +71,8 @@ async def translate_exceptions_(self, *args) -> T: return translate_exceptions_ -class AwsMySQLClient(BaseDBAsyncClient): + +class AwsMySQLClient(AwsBaseDBAsyncClient): """AWS Advanced Python Wrapper MySQL client for Tortoise ORM.""" query_class = MySQLQuery executor_class = AwsMySQLExecutor @@ -94,8 +85,6 @@ class AwsMySQLClient(BaseDBAsyncClient): support_for_posix_regex_queries=True, support_json_attributes=True, ) - _provider: Optional[ConnectionProvider] = None - _pool_init_class_lock = asyncio.Lock() def __init__( self, @@ -109,7 +98,7 @@ def __init__( ): """Initialize AWS MySQL client with connection parameters.""" super().__init__(**kwargs) - + # Basic connection parameters self.user = user self.password = password @@ -121,7 +110,7 @@ def __init__( # Extract MySQL-specific settings self.storage_engine = self.extra.pop("storage_engine", "innodb") self.charset = self.extra.pop("charset", "utf8mb4") - + # Remove Tortoise-specific parameters self.extra.pop("connection_name", None) self.extra.pop("fetch_inserted", None) @@ -133,7 +122,7 @@ def __init__( # Initialize state self._template: Dict[str, Any] = {} - self._connection: Optional[Any] = None + self._connection = None def _init_connection_templates(self) -> None: """Initialize connection templates for with/without database.""" @@ -145,7 +134,7 @@ def _init_connection_templates(self) -> None: "autocommit": True, **self.extra } - + self._template_with_db = {**base_template, "database": self.database} self._template_no_db = {**base_template, "database": None} @@ -159,7 +148,7 @@ async def create_connection(self, with_db: bool) -> None: # Set transaction support based on storage engine if self.storage_engine.lower() != "innodb": self.capabilities.__dict__["supports_transactions"] = False - + # Set template based on database requirement self._template = self._template_with_db if with_db else self._template_no_db @@ -237,11 +226,11 @@ async def execute_query(self, query: str, values: Optional[List[Any]] = None) -> fields = [desc[0] for desc in cursor.description] return cursor.rowcount, [dict(zip(fields, row)) for row in rows] return cursor.rowcount, [] - + async def execute_query_dict(self, query: str, values: Optional[List[Any]] = None) -> List[Dict[str, Any]]: """Execute a query and return only the results as dictionaries.""" return (await self.execute_query(query, values))[1] - + async def execute_script(self, query: str) -> None: """Execute a script query.""" await self._execute_script(query, True) @@ -265,12 +254,12 @@ def _in_transaction(self) -> TransactionContext: return TortoiseAwsClientTransactionContext(TransactionWrapper(self)) -class TransactionWrapper(AwsMySQLClient, TransactionalDBClient): +class TransactionWrapper(AwsMySQLClient, AwsTransactionalDBClient): """Transaction wrapper for AWS MySQL client.""" - + def __init__(self, connection: AwsMySQLClient) -> None: self.connection_name = connection.connection_name - self._connection = connection._connection + self._connection: AwsConnectionAsyncWrapper = connection._connection self._lock = asyncio.Lock() self._savepoint: Optional[str] = None self._finalized: bool = False @@ -279,7 +268,7 @@ def __init__(self, connection: AwsMySQLClient) -> None: def _in_transaction(self) -> TransactionContext: """Create a nested transaction context.""" return NestedTransactionContext(TransactionWrapper(self)) - + def acquire_connection(self): """Acquire the transaction connection.""" return ConnectionWrapper(self._lock, self) @@ -290,7 +279,7 @@ async def begin(self) -> None: """Begin the transaction.""" await self._connection.set_autocommit(False) self._finalized = False - + async def commit(self) -> None: """Commit the transaction.""" if self._finalized: diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py index 53015ee5..4450c494 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type + from aws_advanced_python_wrapper.tortoise.utils import load_mysql_module -MySQLExecutor = load_mysql_module("executor.py", "MySQLExecutor") +MySQLExecutor: Type = load_mysql_module("executor.py", "MySQLExecutor") + class AwsMySQLExecutor(MySQLExecutor): """AWS MySQL Executor for Tortoise ORM.""" diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py b/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py index bc59e203..bda66c56 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py +++ b/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type + from aws_advanced_python_wrapper.tortoise.utils import load_mysql_module -MySQLSchemaGenerator = load_mysql_module("schema_generator.py", "MySQLSchemaGenerator") +MySQLSchemaGenerator: Type = load_mysql_module("schema_generator.py", "MySQLSchemaGenerator") + class AwsMySQLSchemaGenerator(MySQLSchemaGenerator): """AWS MySQL Executor for Tortoise ORM.""" diff --git a/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py index 4f9b5cfd..2d8bc032 100644 --- a/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py +++ b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py @@ -11,38 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from types import ModuleType -from typing import Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, cast + +if TYPE_CHECKING: + from types import ModuleType + from sqlalchemy import Dialect + from aws_advanced_python_wrapper.database_dialect import DatabaseDialect + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.utils.properties import Properties -from sqlalchemy import Dialect from sqlalchemy.dialects.mysql import mysqlconnector from sqlalchemy.dialects.postgresql import psycopg -from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager -from aws_advanced_python_wrapper.database_dialect import DatabaseDialect -from aws_advanced_python_wrapper.driver_dialect import DriverDialect +from aws_advanced_python_wrapper.connection_provider import \ + ConnectionProviderManager from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.hostinfo import HostInfo -from aws_advanced_python_wrapper.sql_alchemy_connection_provider import SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ + SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.sqlalchemy_driver_dialect import \ + SqlAlchemyDriverDialect from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.properties import Properties logger = Logger(__name__) - class SqlAlchemyTortoisePooledConnectionProvider(SqlAlchemyPooledConnectionProvider): """ Tortoise-specific pooled connection provider that handles failover by disposing pools. """ - _sqlalchemy_dialect_map : Dict[str, ModuleType] = { + _sqlalchemy_dialect_map: Dict[str, "ModuleType"] = { "MySQLDriverDialect": mysqlconnector, "PostgresDriverDialect": psycopg } - def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: + def accepts_host_info(self, host_info: "HostInfo", props: "Properties") -> bool: if self._accept_url_func: return self._accept_url_func(host_info, props) url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) @@ -51,10 +56,10 @@ def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: def _create_pool( self, target_func: Callable, - driver_dialect: DriverDialect, - database_dialect: DatabaseDialect, - host_info: HostInfo, - props: Properties): + driver_dialect: "DriverDialect", + database_dialect: "DatabaseDialect", + host_info: "HostInfo", + props: "Properties"): kwargs = dict() if self._pool_configurator is None else self._pool_configurator(host_info, props) prepared_properties = driver_dialect.prepare_connect_info(host_info, props) database_dialect.prepare_conn_props(prepared_properties) @@ -70,12 +75,14 @@ def _create_pool( kwargs["pre_ping"] = True kwargs["dialect"] = dialect return self._create_sql_alchemy_pool(**kwargs) - - def _get_pool_dialect(self, driver_dialect: DriverDialect) -> Dialect: + + def _get_pool_dialect(self, driver_dialect: "DriverDialect") -> "Dialect | None": dialect = None driver_dialect_class_name = driver_dialect.__class__.__name__ if driver_dialect_class_name == "SqlAlchemyDriverDialect": - driver_dialect_class_name = driver_dialect_class_name._underlying_driver_dialect.__class__.__name__ + # Cast to access _underlying_driver_dialect attribute + sql_alchemy_dialect = cast("SqlAlchemyDriverDialect", driver_dialect) + driver_dialect_class_name = sql_alchemy_dialect._underlying_driver_dialect.__class__.__name__ module = self._sqlalchemy_dialect_map.get(driver_dialect_class_name) if not module: @@ -87,31 +94,31 @@ def _get_pool_dialect(self, driver_dialect: DriverDialect) -> Dialect: def setup_tortoise_connection_provider( - pool_configurator: Optional[Callable[[HostInfo, Properties], Dict[str, Any]]] = None, - pool_mapping: Optional[Callable[[HostInfo, Properties], str]] = None + pool_configurator: Optional[Callable[["HostInfo", "Properties"], Dict[str, Any]]] = None, + pool_mapping: Optional[Callable[["HostInfo", "Properties"], str]] = None ) -> SqlAlchemyTortoisePooledConnectionProvider: """ Helper function to set up and configure the Tortoise connection provider. - + Args: pool_configurator: Optional function to configure pool settings. Defaults to basic pool configuration. pool_mapping: Optional function to generate pool keys. Defaults to basic pool key generation. - + Returns: Configured SqlAlchemyTortoisePooledConnectionProvider instance. """ - def default_pool_configurator(host_info: HostInfo, props: Properties) -> Dict[str, Any]: + def default_pool_configurator(host_info: "HostInfo", props: "Properties") -> Dict[str, Any]: return {"pool_size": 5, "max_overflow": -1} - - def default_pool_mapping(host_info: HostInfo, props: Properties) -> str: + + def default_pool_mapping(host_info: "HostInfo", props: "Properties") -> str: return f"{host_info.url}{props.get('user', '')}{props.get('database', '')}" - + provider = SqlAlchemyTortoisePooledConnectionProvider( pool_configurator=pool_configurator or default_pool_configurator, pool_mapping=pool_mapping or default_pool_mapping ) - + ConnectionProviderManager.set_connection_provider(provider) return provider diff --git a/aws_advanced_python_wrapper/tortoise/utils.py b/aws_advanced_python_wrapper/tortoise/utils.py index 6fb8d995..90ecede0 100644 --- a/aws_advanced_python_wrapper/tortoise/utils.py +++ b/aws_advanced_python_wrapper/tortoise/utils.py @@ -23,11 +23,13 @@ def load_mysql_module(module_file: str, class_name: str) -> Type: tortoise_path = Path(tortoise.__file__).parent module_path = tortoise_path / "backends" / "mysql" / module_file module_name = f"tortoise.backends.mysql.{module_file[:-3]}" - + if module_name not in sys.modules: spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module {module_name}") module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) - + return getattr(sys.modules[module_name], class_name) diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index b2835ba6..beafa307 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -264,11 +264,13 @@ def rowcount(self) -> int: @property def arraysize(self) -> int: return self.target_cursor.arraysize - + # Optional for PEP249 @property def lastrowid(self) -> int: - return self.target_cursor.lastrowid + if hasattr(self.target_cursor, 'lastrowid'): + return self.target_cursor.lastrowid + raise AttributeError("'Cursor' object has no attribute 'lastrowid'") def close(self) -> None: self._plugin_manager.execute(self.target_cursor, "Cursor.close", diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 2647de22..51878f68 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -111,8 +111,8 @@ def pytest_runtest_setup(item): # Only sleep if we still need to retry if (len(instances) < request.get_num_of_instances() - or len(instances) == 0 - or not rds_utility.is_db_instance_writer(instances[0])): + or len(instances) == 0 + or not rds_utility.is_db_instance_writer(instances[0])): sleep(5) assert len(instances) > 0 diff --git a/tests/integration/container/tortoise/models/test_models.py b/tests/integration/container/tortoise/models/test_models.py index dca5330a..b7d3ffc5 100644 --- a/tests/integration/container/tortoise/models/test_models.py +++ b/tests/integration/container/tortoise/models/test_models.py @@ -20,7 +20,7 @@ class User(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=50) email = fields.CharField(max_length=100, unique=True) - + class Meta: table = "users" @@ -30,7 +30,7 @@ class UniqueName(Model): name = fields.CharField(max_length=20, null=True, unique=True) optional = fields.CharField(max_length=20, null=True) other_optional = fields.CharField(max_length=20, null=True) - + class Meta: table = "unique_names" @@ -39,6 +39,6 @@ class TableWithSleepTrigger(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=50) value = fields.CharField(max_length=100) - + class Meta: table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_copy.py b/tests/integration/container/tortoise/models/test_models_copy.py index dca5330a..b7d3ffc5 100644 --- a/tests/integration/container/tortoise/models/test_models_copy.py +++ b/tests/integration/container/tortoise/models/test_models_copy.py @@ -20,7 +20,7 @@ class User(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=50) email = fields.CharField(max_length=100, unique=True) - + class Meta: table = "users" @@ -30,7 +30,7 @@ class UniqueName(Model): name = fields.CharField(max_length=20, null=True, unique=True) optional = fields.CharField(max_length=20, null=True) other_optional = fields.CharField(max_length=20, null=True) - + class Meta: table = "unique_names" @@ -39,6 +39,6 @@ class TableWithSleepTrigger(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=50) value = fields.CharField(max_length=100) - + class Meta: table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_relationships.py b/tests/integration/container/tortoise/models/test_models_relationships.py index 53bc622f..16e2bfe6 100644 --- a/tests/integration/container/tortoise/models/test_models_relationships.py +++ b/tests/integration/container/tortoise/models/test_models_relationships.py @@ -12,26 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from tortoise import fields from tortoise.models import Model +if TYPE_CHECKING: + from tortoise.fields.relational import (ForeignKeyFieldInstance, + OneToOneFieldInstance) # One-to-One Relationship Models + + class RelTestAccount(Model): id = fields.IntField(primary_key=True) username = fields.CharField(max_length=50, unique=True) email = fields.CharField(max_length=100) - + class Meta: table = "rel_test_accounts" class RelTestAccountProfile(Model): id = fields.IntField(primary_key=True) - account = fields.OneToOneField("models.RelTestAccount", related_name="profile", on_delete=fields.CASCADE) + account: "OneToOneFieldInstance" = fields.OneToOneField("models.RelTestAccount", related_name="profile", on_delete=fields.CASCADE) bio = fields.TextField(null=True) avatar_url = fields.CharField(max_length=200, null=True) - + class Meta: table = "rel_test_account_profiles" @@ -41,7 +48,7 @@ class RelTestPublisher(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=100) email = fields.CharField(max_length=100, unique=True) - + class Meta: table = "rel_test_publishers" @@ -50,9 +57,9 @@ class RelTestPublication(Model): id = fields.IntField(primary_key=True) title = fields.CharField(max_length=200) isbn = fields.CharField(max_length=13, unique=True) - publisher = fields.ForeignKeyField("models.RelTestPublisher", related_name="publications", on_delete=fields.CASCADE) + publisher: "ForeignKeyFieldInstance" = fields.ForeignKeyField("models.RelTestPublisher", related_name="publications", on_delete=fields.CASCADE) published_date = fields.DateField(null=True) - + class Meta: table = "rel_test_publications" @@ -62,7 +69,7 @@ class RelTestLearner(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=100) learner_id = fields.CharField(max_length=20, unique=True) - + class Meta: table = "rel_test_learners" @@ -73,6 +80,6 @@ class RelTestSubject(Model): code = fields.CharField(max_length=10, unique=True) credits = fields.IntField() learners = fields.ManyToManyField("models.RelTestLearner", related_name="subjects") - + class Meta: table = "rel_test_subjects" diff --git a/tests/integration/container/tortoise/router/test_router.py b/tests/integration/container/tortoise/router/test_router.py index a0893277..e4693120 100644 --- a/tests/integration/container/tortoise/router/test_router.py +++ b/tests/integration/container/tortoise/router/test_router.py @@ -15,6 +15,6 @@ class TestRouter: def db_for_read(self, model): return "default" - + def db_for_write(self, model): - return "default" \ No newline at end of file + return "default" diff --git a/tests/integration/container/tortoise/test_tortoise_basic.py b/tests/integration/container/tortoise/test_tortoise_basic.py index 25530775..d06f7de2 100644 --- a/tests/integration/container/tortoise/test_tortoise_basic.py +++ b/tests/integration/container/tortoise/test_tortoise_basic.py @@ -38,14 +38,12 @@ class TestTortoiseBasic: Test class for Tortoise ORM integration with AWS Advanced Python Wrapper. Contains tests related to basic test operations. """ - + @pytest_asyncio.fixture async def setup_tortoise_basic(self, conn_utils): """Setup Tortoise with default plugins.""" async for result in setup_tortoise(conn_utils): yield result - - @pytest.mark.asyncio async def test_basic_crud_operations(self, setup_tortoise_basic): @@ -54,22 +52,22 @@ async def test_basic_crud_operations(self, setup_tortoise_basic): user = await User.create(name="John Doe", email="john@example.com") assert user.id is not None assert user.name == "John Doe" - + # Read found_user = await User.get(id=user.id) assert found_user.name == "John Doe" assert found_user.email == "john@example.com" - + # Update found_user.name = "Jane Doe" await found_user.save() - + updated_user = await User.get(id=user.id) assert updated_user.name == "Jane Doe" - + # Delete await updated_user.delete() - + with pytest.raises(Exception): await User.get(id=user.id) @@ -80,22 +78,22 @@ async def test_basic_crud_operations_config(self, setup_tortoise_basic): user = await User.create(name="John Doe", email="john@example.com") assert user.id is not None assert user.name == "John Doe" - + # Read found_user = await User.get(id=user.id) assert found_user.name == "John Doe" assert found_user.email == "john@example.com" - + # Update found_user.name = "Jane Doe" await found_user.save() - + updated_user = await User.get(id=user.id) assert updated_user.name == "Jane Doe" - + # Delete await updated_user.delete() - + with pytest.raises(Exception): await User.get(id=user.id) @@ -105,7 +103,7 @@ async def test_transaction_support(self, setup_tortoise_basic): async with in_transaction() as conn: await User.create(name="User 1", email="user1@example.com", using_db=conn) await User.create(name="User 2", email="user2@example.com", using_db=conn) - + # Verify users exist within transaction users = await User.filter(name__in=["User 1", "User 2"]).using_db(conn) assert len(users) == 2 @@ -120,7 +118,7 @@ async def test_transaction_rollback(self, setup_tortoise_basic): raise ValueError("Test rollback") except ValueError: pass - + # Verify user was not created due to rollback users = await User.filter(name="Test User") assert len(users) == 0 @@ -134,14 +132,14 @@ async def test_bulk_operations(self, setup_tortoise_basic): for i in range(5) ] await User.bulk_create([User(**data) for data in users_data]) - + # Verify bulk creation users = await User.filter(name__startswith="User") assert len(users) == 5 - + # Bulk update await User.filter(name__startswith="User").update(name="Updated User") - + updated_users = await User.filter(name="Updated User") assert len(updated_users) == 5 @@ -152,23 +150,23 @@ async def test_query_operations(self, setup_tortoise_basic): await User.create(name="Alice", email="alice@example.com") await User.create(name="Bob", email="bob@example.com") await User.create(name="Charlie", email="charlie@example.com") - + # Test filtering alice = await User.get(name="Alice") assert alice.email == "alice@example.com" - + # Test count count = await User.all().count() assert count >= 3 - + # Test ordering users = await User.all().order_by("name") assert users[0].name == "Alice" - + # Test exists exists = await User.filter(name="Alice").exists() assert exists is True - + # Test values emails = await User.all().values_list("email", flat=True) assert "alice@example.com" in emails @@ -178,68 +176,68 @@ async def test_bulk_create_with_ids(self, setup_tortoise_basic): """Test bulk create operations with ID verification.""" # Bulk create 1000 UniqueName objects with no name (null values) await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) - + # Get all created records with id and name all_ = await UniqueName.all().values("id", "name") - + # Get the starting ID inc = all_[0]["id"] - + # Sort by ID for comparison all_sorted = sorted(all_, key=lambda x: x["id"]) - + # Verify the IDs are sequential and names are None expected = [{"id": val + inc, "name": None} for val in range(1000)] - + assert len(all_sorted) == 1000 assert all_sorted == expected @pytest.mark.asyncio async def test_concurrency_read(self, setup_tortoise_basic): """Test concurrent read operations with AWS wrapper.""" - + await User.create(name="Test User", email="test@example.com") user1 = await User.first() - + # Perform 100 concurrent reads all_read = await asyncio.gather(*[User.first() for _ in range(100)]) - + # All reads should return the same user assert all_read == [user1 for _ in range(100)] @pytest.mark.asyncio async def test_concurrency_create(self, setup_tortoise_basic): """Test concurrent create operations with AWS wrapper.""" - + # Perform 100 concurrent creates with unique emails all_write = await asyncio.gather(*[ - User.create(name="Test", email=f"test{i}@example.com") + User.create(name="Test", email=f"test{i}@example.com") for i in range(100) ]) - + # Read all created users all_read = await User.all() - + # All created users should exist in the database assert set(all_write) == set(all_read) @pytest.mark.asyncio async def test_atomic_decorator(self, setup_tortoise_basic): """Test atomic decorator for transaction handling with AWS wrapper.""" - + @atomic() async def create_users_atomically(): user1 = await User.create(name="Atomic User 1", email="atomic1@example.com") user2 = await User.create(name="Atomic User 2", email="atomic2@example.com") return user1, user2 - + # Execute atomic operation user1, user2 = await create_users_atomically() - + # Verify both users were created assert user1.id is not None assert user2.id is not None - + # Verify users exist in database found_users = await User.filter(name__startswith="Atomic User") assert len(found_users) == 2 @@ -247,17 +245,17 @@ async def create_users_atomically(): @pytest.mark.asyncio async def test_atomic_decorator_rollback(self, setup_tortoise_basic): """Test atomic decorator rollback on exception with AWS wrapper.""" - + @atomic() async def create_users_with_error(): await User.create(name="Atomic Rollback User", email="rollback@example.com") # Force rollback by raising exception raise ValueError("Intentional error for rollback test") - + # Execute atomic operation that should fail with pytest.raises(ValueError): await create_users_with_error() - + # Verify user was not created due to rollback users = await User.filter(name="Atomic Rollback User") assert len(users) == 0 diff --git a/tests/integration/container/tortoise/test_tortoise_common.py b/tests/integration/container/tortoise/test_tortoise_common.py index 1cbf2a1c..f560147a 100644 --- a/tests/integration/container/tortoise/test_tortoise_common.py +++ b/tests/integration/container/tortoise/test_tortoise_common.py @@ -13,13 +13,11 @@ # limitations under the License. import pytest -import pytest_asyncio from tortoise import Tortoise, connections # Import to register the aws-mysql backend -import aws_advanced_python_wrapper.tortoise -from tests.integration.container.tortoise.models.test_models import * -from tests.integration.container.utils.rds_test_utility import RdsTestUtility +import aws_advanced_python_wrapper.tortoise # noqa: F401 +from tests.integration.container.tortoise.models.test_models import User from tests.integration.container.utils.test_environment import TestEnvironment @@ -28,12 +26,13 @@ async def clear_test_models(): from tortoise.models import Model import tests.integration.container.tortoise.models.test_models as models_module - + for attr_name in dir(models_module): attr = getattr(models_module, attr_name) if isinstance(attr, type) and issubclass(attr, Model) and attr != Model: await attr.all().delete() + async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwargs): """Setup Tortoise with AWS MySQL backend and configurable plugins.""" db_url = conn_utils.get_aws_tortoise_url( @@ -41,7 +40,7 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar plugins=plugins, **kwargs, ) - + config = { "connections": { "default": db_url @@ -53,17 +52,17 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar } } } - + from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ setup_tortoise_connection_provider setup_tortoise_connection_provider() await Tortoise.init(config=config) await Tortoise.generate_schemas() - + await clear_test_models() - + yield - + await clear_test_models() await reset_tortoise() @@ -71,11 +70,11 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar async def run_basic_read_operations(name_prefix="Test", email_prefix="test"): """Common test logic for basic read operations.""" user = await User.create(name=f"{name_prefix} User", email=f"{email_prefix}@example.com") - + found_user = await User.get(id=user.id) assert found_user.name == f"{name_prefix} User" assert found_user.email == f"{email_prefix}@example.com" - + users = await User.filter(name=f"{name_prefix} User") assert len(users) == 1 assert users[0].id == user.id @@ -85,18 +84,19 @@ async def run_basic_write_operations(name_prefix="Write", email_prefix="write"): """Common test logic for basic write operations.""" user = await User.create(name=f"{name_prefix} Test", email=f"{email_prefix}@example.com") assert user.id is not None - + user.name = f"Updated {name_prefix}" await user.save() - + updated_user = await User.get(id=user.id) assert updated_user.name == f"Updated {name_prefix}" - + await updated_user.delete() - + with pytest.raises(Exception): await User.get(id=user.id) + async def reset_tortoise(): await Tortoise.close_connections() await Tortoise._reset_apps() diff --git a/tests/integration/container/tortoise/test_tortoise_config.py b/tests/integration/container/tortoise/test_tortoise_config.py index 98431bf6..06432001 100644 --- a/tests/integration/container/tortoise/test_tortoise_config.py +++ b/tests/integration/container/tortoise/test_tortoise_config.py @@ -17,7 +17,7 @@ from tortoise import Tortoise, connections # Import to register the aws-mysql backend -import aws_advanced_python_wrapper.tortoise +import aws_advanced_python_wrapper.tortoise # noqa: F401 from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ setup_tortoise_connection_provider from tests.integration.container.tortoise.models.test_models import User @@ -66,14 +66,14 @@ async def setup_tortoise_dict_config(self, conn_utils): } } } - + setup_tortoise_connection_provider() await Tortoise.init(config=config) await Tortoise.generate_schemas() await self._clear_all_test_models() - + yield - + await self._clear_all_test_models() await reset_tortoise() @@ -83,7 +83,7 @@ async def setup_tortoise_multi_db(self, conn_utils): # Create second database name original_db = conn_utils.dbname second_db = f"{original_db}_test2" - + config = { "connections": { "default": { @@ -120,24 +120,24 @@ async def setup_tortoise_multi_db(self, conn_utils): } } } - + setup_tortoise_connection_provider() await Tortoise.init(config=config) - + # Create second database conn = connections.get("second_db") try: await conn.db_create() except Exception: pass - + await Tortoise.generate_schemas() - + await self._clear_all_test_models() - + yield second_db await self._clear_all_test_models() - + # Drop second database conn = connections.get("second_db") await conn.db_delete() @@ -173,9 +173,9 @@ async def setup_tortoise_with_router(self, conn_utils): await Tortoise.init(config=config) await Tortoise.generate_schemas() await self._clear_all_test_models() - + yield - + await self._clear_all_test_models() await reset_tortoise() @@ -184,12 +184,12 @@ async def test_dict_config_read_operations(self, setup_tortoise_dict_config): """Test basic read operations with dictionary configuration.""" # Create test data user = await User.create(name="Dict Config User", email="dict@example.com") - + # Read operations found_user = await User.get(id=user.id) assert found_user.name == "Dict Config User" assert found_user.email == "dict@example.com" - + # Query operations users = await User.filter(name="Dict Config User") assert len(users) == 1 @@ -201,57 +201,57 @@ async def test_dict_config_write_operations(self, setup_tortoise_dict_config): # Create user = await User.create(name="Dict Write Test", email="dictwrite@example.com") assert user.id is not None - + # Update user.name = "Updated Dict User" await user.save() - + updated_user = await User.get(id=user.id) assert updated_user.name == "Updated Dict User" - + # Delete await updated_user.delete() - + with pytest.raises(Exception): await User.get(id=user.id) @pytest.mark.asyncio async def test_multi_db_operations(self, setup_tortoise_multi_db): - """Test operations with multiple databases using same backend.""" + """Test operations with multiple databases using same backend.""" # Get second connection second_conn = connections.get("second_db") - + # Create users in different databases user1 = await User.create(name="DB1 User", email="db1@example.com") user2 = await User.create(name="DB2 User", email="db2@example.com", using_db=second_conn) - + # Verify users exist in their respective databases db1_users = await User.all() db2_users = await User.all().using_db(second_conn) - + assert len(db1_users) == 1 assert len(db2_users) == 1 assert db1_users[0].name == "DB1 User" assert db2_users[0].name == "DB2 User" - + # Verify isolation - users don't exist in the other database db1_user_in_db2 = await User.filter(name="DB1 User").using_db(second_conn) db2_user_in_db1 = await User.filter(name="DB2 User") - + assert len(db1_user_in_db2) == 0 assert len(db2_user_in_db1) == 0 - + # Test updates in different databases user1.name = "Updated DB1 User" await user1.save() - + user2.name = "Updated DB2 User" await user2.save(using_db=second_conn) - + # Verify updates updated_user1 = await User.get(id=user1.id) updated_user2 = await User.get(id=user2.id, using_db=second_conn) - + assert updated_user1.name == "Updated DB1 User" assert updated_user2.name == "Updated DB2 User" @@ -260,12 +260,12 @@ async def test_router_read_operations(self, setup_tortoise_with_router): """Test read operations with router configuration.""" # Create test data user = await User.create(name="Router User", email="router@example.com") - + # Read operations (should be routed by router) found_user = await User.get(id=user.id) assert found_user.name == "Router User" assert found_user.email == "router@example.com" - + # Query operations users = await User.filter(name="Router User") assert len(users) == 1 @@ -277,16 +277,16 @@ async def test_router_write_operations(self, setup_tortoise_with_router): # Create (should be routed by router) user = await User.create(name="Router Write Test", email="routerwrite@example.com") assert user.id is not None - + # Update (should be routed by router) user.name = "Updated Router User" await user.save() - + updated_user = await User.get(id=user.id) assert updated_user.name == "Updated Router User" - + # Delete (should be routed by router) await updated_user.delete() - + with pytest.raises(Exception): await User.get(id=user.id) diff --git a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py index de51eb5f..8a5752ad 100644 --- a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py +++ b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py @@ -20,7 +20,6 @@ from boto3 import client from botocore.exceptions import ClientError -from tests.integration.container.tortoise.models.test_models import User from tests.integration.container.tortoise.test_tortoise_common import ( run_basic_read_operations, run_basic_write_operations, setup_tortoise) from tests.integration.container.utils.conditions import ( @@ -43,18 +42,18 @@ class TestTortoiseCustomEndpoint: """Test class for Tortoise ORM with custom endpoint plugin.""" endpoint_id = f"test-tortoise-endpoint-{uuid4()}" - endpoint_info = {} - + endpoint_info: dict[str, str] = {} + @pytest.fixture(scope='class') def create_custom_endpoint(self): """Create a custom endpoint for testing.""" env_info = TestEnvironment.get_current().get_info() region = env_info.get_region() rds_client = client('rds', region_name=region) - + instances = env_info.get_database_info().get_instances() instance_ids = [instances[0].get_instance_id()] - + try: rds_client.create_db_cluster_endpoint( DBClusterEndpointIdentifier=self.endpoint_id, @@ -62,7 +61,7 @@ def create_custom_endpoint(self): EndpointType="ANY", StaticMembers=instance_ids ) - + self._wait_until_endpoint_available(rds_client) yield self.endpoint_info["Endpoint"] finally: @@ -72,12 +71,12 @@ def create_custom_endpoint(self): if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': pass # Ignore if endpoint doesn't exist rds_client.close() - + def _wait_until_endpoint_available(self, rds_client): """Wait for the custom endpoint to become available.""" end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes available = False - + while perf_counter_ns() < end_ns: response = rds_client.describe_db_cluster_endpoints( DBClusterEndpointIdentifier=self.endpoint_id, @@ -88,23 +87,23 @@ def _wait_until_endpoint_available(self, rds_client): } ] ) - + response_endpoints = response["DBClusterEndpoints"] if len(response_endpoints) != 1: sleep(3) continue - + response_endpoint = response_endpoints[0] TestTortoiseCustomEndpoint.endpoint_info = response_endpoint available = "available" == response_endpoint["Status"] if available: break - + sleep(3) - + if not available: pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") - + @pytest_asyncio.fixture async def setup_tortoise_custom_endpoint(self, conn_utils, create_custom_endpoint): """Setup Tortoise with custom endpoint plugin.""" diff --git a/tests/integration/container/tortoise/test_tortoise_failover.py b/tests/integration/container/tortoise/test_tortoise_failover.py index baccfd2e..59139422 100644 --- a/tests/integration/container/tortoise/test_tortoise_failover.py +++ b/tests/integration/container/tortoise/test_tortoise_failover.py @@ -29,7 +29,7 @@ setup_tortoise from tests.integration.container.utils.conditions import ( disable_on_engines, disable_on_features, enable_on_deployments, - enable_on_engines, enable_on_num_instances) + enable_on_num_instances) from tests.integration.container.utils.database_engine import DatabaseEngine from tests.integration.container.utils.database_engine_deployment import \ DatabaseEngineDeployment @@ -52,10 +52,10 @@ class TestTortoiseFailover: def aurora_utility(self): region: str = TestEnvironment.get_current().get_info().get_region() return RdsTestUtility(region) - + async def _create_sleep_trigger_record(self, name_prefix="Plugin Test", value="test_value", using_db=None): await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) - + @pytest_asyncio.fixture async def sleep_trigger_setup(self): """Setup and cleanup sleep trigger for testing.""" @@ -66,7 +66,7 @@ async def sleep_trigger_setup(self): await connection.execute_query(trigger_sql) yield await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") - + @pytest_asyncio.fixture async def setup_tortoise_with_failover(self, conn_utils): """Setup Tortoise with failover plugins.""" @@ -74,17 +74,16 @@ async def setup_tortoise_with_failover(self, conn_utils): "topology_refresh_ms": 10000, "connect_timeout": 10, "monitoring-connect_timeout": 5, - "use_pure": True, # MySQL specific + "use_pure": True, # MySQL specific } async for result in setup_tortoise(conn_utils, plugins="failover", **kwargs): yield result - @pytest.mark.asyncio async def test_basic_operations_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): """Test failover when inserting to a single table""" insert_exception = None - + def insert_thread(): nonlocal insert_exception try: @@ -93,22 +92,22 @@ def insert_thread(): loop.run_until_complete(self._create_sleep_trigger_record()) except Exception as e: insert_exception = e - + def failover_thread(): time.sleep(3) # Wait for insert to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + # Start both threads insert_t = threading.Thread(target=insert_thread) failover_t = threading.Thread(target=failover_thread) - + insert_t.start() failover_t.start() - + # Wait for both threads to complete insert_t.join() failover_t.join() - + # Assert that insert thread got FailoverSuccessError assert insert_exception is not None assert isinstance(insert_exception, FailoverSuccessError) @@ -117,55 +116,55 @@ def failover_thread(): async def test_transaction_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): """Test transactions with failover during long-running operations.""" transaction_exception = None - + def transaction_thread(): nonlocal transaction_exception try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + async def run_transaction(): async with in_transaction() as conn: await self._create_sleep_trigger_record("TX Plugin Test", "tx_test_value", conn) - + loop.run_until_complete(run_transaction()) except Exception as e: transaction_exception = e - + def failover_thread(): time.sleep(5) # Wait for transaction to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + # Start both threads tx_t = threading.Thread(target=transaction_thread) failover_t = threading.Thread(target=failover_thread) - + tx_t.start() failover_t.start() - + # Wait for both threads to complete tx_t.join() failover_t.join() - + # Assert that transaction thread got FailoverSuccessError assert transaction_exception is not None assert isinstance(transaction_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) - + # Verify no records were created due to transaction rollback record_count = await TableWithSleepTrigger.all().count() assert record_count == 0 - + # Verify autocommit is re-enabled after failover by inserting a record autocommit_record = await TableWithSleepTrigger.create( - name="Autocommit Test", + name="Autocommit Test", value="autocommit_value" ) - + # Verify the record exists (should be auto-committed) found_record = await TableWithSleepTrigger.get(id=autocommit_record.id) assert found_record.name == "Autocommit Test" assert found_record.value == "autocommit_value" - + # Clean up the test record await found_record.delete() @@ -173,18 +172,18 @@ def failover_thread(): async def test_concurrent_queries_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): """Test concurrent queries with failover during long-running operation.""" connection = connections.get("default") - + # Run 15 concurrent select queries async def run_select_query(query_id): return await connection.execute_query(f"SELECT {query_id} as query_id") - + initial_tasks = [run_select_query(i) for i in range(15)] initial_results = await asyncio.gather(*initial_tasks) assert len(initial_results) == 15 - + # Run insert query with failover sleep_exception = None - + def insert_query_thread(): nonlocal sleep_exception for attempt in range(3): @@ -197,35 +196,35 @@ def insert_query_thread(): except Exception as e: sleep_exception = e break # Stop on first exception (likely the failover) - + def failover_thread(): time.sleep(5) # Wait for sleep query to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + sleep_t = threading.Thread(target=insert_query_thread) failover_t = threading.Thread(target=failover_thread) - + sleep_t.start() failover_t.start() - + sleep_t.join() failover_t.join() - + # Verify failover exception occurred assert sleep_exception is not None assert isinstance(sleep_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) - + # Run another 15 concurrent select queries after failover to verify that we can still make connections post_failover_tasks = [run_select_query(i + 100) for i in range(15)] post_failover_results = await asyncio.gather(*post_failover_tasks) assert len(post_failover_results) == 15 - + @pytest.mark.asyncio async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): """Test multiple concurrent insert operations with failover during long-running operations.""" # Track exceptions from all insert threads insert_exceptions = [] - + def insert_thread(thread_id): try: loop = asyncio.new_event_loop() @@ -235,27 +234,27 @@ def insert_thread(thread_id): ) except Exception as e: insert_exceptions.append(e) - + def failover_thread(): time.sleep(5) # Wait for inserts to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + # Start 15 insert threads insert_threads = [] for i in range(15): t = threading.Thread(target=insert_thread, args=(i,)) insert_threads.append(t) t.start() - + # Start failover thread failover_t = threading.Thread(target=failover_thread) failover_t.start() - + # Wait for all threads to complete for t in insert_threads: t.join() failover_t.join() - + # Verify ALL threads got FailoverSuccessError or TransactionResolutionUnknownError assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" failover_errors = [e for e in insert_exceptions if isinstance(e, (FailoverSuccessError, TransactionResolutionUnknownError))] diff --git a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py index cf76ac57..b4f9e32d 100644 --- a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py +++ b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py @@ -32,7 +32,7 @@ TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseIamAuthentication: """Test class for Tortoise ORM with IAM authentication.""" - + @pytest_asyncio.fixture async def setup_tortoise_iam(self, conn_utils): """Setup Tortoise with IAM authentication.""" diff --git a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py index 13856a49..d6abed4b 100644 --- a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py +++ b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py @@ -52,23 +52,23 @@ class TestTortoiseMultiPlugins: """Test class for Tortoise ORM with multiple AWS wrapper plugins.""" endpoint_id = f"test-multi-endpoint-{uuid4()}" - endpoint_info = {} - + endpoint_info: dict[str, str] = {} + @pytest.fixture(scope='class') def aurora_utility(self): region: str = TestEnvironment.get_current().get_info().get_region() return RdsTestUtility(region) - + @pytest.fixture(scope='class') def create_custom_endpoint(self): """Create a custom endpoint for testing.""" env_info = TestEnvironment.get_current().get_info() region = env_info.get_region() rds_client = client('rds', region_name=region) - + instances = env_info.get_database_info().get_instances() instance_ids = [instances[0].get_instance_id()] - + try: rds_client.create_db_cluster_endpoint( DBClusterEndpointIdentifier=self.endpoint_id, @@ -76,7 +76,7 @@ def create_custom_endpoint(self): EndpointType="ANY", StaticMembers=instance_ids ) - + self._wait_until_endpoint_available(rds_client) yield self.endpoint_info["Endpoint"] finally: @@ -86,12 +86,12 @@ def create_custom_endpoint(self): if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': pass rds_client.close() - + def _wait_until_endpoint_available(self, rds_client): """Wait for the custom endpoint to become available.""" end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes available = False - + while perf_counter_ns() < end_ns: response = rds_client.describe_db_cluster_endpoints( DBClusterEndpointIdentifier=self.endpoint_id, @@ -102,27 +102,27 @@ def _wait_until_endpoint_available(self, rds_client): } ] ) - + response_endpoints = response["DBClusterEndpoints"] if len(response_endpoints) != 1: sleep(3) continue - + response_endpoint = response_endpoints[0] TestTortoiseMultiPlugins.endpoint_info = response_endpoint available = "available" == response_endpoint["Status"] if available: break - + sleep(3) - + if not available: pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") - + async def _create_sleep_trigger_record(self, name_prefix="Multi Test", value="multi_value", using_db=None): """Create TableWithSleepTrigger record.""" await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) - + @pytest_asyncio.fixture async def sleep_trigger_setup(self): """Setup and cleanup sleep trigger for testing.""" @@ -133,7 +133,7 @@ async def sleep_trigger_setup(self): await connection.execute_query(trigger_sql) yield await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") - + @pytest_asyncio.fixture async def setup_tortoise_multi_plugins(self, conn_utils, create_custom_endpoint): """Setup Tortoise with multiple plugins.""" @@ -141,7 +141,10 @@ async def setup_tortoise_multi_plugins(self, conn_utils, create_custom_endpoint) "topology_refresh_ms": 10000, "reader_host_selector_strategy": "fastest_response" } - async for result in setup_tortoise(conn_utils, plugins="failover,iam,custom_endpoint,aurora_connection_tracker,fastest_response_strategy", user=conn_utils.iam_user, **kwargs): + async for result in setup_tortoise(conn_utils, + plugins="failover,iam,custom_endpoint,aurora_connection_tracker,fastest_response_strategy", + user=conn_utils.iam_user, **kwargs + ): yield result @pytest.mark.asyncio @@ -153,12 +156,12 @@ async def test_basic_read_operations(self, setup_tortoise_multi_plugins): async def test_basic_write_operations(self, setup_tortoise_multi_plugins): """Test basic write operations with multiple plugins.""" await run_basic_write_operations("Multi", "multi") - + @pytest.mark.asyncio async def test_multi_plugins_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): """Test multiple plugins work with failover during long-running operations.""" insert_exception = None - + def insert_thread(): nonlocal insert_exception try: @@ -167,42 +170,42 @@ def insert_thread(): loop.run_until_complete(self._create_sleep_trigger_record()) except Exception as e: insert_exception = e - + def failover_thread(): time.sleep(5) # Wait for insert to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + # Start both threads insert_t = threading.Thread(target=insert_thread) failover_t = threading.Thread(target=failover_thread) - + insert_t.start() failover_t.start() - + # Wait for both threads to complete insert_t.join() failover_t.join() - + # Assert that insert thread got FailoverSuccessError assert insert_exception is not None assert isinstance(insert_exception, FailoverSuccessError) - + @pytest.mark.asyncio async def test_concurrent_queries_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): """Test concurrent queries with failover during long-running operation.""" connection = connections.get("default") - + # Run 15 concurrent select queries async def run_select_query(query_id): return await connection.execute_query(f"SELECT {query_id} as query_id") - + initial_tasks = [run_select_query(i) for i in range(15)] initial_results = await asyncio.gather(*initial_tasks) assert len(initial_results) == 15 - + # Run sleep query with failover sleep_exception = None - + def sleep_query_thread(): nonlocal sleep_exception for attempt in range(3): @@ -215,35 +218,35 @@ def sleep_query_thread(): except Exception as e: sleep_exception = e break # Stop on first exception (likely the failover) - + def failover_thread(): time.sleep(5) # Wait for sleep query to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + sleep_t = threading.Thread(target=sleep_query_thread) failover_t = threading.Thread(target=failover_thread) - + sleep_t.start() failover_t.start() - + sleep_t.join() failover_t.join() - + # Verify failover exception occurred assert sleep_exception is not None assert isinstance(sleep_exception, FailoverSuccessError) - + # Run another 15 concurrent select queries after failover post_failover_tasks = [run_select_query(i + 100) for i in range(15)] post_failover_results = await asyncio.gather(*post_failover_tasks) assert len(post_failover_results) == 15 - + @pytest.mark.asyncio async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): """Test multiple concurrent insert operations with failover during long-running operations.""" # Track exceptions from all insert threads insert_exceptions = [] - + def insert_thread(thread_id): try: loop = asyncio.new_event_loop() @@ -253,27 +256,27 @@ def insert_thread(thread_id): ) except Exception as e: insert_exceptions.append(e) - + def failover_thread(): time.sleep(5) # Wait for inserts to start aurora_utility.failover_cluster_and_wait_until_writer_changed() - + # Start 15 insert threads insert_threads = [] for i in range(15): t = threading.Thread(target=insert_thread, args=(i,)) insert_threads.append(t) t.start() - + # Start failover thread failover_t = threading.Thread(target=failover_thread) failover_t.start() - + # Wait for all threads to complete for t in insert_threads: t.join() failover_t.join() - + # Verify ALL threads got FailoverSuccessError or TransactionResolutionUnknownError assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" failover_errors = [e for e in insert_exceptions if isinstance(e, FailoverSuccessError)] diff --git a/tests/integration/container/tortoise/test_tortoise_relations.py b/tests/integration/container/tortoise/test_tortoise_relations.py index d2985857..61281749 100644 --- a/tests/integration/container/tortoise/test_tortoise_relations.py +++ b/tests/integration/container/tortoise/test_tortoise_relations.py @@ -14,11 +14,11 @@ import pytest import pytest_asyncio -from tortoise import Tortoise, connections +from tortoise import Tortoise from tortoise.exceptions import IntegrityError # Import to register the aws-mysql backend -import aws_advanced_python_wrapper.tortoise +import aws_advanced_python_wrapper.tortoise # noqa: F401 from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ setup_tortoise_connection_provider from tests.integration.container.tortoise.models.test_models_relationships import ( @@ -28,7 +28,6 @@ reset_tortoise from tests.integration.container.utils.conditions import disable_on_engines from tests.integration.container.utils.database_engine import DatabaseEngine -from tests.integration.container.utils.rds_test_utility import RdsTestUtility from tests.integration.container.utils.test_environment import TestEnvironment @@ -42,10 +41,9 @@ async def clear_relationship_models(): await RelTestAccount.all().delete() - @disable_on_engines([DatabaseEngine.PG]) class TestTortoiseRelationships: - + @pytest_asyncio.fixture(scope="function") async def setup_tortoise_relationships(self, conn_utils): """Setup Tortoise with relationship models.""" @@ -53,7 +51,7 @@ async def setup_tortoise_relationships(self, conn_utils): TestEnvironment.get_current().get_engine(), plugins="aurora_connection_tracker", ) - + config = { "connections": { "default": db_url @@ -65,93 +63,93 @@ async def setup_tortoise_relationships(self, conn_utils): } } } - + setup_tortoise_connection_provider() await Tortoise.init(config=config) await Tortoise.generate_schemas() - + # Clear any existing data await clear_relationship_models() - + yield - + # Cleanup await clear_relationship_models() await reset_tortoise() - + @pytest.mark.asyncio async def test_one_to_one_relationship_creation(self, setup_tortoise_relationships): """Test creating one-to-one relationships""" account = await RelTestAccount.create(username="testuser", email="test@example.com") profile = await RelTestAccountProfile.create(account=account, bio="Test bio", avatar_url="http://example.com/avatar.jpg") - + # Test forward relationship assert profile.account_id == account.id - + # Test reverse relationship account_with_profile = await RelTestAccount.get(id=account.id).prefetch_related("profile") assert account_with_profile.profile.bio == "Test bio" - + @pytest.mark.asyncio async def test_one_to_one_cascade_delete(self, setup_tortoise_relationships): """Test cascade delete in one-to-one relationship""" account = await RelTestAccount.create(username="deleteuser", email="delete@example.com") await RelTestAccountProfile.create(account=account, bio="Will be deleted") - + # Delete account should cascade to profile await account.delete() - + # Profile should be deleted profile_count = await RelTestAccountProfile.filter(account_id=account.id).count() assert profile_count == 0 - + @pytest.mark.asyncio async def test_one_to_one_unique_constraint(self, setup_tortoise_relationships): """Test unique constraint in one-to-one relationship""" account = await RelTestAccount.create(username="uniqueuser", email="unique@example.com") await RelTestAccountProfile.create(account=account, bio="First profile") - + # Creating another profile for same account should fail with pytest.raises(IntegrityError): await RelTestAccountProfile.create(account=account, bio="Second profile") - + @pytest.mark.asyncio async def test_one_to_many_relationship_creation(self, setup_tortoise_relationships): """Test creating one-to-many relationships""" publisher = await RelTestPublisher.create(name="Test Publisher", email="publisher@example.com") pub1 = await RelTestPublication.create(title="Publication 1", isbn="1234567890123", publisher=publisher) pub2 = await RelTestPublication.create(title="Publication 2", isbn="1234567890124", publisher=publisher) - + # Test forward relationship assert pub1.publisher_id == publisher.id assert pub2.publisher_id == publisher.id - + # Test reverse relationship publisher_with_pubs = await RelTestPublisher.get(id=publisher.id).prefetch_related("publications") assert len(publisher_with_pubs.publications) == 2 assert {pub.title for pub in publisher_with_pubs.publications} == {"Publication 1", "Publication 2"} - + @pytest.mark.asyncio async def test_one_to_many_cascade_delete(self, setup_tortoise_relationships): """Test cascade delete in one-to-many relationship""" publisher = await RelTestPublisher.create(name="Delete Publisher", email="deletepub@example.com") await RelTestPublication.create(title="Delete Pub 1", isbn="9999999999999", publisher=publisher) await RelTestPublication.create(title="Delete Pub 2", isbn="9999999999998", publisher=publisher) - + # Delete publisher should cascade to publications await publisher.delete() - + # Publications should be deleted pub_count = await RelTestPublication.filter(publisher_id=publisher.id).count() assert pub_count == 0 - + @pytest.mark.asyncio async def test_foreign_key_constraint(self, setup_tortoise_relationships): """Test foreign key constraint enforcement""" # Creating publication with non-existent publisher should fail with pytest.raises(IntegrityError): await RelTestPublication.create(title="Orphan Publication", isbn="0000000000000", publisher_id=99999) - + @pytest.mark.asyncio async def test_many_to_many_relationship_creation(self, setup_tortoise_relationships): """Test creating many-to-many relationships""" @@ -159,54 +157,54 @@ async def test_many_to_many_relationship_creation(self, setup_tortoise_relations learner2 = await RelTestLearner.create(name="Learner 2", learner_id="L002") subject1 = await RelTestSubject.create(name="Math 101", code="MATH101", credits=3) subject2 = await RelTestSubject.create(name="Physics 101", code="PHYS101", credits=4) - + # Add learners to subjects await subject1.learners.add(learner1, learner2) await subject2.learners.add(learner1) - + # Test forward relationship subject1_with_learners = await RelTestSubject.get(id=subject1.id).prefetch_related("learners") assert len(subject1_with_learners.learners) == 2 - assert {l.name for l in subject1_with_learners.learners} == {"Learner 1", "Learner 2"} - + assert {learner.name for learner in subject1_with_learners.learners} == {"Learner 1", "Learner 2"} + # Test reverse relationship learner1_with_subjects = await RelTestLearner.get(id=learner1.id).prefetch_related("subjects") assert len(learner1_with_subjects.subjects) == 2 - assert {s.name for s in learner1_with_subjects.subjects} == {"Math 101", "Physics 101"} - + assert {subject.name for subject in learner1_with_subjects.subjects} == {"Math 101", "Physics 101"} + @pytest.mark.asyncio async def test_many_to_many_remove_relationship(self, setup_tortoise_relationships): """Test removing many-to-many relationships""" learner = await RelTestLearner.create(name="Remove Learner", learner_id="L003") subject = await RelTestSubject.create(name="Remove Subject", code="REM101", credits=2) - + # Add relationship await subject.learners.add(learner) subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") assert len(subject_with_learners.learners) == 1 - + # Remove relationship await subject.learners.remove(learner) subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") assert len(subject_with_learners.learners) == 0 - + @pytest.mark.asyncio async def test_many_to_many_clear_relationships(self, setup_tortoise_relationships): """Test clearing all many-to-many relationships""" learner1 = await RelTestLearner.create(name="Clear Learner 1", learner_id="L004") learner2 = await RelTestLearner.create(name="Clear Learner 2", learner_id="L005") subject = await RelTestSubject.create(name="Clear Subject", code="CLR101", credits=1) - + # Add multiple relationships await subject.learners.add(learner1, learner2) subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") assert len(subject_with_learners.learners) == 2 - + # Clear all relationships await subject.learners.clear() subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") assert len(subject_with_learners.learners) == 0 - + @pytest.mark.asyncio async def test_unique_constraints(self, setup_tortoise_relationships): """Test unique constraints on model fields""" @@ -214,13 +212,13 @@ async def test_unique_constraints(self, setup_tortoise_relationships): await RelTestAccount.create(username="unique1", email="unique1@example.com") with pytest.raises(IntegrityError): await RelTestAccount.create(username="unique1", email="different@example.com") - + # Test publication unique ISBN publisher = await RelTestPublisher.create(name="ISBN Publisher", email="isbn@example.com") await RelTestPublication.create(title="ISBN Pub 1", isbn="1111111111111", publisher=publisher) with pytest.raises(IntegrityError): await RelTestPublication.create(title="ISBN Pub 2", isbn="1111111111111", publisher=publisher) - + # Test subject unique code await RelTestSubject.create(name="Unique Subject 1", code="UNQ101", credits=3) with pytest.raises(IntegrityError): diff --git a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py index 65763683..9d9fabd6 100644 --- a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py +++ b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py @@ -18,7 +18,6 @@ import pytest import pytest_asyncio -from tests.integration.container.tortoise.models.test_models import User from tests.integration.container.tortoise.test_tortoise_common import ( run_basic_read_operations, run_basic_write_operations, setup_tortoise) from tests.integration.container.utils.conditions import ( @@ -38,19 +37,19 @@ TestEnvironmentFeatures.PERFORMANCE]) class TestTortoiseSecretsManager: """Test class for Tortoise ORM with AWS Secrets Manager authentication.""" - + @pytest.fixture(scope='class') def create_secret(self, conn_utils): """Create a secret in AWS Secrets Manager with database credentials.""" region = TestEnvironment.get_current().get_info().get_region() client = boto3.client('secretsmanager', region_name=region) - + secret_name = f"test-tortoise-secret-{TestEnvironment.get_current().get_info().get_db_name()}" secret_value = { "username": conn_utils.user, "password": conn_utils.password } - + try: response = client.create_secret( Name=secret_name, @@ -66,13 +65,13 @@ def create_secret(self, conn_utils): ) except Exception: pass - + @pytest_asyncio.fixture async def setup_tortoise_secrets_manager(self, conn_utils, create_secret): """Setup Tortoise with Secrets Manager authentication.""" - async for result in setup_tortoise(conn_utils, - plugins="aws_secrets_manager", - secrets_manager_secret_id=create_secret): + async for result in setup_tortoise(conn_utils, + plugins="aws_secrets_manager", + secrets_manager_secret_id=create_secret): yield result @pytest.mark.asyncio diff --git a/tests/integration/container/utils/connection_utils.py b/tests/integration/container/utils/connection_utils.py index 1cc57b06..324a7cc3 100644 --- a/tests/integration/container/utils/connection_utils.py +++ b/tests/integration/container/utils/connection_utils.py @@ -98,7 +98,7 @@ def get_proxy_connect_params( password = self.password if password is None else password dbname = self.dbname if dbname is None else dbname return DriverHelper.get_connect_params(host, port, user, password, dbname) - + def get_aws_tortoise_url( self, db_engine: DatabaseEngine, @@ -114,14 +114,14 @@ def get_aws_tortoise_url( user = self.user if user is None else user password = self.password if password is None else password dbname = self.dbname if dbname is None else dbname - + # Build base URL protocol = "aws-pg" if db_engine == DatabaseEngine.PG else "aws-mysql" url = f"{protocol}://{user}:{password}@{host}:{port}/{dbname}" - + # Add all kwargs as query parameters if kwargs: params = [f"{key}={value}" for key, value in kwargs.items()] url += "?" + "&".join(params) - + return url diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index f8bd36ec..187cfbb1 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -40,7 +40,7 @@ def __init__(self, request: Dict[str, Any]) -> None: self._instances = getattr(DatabaseInstances, typing.cast('str', request.get("instances"))) self._deployment = getattr(DatabaseEngineDeployment, typing.cast('str', request.get("deployment"))) self._target_python_version = getattr(TargetPythonVersion, typing.cast('str', request.get("targetPythonVersion"))) - self._num_of_instances = request.get("numOfInstances") + self._num_of_instances = request.get("numOfInstances", 1) self._features = set() features: Iterable[str] = typing.cast('Iterable[str]', request.get("features")) if features is not None: diff --git a/tests/integration/container/utils/test_utils.py b/tests/integration/container/utils/test_utils.py index bf669128..a8942c0b 100644 --- a/tests/integration/container/utils/test_utils.py +++ b/tests/integration/container/utils/test_utils.py @@ -13,13 +13,13 @@ from .database_engine import DatabaseEngine -def get_sleep_sql(db_engine : DatabaseEngine, seconds : int): +def get_sleep_sql(db_engine: DatabaseEngine, seconds: int): if db_engine == DatabaseEngine.MYSQL: return f"SELECT SLEEP({seconds});" elif db_engine == DatabaseEngine.PG: return f"SELECT PG_SLEEP({seconds});" else: - raise ValueError("Unknown database engine: " + db) + raise ValueError("Unknown database engine: " + str(db_engine)) def get_sleep_trigger_sql(db_engine: DatabaseEngine, duration: int, table_name: str) -> str: @@ -42,7 +42,7 @@ def get_sleep_trigger_sql(db_engine: DatabaseEngine, duration: int, table_name: RETURN NEW; END; $$ LANGUAGE plpgsql; - + CREATE TRIGGER {table_name}_sleep_trigger BEFORE INSERT ON {table_name} FOR EACH ROW @@ -50,4 +50,3 @@ def get_sleep_trigger_sql(db_engine: DatabaseEngine, duration: int, table_name: """ else: raise ValueError(f"Unknown database engine: {db_engine}") - diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py index 4074a9ee..829ef1f3 100644 --- a/tests/unit/test_tortoise_base_client.py +++ b/tests/unit/test_tortoise_base_client.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest + from aws_advanced_python_wrapper.tortoise.backend.base.client import ( - AwsCursorAsyncWrapper, - AwsConnectionAsyncWrapper, - TortoiseAwsClientConnectionWrapper, - TortoiseAwsClientTransactionContext, - AwsWrapperAsyncConnector, -) + AwsConnectionAsyncWrapper, AwsCursorAsyncWrapper, AwsWrapperAsyncConnector, + TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext) class TestAwsCursorAsyncWrapper: @@ -35,12 +31,12 @@ def test_init(self): async def test_execute(self): mock_cursor = MagicMock() wrapper = AwsCursorAsyncWrapper(mock_cursor) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "result" - + result = await wrapper.execute("SELECT 1", ["param"]) - + mock_to_thread.assert_called_once_with(mock_cursor.execute, "SELECT 1", ["param"]) assert result == "result" @@ -48,12 +44,12 @@ async def test_execute(self): async def test_executemany(self): mock_cursor = MagicMock() wrapper = AwsCursorAsyncWrapper(mock_cursor) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "result" - + result = await wrapper.executemany("INSERT", [["param1"], ["param2"]]) - + mock_to_thread.assert_called_once_with(mock_cursor.executemany, "INSERT", [["param1"], ["param2"]]) assert result == "result" @@ -61,12 +57,12 @@ async def test_executemany(self): async def test_fetchall(self): mock_cursor = MagicMock() wrapper = AwsCursorAsyncWrapper(mock_cursor) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = [("row1",), ("row2",)] - + result = await wrapper.fetchall() - + mock_to_thread.assert_called_once_with(mock_cursor.fetchall) assert result == [("row1",), ("row2",)] @@ -74,12 +70,12 @@ async def test_fetchall(self): async def test_fetchone(self): mock_cursor = MagicMock() wrapper = AwsCursorAsyncWrapper(mock_cursor) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = ("row1",) - + result = await wrapper.fetchone() - + mock_to_thread.assert_called_once_with(mock_cursor.fetchone) assert result == ("row1",) @@ -87,10 +83,10 @@ async def test_fetchone(self): async def test_close(self): mock_cursor = MagicMock() wrapper = AwsCursorAsyncWrapper(mock_cursor) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = None - + await wrapper.close() mock_to_thread.assert_called_once_with(mock_cursor.close) @@ -98,7 +94,7 @@ def test_getattr(self): mock_cursor = MagicMock() mock_cursor.rowcount = 5 wrapper = AwsCursorAsyncWrapper(mock_cursor) - + assert wrapper.rowcount == 5 @@ -113,28 +109,28 @@ async def test_cursor_context_manager(self): mock_connection = MagicMock() mock_cursor = MagicMock() mock_connection.cursor.return_value = mock_cursor - + wrapper = AwsConnectionAsyncWrapper(mock_connection) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.side_effect = [mock_cursor, None] - + async with wrapper.cursor() as cursor: assert isinstance(cursor, AwsCursorAsyncWrapper) assert cursor._cursor == mock_cursor - + assert mock_to_thread.call_count == 2 @pytest.mark.asyncio async def test_rollback(self): mock_connection = MagicMock() wrapper = AwsConnectionAsyncWrapper(mock_connection) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "rollback_result" - + result = await wrapper.rollback() - + mock_to_thread.assert_called_once_with(mock_connection.rollback) assert result == "rollback_result" @@ -142,12 +138,12 @@ async def test_rollback(self): async def test_commit(self): mock_connection = MagicMock() wrapper = AwsConnectionAsyncWrapper(mock_connection) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = "commit_result" - + result = await wrapper.commit() - + mock_to_thread.assert_called_once_with(mock_connection.commit) assert result == "commit_result" @@ -155,19 +151,19 @@ async def test_commit(self): async def test_set_autocommit(self): mock_connection = MagicMock() wrapper = AwsConnectionAsyncWrapper(mock_connection) - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = None - + await wrapper.set_autocommit(True) - + mock_to_thread.assert_called_once_with(setattr, mock_connection, 'autocommit', True) def test_getattr(self): mock_connection = MagicMock() mock_connection.some_attr = "test_value" wrapper = AwsConnectionAsyncWrapper(mock_connection) - + assert wrapper.some_attr == "test_value" @@ -175,21 +171,21 @@ class TestTortoiseAwsClientConnectionWrapper: def test_init(self): mock_client = MagicMock() mock_connect_func = MagicMock() - + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_connect_func, with_db=True) - + assert wrapper.client == mock_client assert wrapper.connect_func == mock_connect_func - assert wrapper.with_db == True + assert wrapper.with_db assert wrapper.connection is None @pytest.mark.asyncio async def test_ensure_connection(self): mock_client = MagicMock() mock_client.create_connection = AsyncMock() - + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, MagicMock(), with_db=True) - + await wrapper.ensure_connection() mock_client.create_connection.assert_called_once_with(with_db=True) @@ -200,17 +196,17 @@ async def test_context_manager(self): mock_client.create_connection = AsyncMock() mock_connect_func = MagicMock() mock_connection = MagicMock() - + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_connect_func, with_db=True) - - with patch.object(AwsWrapperAsyncConnector, 'ConnectWithAwsWrapper') as mock_connect: + + with patch.object(AwsWrapperAsyncConnector, 'connect_with_aws_wrapper') as mock_connect: mock_connect.return_value = mock_connection - - with patch.object(AwsWrapperAsyncConnector, 'CloseAwsWrapper') as mock_close: + + with patch.object(AwsWrapperAsyncConnector, 'close_aws_wrapper') as mock_close: async with wrapper as conn: assert conn == mock_connection assert wrapper.connection == mock_connection - + # Verify close was called on exit mock_close.assert_called_once_with(mock_connection) @@ -219,9 +215,9 @@ class TestTortoiseAwsClientTransactionContext: def test_init(self): mock_client = MagicMock() mock_client.connection_name = "test_conn" - + context = TortoiseAwsClientTransactionContext(mock_client) - + assert context.client == mock_client assert context.connection_name == "test_conn" @@ -229,9 +225,9 @@ def test_init(self): async def test_ensure_connection(self): mock_client = MagicMock() mock_client._parent.create_connection = AsyncMock() - + context = TortoiseAwsClientTransactionContext(mock_client) - + await context.ensure_connection() mock_client._parent.create_connection.assert_called_once_with(with_db=True) @@ -245,19 +241,19 @@ async def test_context_manager_commit(self): mock_client.begin = AsyncMock() mock_client.commit = AsyncMock() mock_connection = MagicMock() - + context = TortoiseAwsClientTransactionContext(mock_client) - - with patch.object(AwsWrapperAsyncConnector, 'ConnectWithAwsWrapper') as mock_connect: + + with patch.object(AwsWrapperAsyncConnector, 'connect_with_aws_wrapper') as mock_connect: mock_connect.return_value = mock_connection with patch('tortoise.connection.connections') as mock_connections: mock_connections.set.return_value = "test_token" - - with patch.object(AwsWrapperAsyncConnector, 'CloseAwsWrapper') as mock_close: + + with patch.object(AwsWrapperAsyncConnector, 'close_aws_wrapper') as mock_close: async with context as client: assert client == mock_client assert mock_client._connection == mock_connection - + # Verify commit was called and connection was closed mock_client.commit.assert_called_once() mock_close.assert_called_once_with(mock_connection) @@ -272,21 +268,21 @@ async def test_context_manager_rollback_on_exception(self): mock_client.begin = AsyncMock() mock_client.rollback = AsyncMock() mock_connection = MagicMock() - + context = TortoiseAwsClientTransactionContext(mock_client) - - with patch.object(AwsWrapperAsyncConnector, 'ConnectWithAwsWrapper') as mock_connect: + + with patch.object(AwsWrapperAsyncConnector, 'connect_with_aws_wrapper') as mock_connect: mock_connect.return_value = mock_connection with patch('tortoise.connection.connections') as mock_connections: mock_connections.set.return_value = "test_token" - - with patch.object(AwsWrapperAsyncConnector, 'CloseAwsWrapper') as mock_close: + + with patch.object(AwsWrapperAsyncConnector, 'close_aws_wrapper') as mock_close: try: - async with context as client: + async with context: raise ValueError("Test exception") except ValueError: pass - + # Verify rollback was called and connection was closed mock_client.rollback.assert_called_once() mock_close.assert_called_once_with(mock_connection) @@ -298,12 +294,12 @@ async def test_connect_with_aws_wrapper(self): mock_connect_func = MagicMock() mock_connection = MagicMock() kwargs = {"host": "localhost", "user": "test"} - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = mock_connection - - result = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(mock_connect_func, **kwargs) - + + result = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(mock_connect_func, **kwargs) + mock_to_thread.assert_called_once() assert isinstance(result, AwsConnectionAsyncWrapper) assert result._wrapped_connection == mock_connection @@ -311,10 +307,10 @@ async def test_connect_with_aws_wrapper(self): @pytest.mark.asyncio async def test_close_aws_wrapper(self): mock_connection = MagicMock() - + with patch('asyncio.to_thread') as mock_to_thread: mock_to_thread.return_value = None - - await AwsWrapperAsyncConnector.CloseAwsWrapper(mock_connection) - + + await AwsWrapperAsyncConnector.close_aws_wrapper(mock_connection) + mock_to_thread.assert_called_once_with(mock_connection.close) diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py index 5eb4b572..93fe77e5 100644 --- a/tests/unit/test_tortoise_mysql_client.py +++ b/tests/unit/test_tortoise_mysql_client.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest from mysql.connector import errors -from tortoise.exceptions import DBConnectionError, IntegrityError, OperationalError, TransactionManagementError +from tortoise.exceptions import (DBConnectionError, IntegrityError, + OperationalError, TransactionManagementError) from aws_advanced_python_wrapper.tortoise.backend.mysql.client import ( - AwsMySQLClient, - TransactionWrapper, - translate_exceptions, - _gen_savepoint_name, -) + AwsMySQLClient, TransactionWrapper, _gen_savepoint_name, + translate_exceptions) class TestTranslateExceptions: @@ -32,7 +30,7 @@ async def test_translate_operational_error(self): @translate_exceptions async def test_func(self): raise errors.OperationalError("Test error") - + with pytest.raises(OperationalError): await test_func(None) @@ -41,7 +39,7 @@ async def test_translate_integrity_error(self): @translate_exceptions async def test_func(self): raise errors.IntegrityError("Test error") - + with pytest.raises(IntegrityError): await test_func(None) @@ -50,7 +48,7 @@ async def test_no_exception(self): @translate_exceptions async def test_func(self): return "success" - + result = await test_func(None) assert result == "success" @@ -68,7 +66,7 @@ def test_init(self): storage_engine="InnoDB", connection_name="test_conn" ) - + assert client.user == "test_user" assert client.password == "test_pass" assert client.database == "test_db" @@ -87,7 +85,7 @@ def test_init_defaults(self): port=3306, connection_name="test_conn" ) - + assert client.charset == "utf8mb4" assert client.storage_engine == "innodb" @@ -103,7 +101,7 @@ async def test_create_connection_invalid_charset(self): charset="invalid_charset", connection_name="test_conn" ) - + with pytest.raises(DBConnectionError, match="Unknown character set"): await client.create_connection(with_db=True) @@ -118,10 +116,10 @@ async def test_create_connection_success(self): port=3306, connection_name="test_conn" ) - + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): await client.create_connection(with_db=True) - + assert client._template["user"] == "test_user" assert client._template["database"] == "test_db" assert client._template["autocommit"] is True @@ -137,7 +135,7 @@ async def test_close(self): port=3306, connection_name="test_conn" ) - + # close() method now does nothing (AWS wrapper handles cleanup) await client.close() # No assertions needed since method is a no-op @@ -152,7 +150,7 @@ def test_acquire_connection(self): port=3306, connection_name="test_conn" ) - + connection_wrapper = client.acquire_connection() assert connection_wrapper.client == client @@ -167,21 +165,21 @@ async def test_execute_insert(self): port=3306, connection_name="test_conn" ) - + mock_connection = MagicMock() mock_cursor = MagicMock() mock_cursor.lastrowid = 123 - + with patch.object(client, 'acquire_connection') as mock_acquire: mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) mock_acquire.return_value.__aexit__ = AsyncMock() mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) mock_connection.cursor.return_value.__aexit__ = AsyncMock() mock_cursor.execute = AsyncMock() - + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): result = await client.execute_insert("INSERT INTO test VALUES (?)", ["value"]) - + assert result == 123 mock_cursor.execute.assert_called_once_with("INSERT INTO test VALUES (?)", ["value"]) @@ -197,22 +195,22 @@ async def test_execute_many_with_transactions(self): connection_name="test_conn", storage_engine="innodb" ) - + mock_connection = MagicMock() mock_cursor = MagicMock() - + with patch.object(client, 'acquire_connection') as mock_acquire: mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) mock_acquire.return_value.__aexit__ = AsyncMock() mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) mock_connection.cursor.return_value.__aexit__ = AsyncMock() - + with patch.object(client, '_execute_many_with_transaction') as mock_execute_many_tx: mock_execute_many_tx.return_value = None - + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): await client.execute_many("INSERT INTO test VALUES (?)", [["val1"], ["val2"]]) - + mock_execute_many_tx.assert_called_once_with( mock_cursor, mock_connection, "INSERT INTO test VALUES (?)", [["val1"], ["val2"]] ) @@ -228,12 +226,12 @@ async def test_execute_query(self): port=3306, connection_name="test_conn" ) - + mock_connection = MagicMock() mock_cursor = MagicMock() mock_cursor.rowcount = 2 mock_cursor.description = [("id",), ("name",)] - + with patch.object(client, 'acquire_connection') as mock_acquire: mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) mock_acquire.return_value.__aexit__ = AsyncMock() @@ -241,10 +239,10 @@ async def test_execute_query(self): mock_connection.cursor.return_value.__aexit__ = AsyncMock() mock_cursor.execute = AsyncMock() mock_cursor.fetchall = AsyncMock(return_value=[(1, "test"), (2, "test2")]) - + with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): rowcount, results = await client.execute_query("SELECT * FROM test") - + assert rowcount == 2 assert len(results) == 2 assert results[0] == {"id": 1, "name": "test"} @@ -260,12 +258,12 @@ async def test_execute_query_dict(self): port=3306, connection_name="test_conn" ) - + with patch.object(client, 'execute_query') as mock_execute_query: mock_execute_query.return_value = (2, [{"id": 1}, {"id": 2}]) - + results = await client.execute_query_dict("SELECT * FROM test") - + assert results == [{"id": 1}, {"id": 2}] def test_in_transaction(self): @@ -278,7 +276,7 @@ def test_in_transaction(self): port=3306, connection_name="test_conn" ) - + transaction_context = client._in_transaction() assert transaction_context is not None @@ -294,9 +292,9 @@ def test_init(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) - + assert wrapper.connection_name == "test_conn" assert wrapper._parent == parent_client assert wrapper._finalized is False @@ -313,14 +311,14 @@ async def test_begin(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() mock_connection.set_autocommit = AsyncMock() wrapper._connection = mock_connection - + await wrapper.begin() - + mock_connection.set_autocommit.assert_called_once_with(False) assert wrapper._finalized is False @@ -335,15 +333,15 @@ async def test_commit(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() mock_connection.commit = AsyncMock() mock_connection.set_autocommit = AsyncMock() wrapper._connection = mock_connection - + await wrapper.commit() - + mock_connection.commit.assert_called_once() mock_connection.set_autocommit.assert_called_once_with(True) assert wrapper._finalized is True @@ -359,10 +357,10 @@ async def test_commit_already_finalized(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) wrapper._finalized = True - + with pytest.raises(TransactionManagementError, match="Transaction already finalized"): await wrapper.commit() @@ -377,15 +375,15 @@ async def test_rollback(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() mock_connection.rollback = AsyncMock() mock_connection.set_autocommit = AsyncMock() wrapper._connection = mock_connection - + await wrapper.rollback() - + mock_connection.rollback.assert_called_once() mock_connection.set_autocommit.assert_called_once_with(True) assert wrapper._finalized is True @@ -401,18 +399,18 @@ async def test_savepoint(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) mock_connection = MagicMock() mock_cursor = MagicMock() mock_cursor.execute = AsyncMock() - + wrapper._connection = mock_connection mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) mock_connection.cursor.return_value.__aexit__ = AsyncMock() - + await wrapper.savepoint() - + assert wrapper._savepoint is not None assert wrapper._savepoint.startswith("tortoise_savepoint_") mock_cursor.execute.assert_called_once() @@ -428,19 +426,19 @@ async def test_savepoint_rollback(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) wrapper._savepoint = "test_savepoint" mock_connection = MagicMock() mock_cursor = MagicMock() mock_cursor.execute = AsyncMock() - + wrapper._connection = mock_connection mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) mock_connection.cursor.return_value.__aexit__ = AsyncMock() - + await wrapper.savepoint_rollback() - + mock_cursor.execute.assert_called_once_with("ROLLBACK TO SAVEPOINT test_savepoint") assert wrapper._savepoint is None assert wrapper._finalized is True @@ -456,10 +454,10 @@ async def test_savepoint_rollback_no_savepoint(self): port=3306, connection_name="test_conn" ) - + wrapper = TransactionWrapper(parent_client) wrapper._savepoint = None - + with pytest.raises(TransactionManagementError, match="No savepoint to rollback to"): await wrapper.savepoint_rollback() @@ -468,7 +466,7 @@ class TestGenSavepointName: def test_gen_savepoint_name(self): name1 = _gen_savepoint_name() name2 = _gen_savepoint_name() - + assert name1.startswith("tortoise_savepoint_") assert name2.startswith("tortoise_savepoint_") assert name1 != name2 # Should be unique From 75164e6ef213e4f8e343d060faa1ff35bdd7861e Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Fri, 5 Dec 2025 17:17:40 -0800 Subject: [PATCH 09/11] rename path to backends instead of backend --- .../sql_alchemy_connection_provider.py | 3 ++- aws_advanced_python_wrapper/tortoise/__init__.py | 2 +- .../tortoise/{backend => backends}/__init__.py | 0 .../tortoise/{backend => backends}/base/__init__.py | 0 .../tortoise/{backend => backends}/base/client.py | 0 .../tortoise/{backend => backends}/mysql/__init__.py | 0 .../tortoise/{backend => backends}/mysql/client.py | 6 +++--- .../tortoise/{backend => backends}/mysql/executor.py | 0 .../{backend => backends}/mysql/schema_generator.py | 0 .../container/tortoise/test_tortoise_config.py | 8 ++++---- tests/unit/test_tortoise_base_client.py | 2 +- tests/unit/test_tortoise_mysql_client.py | 10 +++++----- 12 files changed, 16 insertions(+), 15 deletions(-) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/__init__.py (100%) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/base/__init__.py (100%) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/base/client.py (100%) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/mysql/__init__.py (100%) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/mysql/client.py (98%) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/mysql/executor.py (100%) rename aws_advanced_python_wrapper/tortoise/{backend => backends}/mysql/schema_generator.py (100%) diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index 0523ea56..2f3c746d 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -88,7 +88,7 @@ def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: if self._accept_url_func: return self._accept_url_func(host_info, props) url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) - return url_type in (RdsUrlType.RDS_INSTANCE, RdsUrlType.RDS_WRITER_CLUSTER) + return RdsUrlType.RDS_INSTANCE == url_type def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS or strategy in self._accepted_strategies @@ -137,6 +137,7 @@ def connect( if queue_pool is None: raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url)) + return queue_pool.connect() # The pool key should always be retrieved using this method, because the username diff --git a/aws_advanced_python_wrapper/tortoise/__init__.py b/aws_advanced_python_wrapper/tortoise/__init__.py index 7e384786..dd958d83 100644 --- a/aws_advanced_python_wrapper/tortoise/__init__.py +++ b/aws_advanced_python_wrapper/tortoise/__init__.py @@ -16,7 +16,7 @@ # Register AWS MySQL backend DB_LOOKUP["aws-mysql"] = { - "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", "vmap": { "path": "database", "hostname": "host", diff --git a/aws_advanced_python_wrapper/tortoise/backend/__init__.py b/aws_advanced_python_wrapper/tortoise/backends/__init__.py similarity index 100% rename from aws_advanced_python_wrapper/tortoise/backend/__init__.py rename to aws_advanced_python_wrapper/tortoise/backends/__init__.py diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/__init__.py b/aws_advanced_python_wrapper/tortoise/backends/base/__init__.py similarity index 100% rename from aws_advanced_python_wrapper/tortoise/backend/base/__init__.py rename to aws_advanced_python_wrapper/tortoise/backends/base/__init__.py diff --git a/aws_advanced_python_wrapper/tortoise/backend/base/client.py b/aws_advanced_python_wrapper/tortoise/backends/base/client.py similarity index 100% rename from aws_advanced_python_wrapper/tortoise/backend/base/client.py rename to aws_advanced_python_wrapper/tortoise/backends/base/client.py diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/__init__.py similarity index 100% rename from aws_advanced_python_wrapper/tortoise/backend/mysql/__init__.py rename to aws_advanced_python_wrapper/tortoise/backends/mysql/__init__.py diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/client.py similarity index 98% rename from aws_advanced_python_wrapper/tortoise/backend/mysql/client.py rename to aws_advanced_python_wrapper/tortoise/backends/mysql/client.py index 733a7a30..5b974bbf 100644 --- a/aws_advanced_python_wrapper/tortoise/backend/mysql/client.py +++ b/aws_advanced_python_wrapper/tortoise/backends/mysql/client.py @@ -30,12 +30,12 @@ OperationalError, TransactionManagementError) from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError -from aws_advanced_python_wrapper.tortoise.backend.base.client import ( +from aws_advanced_python_wrapper.tortoise.backends.base.client import ( AwsBaseDBAsyncClient, AwsConnectionAsyncWrapper, AwsTransactionalDBClient, TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext) -from aws_advanced_python_wrapper.tortoise.backend.mysql.executor import \ +from aws_advanced_python_wrapper.tortoise.backends.mysql.executor import \ AwsMySQLExecutor -from aws_advanced_python_wrapper.tortoise.backend.mysql.schema_generator import \ +from aws_advanced_python_wrapper.tortoise.backends.mysql.schema_generator import \ AwsMySQLSchemaGenerator from aws_advanced_python_wrapper.utils.log import Logger diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/executor.py similarity index 100% rename from aws_advanced_python_wrapper/tortoise/backend/mysql/executor.py rename to aws_advanced_python_wrapper/tortoise/backends/mysql/executor.py diff --git a/aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/schema_generator.py similarity index 100% rename from aws_advanced_python_wrapper/tortoise/backend/mysql/schema_generator.py rename to aws_advanced_python_wrapper/tortoise/backends/mysql/schema_generator.py diff --git a/tests/integration/container/tortoise/test_tortoise_config.py b/tests/integration/container/tortoise/test_tortoise_config.py index 06432001..9308c643 100644 --- a/tests/integration/container/tortoise/test_tortoise_config.py +++ b/tests/integration/container/tortoise/test_tortoise_config.py @@ -48,7 +48,7 @@ async def setup_tortoise_dict_config(self, conn_utils): config = { "connections": { "default": { - "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", "credentials": { "host": conn_utils.writer_cluster_host, "port": conn_utils.port, @@ -87,7 +87,7 @@ async def setup_tortoise_multi_db(self, conn_utils): config = { "connections": { "default": { - "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", "credentials": { "host": conn_utils.writer_cluster_host, "port": conn_utils.port, @@ -98,7 +98,7 @@ async def setup_tortoise_multi_db(self, conn_utils): } }, "second_db": { - "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", "credentials": { "host": conn_utils.writer_cluster_host, "port": conn_utils.port, @@ -149,7 +149,7 @@ async def setup_tortoise_with_router(self, conn_utils): config = { "connections": { "default": { - "engine": "aws_advanced_python_wrapper.tortoise.backend.mysql", + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", "credentials": { "host": conn_utils.writer_cluster_host, "port": conn_utils.port, diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py index 829ef1f3..d300a702 100644 --- a/tests/unit/test_tortoise_base_client.py +++ b/tests/unit/test_tortoise_base_client.py @@ -16,7 +16,7 @@ import pytest -from aws_advanced_python_wrapper.tortoise.backend.base.client import ( +from aws_advanced_python_wrapper.tortoise.backends.base.client import ( AwsConnectionAsyncWrapper, AwsCursorAsyncWrapper, AwsWrapperAsyncConnector, TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext) diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py index 93fe77e5..701d4514 100644 --- a/tests/unit/test_tortoise_mysql_client.py +++ b/tests/unit/test_tortoise_mysql_client.py @@ -19,7 +19,7 @@ from tortoise.exceptions import (DBConnectionError, IntegrityError, OperationalError, TransactionManagementError) -from aws_advanced_python_wrapper.tortoise.backend.mysql.client import ( +from aws_advanced_python_wrapper.tortoise.backends.mysql.client import ( AwsMySQLClient, TransactionWrapper, _gen_savepoint_name, translate_exceptions) @@ -117,7 +117,7 @@ async def test_create_connection_success(self): connection_name="test_conn" ) - with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): await client.create_connection(with_db=True) assert client._template["user"] == "test_user" @@ -177,7 +177,7 @@ async def test_execute_insert(self): mock_connection.cursor.return_value.__aexit__ = AsyncMock() mock_cursor.execute = AsyncMock() - with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): result = await client.execute_insert("INSERT INTO test VALUES (?)", ["value"]) assert result == 123 @@ -208,7 +208,7 @@ async def test_execute_many_with_transactions(self): with patch.object(client, '_execute_many_with_transaction') as mock_execute_many_tx: mock_execute_many_tx.return_value = None - with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): await client.execute_many("INSERT INTO test VALUES (?)", [["val1"], ["val2"]]) mock_execute_many_tx.assert_called_once_with( @@ -240,7 +240,7 @@ async def test_execute_query(self): mock_cursor.execute = AsyncMock() mock_cursor.fetchall = AsyncMock(return_value=[(1, "test"), (2, "test2")]) - with patch('aws_advanced_python_wrapper.tortoise.backend.mysql.client.logger'): + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): rowcount, results = await client.execute_query("SELECT * FROM test") assert rowcount == 2 From bf1c9c2c18f96e24cd5276412092bebeafefa72a Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Fri, 5 Dec 2025 17:31:38 -0800 Subject: [PATCH 10/11] remove 3.8 tests as it's not compatible with asyncio --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 89659835..888b35f7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] poetry-version: ["1.8.2"] steps: From a538ced14f712570ee3cd0d54df73b400323a3ff Mon Sep 17 00:00:00 2001 From: Favian Samatha Date: Fri, 5 Dec 2025 17:34:58 -0800 Subject: [PATCH 11/11] temp: run integration tests on this branch --- .github/workflows/integration_tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index e8c99cde..dcdf6b82 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -5,6 +5,7 @@ on: push: branches: - main + - feat/tortoise permissions: id-token: write # This is required for requesting the JWT