diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index e8c99cde..dcdf6b82 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -5,6 +5,7 @@ on: push: branches: - main + - feat/tortoise permissions: id-token: write # This is required for requesting the JWT diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 89659835..888b35f7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] poetry-version: ["1.8.2"] steps: diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index c9df891a..7eb7fd41 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection, Cursor + from types import ModuleType from abc import ABC from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError @@ -164,3 +165,7 @@ def ping(self, conn: Connection) -> bool: return True except Exception: return False + + def get_driver_module(self) -> ModuleType: + raise UnsupportedOperationError( + Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "get_driver_module")) diff --git a/aws_advanced_python_wrapper/driver_dialect_manager.py b/aws_advanced_python_wrapper/driver_dialect_manager.py index dd71de40..7f754491 100644 --- a/aws_advanced_python_wrapper/driver_dialect_manager.py +++ b/aws_advanced_python_wrapper/driver_dialect_manager.py @@ -52,7 +52,8 @@ class DriverDialectManager(DriverDialectProvider): } pool_connection_driver_dialect: Dict[str, str] = { - "SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect" + "SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect", + "SqlAlchemyTortoisePooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect", } @staticmethod diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index dd7055c5..6dd4d455 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set +import mysql.connector + if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection + from types import ModuleType from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError from inspect import signature @@ -201,3 +203,6 @@ def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) def supports_connect_timeout(self) -> bool: return True + + def get_driver_module(self) -> ModuleType: + return mysql.connector diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index fbe441f3..816540c9 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -14,6 +14,7 @@ from __future__ import annotations +from inspect import signature from typing import TYPE_CHECKING, Any, Callable, Set import psycopg @@ -21,8 +22,7 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection - -from inspect import signature + from types import ModuleType from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes @@ -175,3 +175,6 @@ def supports_tcp_keepalive(self) -> bool: def supports_abort_connection(self) -> bool: return True + + def get_driver_module(self) -> ModuleType: + return psycopg diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 65f16806..5c95e152 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -385,6 +385,8 @@ WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHos SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None. SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties. +SqlAlchemyTortoiseConnectionProvider.UnableToDetermineDialect=[SqlAlchemyTortoiseConnectionProvider] Unable to resolve sql alchemy dialect for the following driver dialect '{}'. + SqlAlchemyDriverDialect.SetValueOnNoneConnection=[SqlAlchemyDriverDialect] Attempted to set the '{}' value on a pooled connection, but no underlying driver connection was found. This can happen if the pooled connection has previously been closed. StaleDnsHelper.ClusterEndpointDns=[StaleDnsPlugin] Cluster endpoint {} resolves to {}. diff --git a/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py b/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py index 290a8f86..fad08fd9 100644 --- a/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py @@ -16,21 +16,23 @@ from typing import TYPE_CHECKING, Any -from aws_advanced_python_wrapper.driver_dialect import DriverDialect -from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.utils.messages import Messages - if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.utils.properties import Properties + from types import ModuleType from sqlalchemy import PoolProxiedConnection +from aws_advanced_python_wrapper.driver_dialect import DriverDialect +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.messages import Messages + class SqlAlchemyDriverDialect(DriverDialect): _driver_name: str = "SQLAlchemy" TARGET_DRIVER_CODE: str = "sqlalchemy" + _underlying_driver_dialect = None def __init__(self, underlying_driver: DriverDialect, props: Properties): super().__init__(props) @@ -125,3 +127,6 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection): return None return self._underlying_driver.transfer_session_state(from_driver_conn, to_driver_conn) + + def get_driver_module(self) -> ModuleType: + return self._underlying_driver.get_driver_module() diff --git a/aws_advanced_python_wrapper/tortoise/__init__.py b/aws_advanced_python_wrapper/tortoise/__init__.py new file mode 100644 index 00000000..dd958d83 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/__init__.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise.backends.base.config_generator import DB_LOOKUP + +# Register AWS MySQL backend +DB_LOOKUP["aws-mysql"] = { + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", + "vmap": { + "path": "database", + "hostname": "host", + "port": "port", + "username": "user", + "password": "password", + }, + "defaults": {"port": 3306, "charset": "utf8mb4", "sql_mode": "STRICT_TRANS_TABLES"}, + "cast": { + "minsize": int, + "maxsize": int, + "connect_timeout": int, + "echo": bool, + "use_unicode": bool, + "ssl": bool, + "use_pure": bool, + }, +} diff --git a/aws_advanced_python_wrapper/tortoise/backends/__init__.py b/aws_advanced_python_wrapper/tortoise/backends/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise/backends/base/__init__.py b/aws_advanced_python_wrapper/tortoise/backends/base/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/base/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise/backends/base/client.py b/aws_advanced_python_wrapper/tortoise/backends/base/client.py new file mode 100644 index 00000000..8d901c0f --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/base/client.py @@ -0,0 +1,199 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, Generic, cast + +import mysql.connector +from tortoise.backends.base.client import (BaseDBAsyncClient, T_conn, + TransactionalDBClient, + TransactionContext) +from tortoise.connection import connections +from tortoise.exceptions import TransactionManagementError + +from aws_advanced_python_wrapper import AwsWrapperConnection + + +class AwsWrapperAsyncConnector: + """Class for creating and closing AWS wrapper connections.""" + + @staticmethod + async def connect_with_aws_wrapper(connect_func: Callable, **kwargs) -> AwsConnectionAsyncWrapper: + """Create an AWS wrapper connection with async cursor support.""" + connection = await asyncio.to_thread( + AwsWrapperConnection.connect, connect_func, **kwargs + ) + return AwsConnectionAsyncWrapper(connection) + + @staticmethod + async def close_aws_wrapper(connection: AwsWrapperConnection) -> None: + """Close an AWS wrapper connection asynchronously.""" + await asyncio.to_thread(connection.close) + + +class AwsCursorAsyncWrapper: + """Wraps sync AwsCursor cursor with async support.""" + + def __init__(self, sync_cursor): + self._cursor = sync_cursor + + async def execute(self, query, params=None): + """Execute a query asynchronously.""" + return await asyncio.to_thread(self._cursor.execute, query, params) + + async def executemany(self, query, params_list): + """Execute multiple queries asynchronously.""" + return await asyncio.to_thread(self._cursor.executemany, query, params_list) + + async def fetchall(self): + """Fetch all results asynchronously.""" + return await asyncio.to_thread(self._cursor.fetchall) + + async def fetchone(self): + """Fetch one result asynchronously.""" + return await asyncio.to_thread(self._cursor.fetchone) + + async def close(self): + """Close cursor asynchronously.""" + return await asyncio.to_thread(self._cursor.close) + + def __getattr__(self, name): + """Delegate non-async attributes to the wrapped cursor.""" + return getattr(self._cursor, name) + + +class AwsConnectionAsyncWrapper(AwsWrapperConnection): + """Wraps sync AwsConnection with async cursor support.""" + + def __init__(self, connection: AwsWrapperConnection): + self._wrapped_connection = connection + + @asynccontextmanager + async def cursor(self): + """Create an async cursor context manager.""" + cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor) + try: + yield AwsCursorAsyncWrapper(cursor_obj) + finally: + await asyncio.to_thread(cursor_obj.close) + + async def rollback(self): + """Rollback the current transaction.""" + return await asyncio.to_thread(self._wrapped_connection.rollback) + + async def commit(self): + """Commit the current transaction.""" + return await asyncio.to_thread(self._wrapped_connection.commit) + + async def set_autocommit(self, value: bool): + """Set autocommit mode.""" + return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value) + + def __getattr__(self, name): + """Delegate all other attributes/methods to the wrapped connection.""" + return getattr(self._wrapped_connection, name) + + def __del__(self): + """Delegate cleanup to wrapped connection.""" + if hasattr(self, '_wrapped_connection'): + # Let the wrapped connection handle its own cleanup + pass + + +class AwsBaseDBAsyncClient(BaseDBAsyncClient): + _template: Dict[str, Any] + + +class AwsTransactionalDBClient(TransactionalDBClient): + _template: Dict[str, Any] + _parent: AwsBaseDBAsyncClient + pass + + +class TortoiseAwsClientConnectionWrapper(Generic[T_conn]): + """Manages acquiring from and releasing connections to a pool.""" + + __slots__ = ("client", "connection", "connect_func", "with_db") + + def __init__( + self, + client: AwsBaseDBAsyncClient, + connect_func: Callable, + with_db: bool = True + ) -> None: + self.connect_func = connect_func + self.client = client + self.connection: AwsConnectionAsyncWrapper | None = None + self.with_db = with_db + + async def ensure_connection(self) -> None: + """Ensure the connection pool is initialized.""" + await self.client.create_connection(with_db=self.with_db) + + async def __aenter__(self) -> T_conn: + """Acquire connection from pool.""" + await self.ensure_connection() + self.connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(self.connect_func, **self.client._template) + return cast("T_conn", self.connection) + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Close connection and release back to pool.""" + if self.connection: + await AwsWrapperAsyncConnector.close_aws_wrapper(self.connection) + + +class TortoiseAwsClientTransactionContext(TransactionContext): + """Transaction context that uses a pool to acquire connections.""" + + __slots__ = ("client", "connection_name", "token") + + def __init__(self, client: AwsTransactionalDBClient) -> None: + self.client: AwsTransactionalDBClient = client + self.connection_name = client.connection_name + + async def ensure_connection(self) -> None: + """Ensure the connection pool is initialized.""" + await self.client._parent.create_connection(with_db=True) + + async def __aenter__(self) -> TransactionalDBClient: + """Enter transaction context.""" + await self.ensure_connection() + + # Set the context variable so the current task sees a TransactionWrapper connection + self.token = connections.set(self.connection_name, self.client) + + # Create connection and begin transaction + self.client._connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper( + mysql.connector.Connect, + **self.client._parent._template + ) + await self.client.begin() + return self.client + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit transaction context with proper cleanup.""" + try: + if not self.client._finalized: + if exc_type: + # Can't rollback a transaction that already failed + if exc_type is not TransactionManagementError: + await self.client.rollback() + else: + await self.client.commit() + finally: + await AwsWrapperAsyncConnector.close_aws_wrapper(self.client._connection) + connections.reset(self.token) diff --git a/aws_advanced_python_wrapper/tortoise/backends/mysql/__init__.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/__init__.py new file mode 100644 index 00000000..1092dd31 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/mysql/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .client import AwsMySQLClient + +client_class = AwsMySQLClient diff --git a/aws_advanced_python_wrapper/tortoise/backends/mysql/client.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/client.py new file mode 100644 index 00000000..5b974bbf --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/mysql/client.py @@ -0,0 +1,340 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from functools import wraps +from itertools import count +from typing import (Any, Callable, Coroutine, Dict, List, Optional, + SupportsInt, Tuple, TypeVar) + +import mysql.connector +import sqlparse # type: ignore[import-untyped] +from mysql.connector import errors +from mysql.connector.charsets import MYSQL_CHARACTER_SETS +from pypika_tortoise import MySQLQuery +from tortoise.backends.base.client import (Capabilities, ConnectionWrapper, + NestedTransactionContext, + TransactionContext) +from tortoise.exceptions import (DBConnectionError, IntegrityError, + OperationalError, TransactionManagementError) + +from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError +from aws_advanced_python_wrapper.tortoise.backends.base.client import ( + AwsBaseDBAsyncClient, AwsConnectionAsyncWrapper, AwsTransactionalDBClient, + TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext) +from aws_advanced_python_wrapper.tortoise.backends.mysql.executor import \ + AwsMySQLExecutor +from aws_advanced_python_wrapper.tortoise.backends.mysql.schema_generator import \ + AwsMySQLSchemaGenerator +from aws_advanced_python_wrapper.utils.log import Logger + +logger = Logger(__name__) +T = TypeVar("T") +FuncType = Callable[..., Coroutine[None, None, T]] + + +def translate_exceptions(func: FuncType) -> FuncType: + """Decorator to translate MySQL connector exceptions to Tortoise exceptions.""" + @wraps(func) + async def translate_exceptions_(self, *args) -> T: + try: + try: + return await func(self, *args) + except AwsWrapperError as aws_err: # Unwrap any AwsWrappedErrors + if aws_err.__cause__: + raise aws_err.__cause__ + raise + except FailoverError: # Raise any failover errors + raise + except errors.IntegrityError as exc: + raise IntegrityError(exc) + except ( + errors.OperationalError, + errors.ProgrammingError, + errors.DataError, + errors.InternalError, + errors.NotSupportedError, + errors.DatabaseError + ) as exc: + raise OperationalError(exc) + + return translate_exceptions_ + + +class AwsMySQLClient(AwsBaseDBAsyncClient): + """AWS Advanced Python Wrapper MySQL client for Tortoise ORM.""" + query_class = MySQLQuery + executor_class = AwsMySQLExecutor + schema_generator = AwsMySQLSchemaGenerator + capabilities = Capabilities( + dialect="mysql", + requires_limit=True, + inline_comment=True, + support_index_hint=True, + support_for_posix_regex_queries=True, + support_json_attributes=True, + ) + + def __init__( + self, + *, + user: str, + password: str, + database: str, + host: str, + port: SupportsInt, + **kwargs, + ): + """Initialize AWS MySQL client with connection parameters.""" + super().__init__(**kwargs) + + # Basic connection parameters + self.user = user + self.password = password + self.database = database + self.host = host + self.port = int(port) + self.extra = kwargs.copy() + + # Extract MySQL-specific settings + self.storage_engine = self.extra.pop("storage_engine", "innodb") + self.charset = self.extra.pop("charset", "utf8mb4") + + # Remove Tortoise-specific parameters + self.extra.pop("connection_name", None) + self.extra.pop("fetch_inserted", None) + self.extra.pop("autocommit", None) + self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") + + # Initialize connection templates + self._init_connection_templates() + + # Initialize state + self._template: Dict[str, Any] = {} + self._connection = None + + def _init_connection_templates(self) -> None: + """Initialize connection templates for with/without database.""" + base_template = { + "user": self.user, + "password": self.password, + "host": self.host, + "port": self.port, + "autocommit": True, + **self.extra + } + + self._template_with_db = {**base_template, "database": self.database} + self._template_no_db = {**base_template, "database": None} + + # Connection Management + async def create_connection(self, with_db: bool) -> None: + """Initialize connection pool and configure database settings.""" + # Validate charset + if self.charset.lower() not in [cs[0] for cs in MYSQL_CHARACTER_SETS if cs is not None]: + raise DBConnectionError(f"Unknown character set: {self.charset}") + + # Set transaction support based on storage engine + if self.storage_engine.lower() != "innodb": + self.capabilities.__dict__["supports_transactions"] = False + + # Set template based on database requirement + self._template = self._template_with_db if with_db else self._template_no_db + + async def close(self) -> None: + """Close connections - AWS wrapper handles cleanup internally.""" + pass + + def acquire_connection(self): + """Acquire a connection from the pool.""" + return self._acquire_connection(with_db=True) + + def _acquire_connection(self, with_db: bool) -> TortoiseAwsClientConnectionWrapper: + """Create connection wrapper for specified database mode.""" + return TortoiseAwsClientConnectionWrapper( + self, mysql.connector.Connect, with_db=with_db + ) + + # Database Operations + async def db_create(self) -> None: + """Create the database.""" + await self.create_connection(with_db=False) + await self._execute_script(f"CREATE DATABASE {self.database};", False) + await self.close() + + async def db_delete(self) -> None: + """Delete the database.""" + await self.create_connection(with_db=False) + await self._execute_script(f"DROP DATABASE {self.database};", False) + await self.close() + + # Query Execution Methods + @translate_exceptions + async def execute_insert(self, query: str, values: List[Any]) -> int: + """Execute an INSERT query and return the last inserted row ID.""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + await cursor.execute(query, values) + return cursor.lastrowid + + @translate_exceptions + async def execute_many(self, query: str, values: List[List[Any]]) -> None: + """Execute a query with multiple parameter sets.""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + if self.capabilities.supports_transactions: + await self._execute_many_with_transaction(cursor, connection, query, values) + else: + await cursor.executemany(query, values) + + async def _execute_many_with_transaction(self, cursor: Any, connection: Any, query: str, values: List[List[Any]]) -> None: + """Execute many queries within a transaction.""" + try: + await connection.set_autocommit(False) + try: + await cursor.executemany(query, values) + except Exception: + await connection.rollback() + raise + else: + await connection.commit() + finally: + await connection.set_autocommit(True) + + @translate_exceptions + async def execute_query(self, query: str, values: Optional[List[Any]] = None) -> Tuple[int, List[Dict[str, Any]]]: + """Execute a query and return row count and results.""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + await cursor.execute(query, values) + rows = await cursor.fetchall() + if rows: + fields = [desc[0] for desc in cursor.description] + return cursor.rowcount, [dict(zip(fields, row)) for row in rows] + return cursor.rowcount, [] + + async def execute_query_dict(self, query: str, values: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + """Execute a query and return only the results as dictionaries.""" + return (await self.execute_query(query, values))[1] + + async def execute_script(self, query: str) -> None: + """Execute a script query.""" + await self._execute_script(query, True) + + @translate_exceptions + async def _execute_script(self, query: str, with_db: bool) -> None: + """Execute a multi-statement query by parsing and running statements sequentially.""" + async with self._acquire_connection(with_db) as connection: + logger.debug(f"Executing script: {query}") + async with connection.cursor() as cursor: + # Parse multi-statement queries since MySQL Connector doesn't handle them well + statements = sqlparse.split(query) + for statement in statements: + statement = statement.strip() + if statement: + await cursor.execute(statement) + + # Transaction Support + def _in_transaction(self) -> TransactionContext: + """Create a new transaction context.""" + return TortoiseAwsClientTransactionContext(TransactionWrapper(self)) + + +class TransactionWrapper(AwsMySQLClient, AwsTransactionalDBClient): + """Transaction wrapper for AWS MySQL client.""" + + def __init__(self, connection: AwsMySQLClient) -> None: + self.connection_name = connection.connection_name + self._connection: AwsConnectionAsyncWrapper = connection._connection + self._lock = asyncio.Lock() + self._savepoint: Optional[str] = None + self._finalized: bool = False + self._parent = connection + + def _in_transaction(self) -> TransactionContext: + """Create a nested transaction context.""" + return NestedTransactionContext(TransactionWrapper(self)) + + def acquire_connection(self): + """Acquire the transaction connection.""" + return ConnectionWrapper(self._lock, self) + + # Transaction Control Methods + @translate_exceptions + async def begin(self) -> None: + """Begin the transaction.""" + await self._connection.set_autocommit(False) + self._finalized = False + + async def commit(self) -> None: + """Commit the transaction.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + await self._connection.commit() + await self._connection.set_autocommit(True) + self._finalized = True + + async def rollback(self) -> None: + """Rollback the transaction.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + await self._connection.rollback() + await self._connection.set_autocommit(True) + self._finalized = True + + # Savepoint Management + @translate_exceptions + async def savepoint(self) -> None: + """Create a savepoint.""" + self._savepoint = _gen_savepoint_name() + async with self._connection.cursor() as cursor: + await cursor.execute(f"SAVEPOINT {self._savepoint}") + + async def savepoint_rollback(self): + """Rollback to the savepoint.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + async with self._connection.cursor() as cursor: + await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint}") + self._savepoint = None + self._finalized = True + + async def release_savepoint(self): + """Release the savepoint.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to release") + async with self._connection.cursor() as cursor: + await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint}") + self._savepoint = None + self._finalized = True + + @translate_exceptions + async def execute_many(self, query: str, values: List[List[Any]]) -> None: + """Execute many queries without autocommit handling (already in transaction).""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + await cursor.executemany(query, values) + + +def _gen_savepoint_name(_c: count = count()) -> str: + """Generate a unique savepoint name.""" + return f"tortoise_savepoint_{next(_c)}" diff --git a/aws_advanced_python_wrapper/tortoise/backends/mysql/executor.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/executor.py new file mode 100644 index 00000000..4450c494 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/mysql/executor.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from aws_advanced_python_wrapper.tortoise.utils import load_mysql_module + +MySQLExecutor: Type = load_mysql_module("executor.py", "MySQLExecutor") + + +class AwsMySQLExecutor(MySQLExecutor): + """AWS MySQL Executor for Tortoise ORM.""" + pass diff --git a/aws_advanced_python_wrapper/tortoise/backends/mysql/schema_generator.py b/aws_advanced_python_wrapper/tortoise/backends/mysql/schema_generator.py new file mode 100644 index 00000000..bda66c56 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/backends/mysql/schema_generator.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from aws_advanced_python_wrapper.tortoise.utils import load_mysql_module + +MySQLSchemaGenerator: Type = load_mysql_module("schema_generator.py", "MySQLSchemaGenerator") + + +class AwsMySQLSchemaGenerator(MySQLSchemaGenerator): + """AWS MySQL Executor for Tortoise ORM.""" + pass diff --git a/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py new file mode 100644 index 00000000..2d8bc032 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/sql_alchemy_tortoise_connection_provider.py @@ -0,0 +1,124 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, cast + +if TYPE_CHECKING: + from types import ModuleType + from sqlalchemy import Dialect + from aws_advanced_python_wrapper.database_dialect import DatabaseDialect + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.utils.properties import Properties + +from sqlalchemy.dialects.mysql import mysqlconnector +from sqlalchemy.dialects.postgresql import psycopg + +from aws_advanced_python_wrapper.connection_provider import \ + ConnectionProviderManager +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ + SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.sqlalchemy_driver_dialect import \ + SqlAlchemyDriverDialect +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages + +logger = Logger(__name__) + + +class SqlAlchemyTortoisePooledConnectionProvider(SqlAlchemyPooledConnectionProvider): + """ + Tortoise-specific pooled connection provider that handles failover by disposing pools. + """ + + _sqlalchemy_dialect_map: Dict[str, "ModuleType"] = { + "MySQLDriverDialect": mysqlconnector, + "PostgresDriverDialect": psycopg + } + + def accepts_host_info(self, host_info: "HostInfo", props: "Properties") -> bool: + if self._accept_url_func: + return self._accept_url_func(host_info, props) + url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) + return url_type.is_rds + + def _create_pool( + self, + target_func: Callable, + driver_dialect: "DriverDialect", + database_dialect: "DatabaseDialect", + host_info: "HostInfo", + props: "Properties"): + kwargs = dict() if self._pool_configurator is None else self._pool_configurator(host_info, props) + prepared_properties = driver_dialect.prepare_connect_info(host_info, props) + database_dialect.prepare_conn_props(prepared_properties) + kwargs["creator"] = self._get_connection_func(target_func, prepared_properties) + dialect = self._get_pool_dialect(driver_dialect) + if not dialect: + raise AwsWrapperError(Messages.get_formatted("SqlAlchemyTortoisePooledConnectionProvider.NoDialect", driver_dialect.__class__.__name__)) + + ''' + We need to pass in pre_ping and dialect to QueuePool the queue pool to enable health checks. + Without this health check, we could be using dead connections after a failover. + ''' + kwargs["pre_ping"] = True + kwargs["dialect"] = dialect + return self._create_sql_alchemy_pool(**kwargs) + + def _get_pool_dialect(self, driver_dialect: "DriverDialect") -> "Dialect | None": + dialect = None + driver_dialect_class_name = driver_dialect.__class__.__name__ + if driver_dialect_class_name == "SqlAlchemyDriverDialect": + # Cast to access _underlying_driver_dialect attribute + sql_alchemy_dialect = cast("SqlAlchemyDriverDialect", driver_dialect) + driver_dialect_class_name = sql_alchemy_dialect._underlying_driver_dialect.__class__.__name__ + module = self._sqlalchemy_dialect_map.get(driver_dialect_class_name) + + if not module: + return dialect + dialect = module.dialect() + dialect.dbapi = driver_dialect.get_driver_module() + + return dialect + + +def setup_tortoise_connection_provider( + pool_configurator: Optional[Callable[["HostInfo", "Properties"], Dict[str, Any]]] = None, + pool_mapping: Optional[Callable[["HostInfo", "Properties"], str]] = None +) -> SqlAlchemyTortoisePooledConnectionProvider: + """ + Helper function to set up and configure the Tortoise connection provider. + + Args: + pool_configurator: Optional function to configure pool settings. + Defaults to basic pool configuration. + pool_mapping: Optional function to generate pool keys. + Defaults to basic pool key generation. + + Returns: + Configured SqlAlchemyTortoisePooledConnectionProvider instance. + """ + def default_pool_configurator(host_info: "HostInfo", props: "Properties") -> Dict[str, Any]: + return {"pool_size": 5, "max_overflow": -1} + + def default_pool_mapping(host_info: "HostInfo", props: "Properties") -> str: + return f"{host_info.url}{props.get('user', '')}{props.get('database', '')}" + + provider = SqlAlchemyTortoisePooledConnectionProvider( + pool_configurator=pool_configurator or default_pool_configurator, + pool_mapping=pool_mapping or default_pool_mapping + ) + + ConnectionProviderManager.set_connection_provider(provider) + return provider diff --git a/aws_advanced_python_wrapper/tortoise/utils.py b/aws_advanced_python_wrapper/tortoise/utils.py new file mode 100644 index 00000000..90ecede0 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise/utils.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.util +import sys +from pathlib import Path +from typing import Type + + +def load_mysql_module(module_file: str, class_name: str) -> Type: + """Load MySQL backend module without __init__.py.""" + import tortoise + tortoise_path = Path(tortoise.__file__).parent + module_path = tortoise_path / "backends" / "mysql" / module_file + module_name = f"tortoise.backends.mysql.{module_file[:-3]}" + + if module_name not in sys.modules: + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module {module_name}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + return getattr(sys.modules[module_name], class_name) diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index e04d293f..beafa307 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -265,6 +265,13 @@ def rowcount(self) -> int: def arraysize(self) -> int: return self.target_cursor.arraysize + # Optional for PEP249 + @property + def lastrowid(self) -> int: + if hasattr(self.target_cursor, 'lastrowid'): + return self.target_cursor.lastrowid + raise AttributeError("'Cursor' object has no attribute 'lastrowid'") + def close(self) -> None: self._plugin_manager.execute(self.target_cursor, "Cursor.close", lambda: self.target_cursor.close()) diff --git a/poetry.lock b/poetry.lock index 5653d237..a00e1521 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,18 +1,24 @@ # This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] -name = "astor" -version = "0.8.1" -description = "Read/rewrite/write Python ASTs" +name = "aiosqlite" +version = "0.21.0" +description = "asyncio bridge to the standard sqlite3 module" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" -groups = ["dev"] -markers = "python_version < \"3.9\"" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "astor-0.8.1-py2.py3-none-any.whl", hash = "sha256:070a54e890cefb5b3739d19f30f5a5ec840ffc9c50ffa7d23cc9fc1a38ebbfc5"}, - {file = "astor-0.8.1.tar.gz", hash = "sha256:6a6effda93f4e1ce9f618779b2dd1d9d84f1e32812c23a29b3fff6fd7f63fa5e"}, + {file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"}, + {file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"}, ] +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"] + [[package]] name = "aws-xray-sdk" version = "2.15.0" @@ -29,36 +35,6 @@ files = [ botocore = ">=1.11.3" wrapt = "*" -[[package]] -name = "backports-zoneinfo" -version = "0.2.1" -description = "Backport of the standard library zoneinfo module" -optional = false -python-versions = ">=3.6" -groups = ["dev", "test"] -markers = "python_version < \"3.9\"" -files = [ - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, - {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, -] - -[package.extras] -tzdata = ["tzdata"] - [[package]] name = "beautifulsoup4" version = "4.12.3" @@ -483,7 +459,6 @@ files = [ ] [package.dependencies] -astor = {version = "*", markers = "python_version < \"3.9\""} classify-imports = "*" flake8 = "*" @@ -720,6 +695,18 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "iso8601" +version = "2.1.0" +description = "Simple module to parse ISO 8601 dates" +optional = false +python-versions = ">=3.7,<4.0" +groups = ["main"] +files = [ + {file = "iso8601-2.1.0-py3-none-any.whl", hash = "sha256:aac4145c4dcb66ad8b648a02830f5e2ff6c24af20f4f482689be402db2429242"}, + {file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"}, +] + [[package]] name = "isort" version = "5.13.2" @@ -1205,7 +1192,6 @@ files = [ ] [package.dependencies] -"backports.zoneinfo" = {version = ">=0.2.0", markers = "python_version < \"3.9\""} typing-extensions = {version = ">=4.6", markers = "python_version < \"3.13\""} tzdata = {version = "*", markers = "sys_platform == \"win32\""} @@ -1324,6 +1310,18 @@ files = [ {file = "pyflakes-3.1.0.tar.gz", hash = "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc"}, ] +[[package]] +name = "pypika-tortoise" +version = "0.6.2" +description = "Forked from pypika and streamline just for tortoise-orm" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pypika_tortoise-0.6.2-py3-none-any.whl", hash = "sha256:425462b02ede0a5ed7b812ec12427419927ed6b19282c55667d1cbc9a440d3cb"}, + {file = "pypika_tortoise-0.6.2.tar.gz", hash = "sha256:f95ab59d9b6454db2e8daa0934728458350a1f3d56e81d9d1debc8eebeff26b3"}, +] + [[package]] name = "pytest" version = "7.4.4" @@ -1347,6 +1345,25 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +groups = ["test"] +files = [ + {file = "pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b"}, + {file = "pytest_asyncio-0.21.2.tar.gz", hash = "sha256:d67738fc232b94b326b9d060750beb16e0074210b98dd8b58a5239fa2a154f45"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-html" version = "4.1.1" @@ -1435,6 +1452,18 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytz" +version = "2025.2" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, + {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, +] + [[package]] name = "requests" version = "2.32.4" @@ -1607,6 +1636,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlparse" +version = "0.4.4" +description = "A non-validating SQL parser." +optional = false +python-versions = ">=3.5" +groups = ["main"] +files = [ + {file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"}, + {file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"}, +] + +[package.extras] +dev = ["build", "flake8"] +doc = ["sphinx"] +test = ["pytest", "pytest-cov"] + [[package]] name = "tabulate" version = "0.9.0" @@ -1647,6 +1693,32 @@ files = [ {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] +[[package]] +name = "tortoise-orm" +version = "0.25.1" +description = "Easy async ORM for python, built with relations in mind" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "tortoise_orm-0.25.1-py3-none-any.whl", hash = "sha256:df0ef7e06eb0650a7e5074399a51ee6e532043308c612db2cac3882486a3fd9f"}, + {file = "tortoise_orm-0.25.1.tar.gz", hash = "sha256:4d5bfd13d5750935ffe636a6b25597c5c8f51c47e5b72d7509d712eda1a239fe"}, +] + +[package.dependencies] +aiosqlite = ">=0.16.0,<1.0.0" +iso8601 = {version = ">=2.1.0,<3.0.0", markers = "python_version < \"4.0\""} +pypika-tortoise = {version = ">=0.6.1,<1.0.0", markers = "python_version < \"4.0\""} +pytz = "*" + +[package.extras] +accel = ["ciso8601 ; sys_platform != \"win32\" and implementation_name == \"cpython\"", "orjson", "uvloop ; sys_platform != \"win32\" and implementation_name == \"cpython\""] +aiomysql = ["aiomysql"] +asyncmy = ["asyncmy (>=0.2.8,<1.0.0) ; python_version < \"4.0\""] +asyncodbc = ["asyncodbc (>=0.1.1,<1.0.0) ; python_version < \"4.0\""] +asyncpg = ["asyncpg"] +psycopg = ["psycopg[binary,pool] (>=3.0.12,<4.0.0)"] + [[package]] name = "toxiproxy-python" version = "0.1.1" @@ -2160,7 +2232,7 @@ description = "HTTP library with thread-safe connection pooling, file post, and optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" groups = ["main", "test"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"}, {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"}, @@ -2292,5 +2364,5 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" -python-versions = "^3.8.1" -content-hash = "4d6dc3f0080c62b28f78487a4cbefa724da3af5e9fd100a6abd376bcae3baadf" +python-versions = "^3.9" +content-hash = "e290dd3792a38a361c52cf353dbd00f5eeb850beef88339648991b41eeddb883" diff --git a/pyproject.toml b/pyproject.toml index ae8796e2..bb516992 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] [tool.poetry.dependencies] -python = "^3.8.1" +python = "^3.9" resourcebundle = "2.1.0" boto3 = "^1.34.111" toml = "^0.10.2" @@ -32,6 +32,8 @@ types_aws_xray_sdk = "^2.13.0" opentelemetry-api = "^1.22.0" opentelemetry-sdk = "^1.22.0" requests = "^2.32.2" +tortoise-orm = "^0.25.1" +sqlparse = "^0.4.4" [tool.poetry.group.dev.dependencies] mypy = "^1.9.0" @@ -63,6 +65,7 @@ mysql-connector-python = "9.0.0" opentelemetry-exporter-otlp = "^1.22.0" opentelemetry-exporter-otlp-proto-grpc = "^1.22.0" opentelemetry-sdk-extension-aws = "^2.0.1" +pytest-asyncio = "^0.21.0" [tool.isort] sections = "FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER" @@ -77,4 +80,3 @@ filterwarnings = [ 'ignore:cache could not write path', 'ignore:could not create cache path' ] - diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 2e23eeba..51878f68 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -68,8 +68,10 @@ def pytest_runtest_setup(item): test_name = item.callspec.id else: TestEnvironment.get_current().set_current_driver(None) + # Fallback to item.name if no callspec (for non-parameterized tests) + test_name = getattr(item, 'name', None) or str(item) - logger.info("Starting test preparation for: " + test_name) + logger.info(f"Starting test preparation for: {test_name}") segment: Optional[Segment] = None if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): @@ -107,7 +109,11 @@ def pytest_runtest_setup(item): logger.warning("conftest.ExceptionWhileObtainingInstanceIDs", ex) instances = list() - sleep(5) + # Only sleep if we still need to retry + if (len(instances) < request.get_num_of_instances() + or len(instances) == 0 + or not rds_utility.is_db_instance_writer(instances[0])): + sleep(5) assert len(instances) > 0 current_writer = instances[0] diff --git a/tests/integration/container/tortoise/__init__.py b/tests/integration/container/tortoise/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/tests/integration/container/tortoise/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/container/tortoise/models/__init__.py b/tests/integration/container/tortoise/models/__init__.py new file mode 100644 index 00000000..12de791f --- /dev/null +++ b/tests/integration/container/tortoise/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/tests/integration/container/tortoise/models/test_models.py b/tests/integration/container/tortoise/models/test_models.py new file mode 100644 index 00000000..b7d3ffc5 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "users" + + +class UniqueName(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=20, null=True, unique=True) + optional = fields.CharField(max_length=20, null=True) + other_optional = fields.CharField(max_length=20, null=True) + + class Meta: + table = "unique_names" + + +class TableWithSleepTrigger(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + value = fields.CharField(max_length=100) + + class Meta: + table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_copy.py b/tests/integration/container/tortoise/models/test_models_copy.py new file mode 100644 index 00000000..b7d3ffc5 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models_copy.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "users" + + +class UniqueName(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=20, null=True, unique=True) + optional = fields.CharField(max_length=20, null=True) + other_optional = fields.CharField(max_length=20, null=True) + + class Meta: + table = "unique_names" + + +class TableWithSleepTrigger(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + value = fields.CharField(max_length=100) + + class Meta: + table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_relationships.py b/tests/integration/container/tortoise/models/test_models_relationships.py new file mode 100644 index 00000000..16e2bfe6 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models_relationships.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from tortoise import fields +from tortoise.models import Model + +if TYPE_CHECKING: + from tortoise.fields.relational import (ForeignKeyFieldInstance, + OneToOneFieldInstance) + +# One-to-One Relationship Models + + +class RelTestAccount(Model): + id = fields.IntField(primary_key=True) + username = fields.CharField(max_length=50, unique=True) + email = fields.CharField(max_length=100) + + class Meta: + table = "rel_test_accounts" + + +class RelTestAccountProfile(Model): + id = fields.IntField(primary_key=True) + account: "OneToOneFieldInstance" = fields.OneToOneField("models.RelTestAccount", related_name="profile", on_delete=fields.CASCADE) + bio = fields.TextField(null=True) + avatar_url = fields.CharField(max_length=200, null=True) + + class Meta: + table = "rel_test_account_profiles" + + +# One-to-Many Relationship Models +class RelTestPublisher(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "rel_test_publishers" + + +class RelTestPublication(Model): + id = fields.IntField(primary_key=True) + title = fields.CharField(max_length=200) + isbn = fields.CharField(max_length=13, unique=True) + publisher: "ForeignKeyFieldInstance" = fields.ForeignKeyField("models.RelTestPublisher", related_name="publications", on_delete=fields.CASCADE) + published_date = fields.DateField(null=True) + + class Meta: + table = "rel_test_publications" + + +# Many-to-Many Relationship Models +class RelTestLearner(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + learner_id = fields.CharField(max_length=20, unique=True) + + class Meta: + table = "rel_test_learners" + + +class RelTestSubject(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + code = fields.CharField(max_length=10, unique=True) + credits = fields.IntField() + learners = fields.ManyToManyField("models.RelTestLearner", related_name="subjects") + + class Meta: + table = "rel_test_subjects" diff --git a/tests/integration/container/tortoise/router/__init__.py b/tests/integration/container/tortoise/router/__init__.py new file mode 100644 index 00000000..12de791f --- /dev/null +++ b/tests/integration/container/tortoise/router/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/tests/integration/container/tortoise/router/test_router.py b/tests/integration/container/tortoise/router/test_router.py new file mode 100644 index 00000000..e4693120 --- /dev/null +++ b/tests/integration/container/tortoise/router/test_router.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +class TestRouter: + def db_for_read(self, model): + return "default" + + def db_for_write(self, model): + return "default" diff --git a/tests/integration/container/tortoise/test_tortoise_basic.py b/tests/integration/container/tortoise/test_tortoise_basic.py new file mode 100644 index 00000000..d06f7de2 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_basic.py @@ -0,0 +1,261 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest +import pytest_asyncio +from tortoise.transactions import atomic, in_transaction + +from tests.integration.container.tortoise.models.test_models import ( + UniqueName, User) +from tests.integration.container.tortoise.test_tortoise_common import \ + setup_tortoise +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseBasic: + """ + Test class for Tortoise ORM integration with AWS Advanced Python Wrapper. + Contains tests related to basic test operations. + """ + + @pytest_asyncio.fixture + async def setup_tortoise_basic(self, conn_utils): + """Setup Tortoise with default plugins.""" + async for result in setup_tortoise(conn_utils): + yield result + + @pytest.mark.asyncio + async def test_basic_crud_operations(self, setup_tortoise_basic): + """Test basic CRUD operations with AWS wrapper.""" + # Create + user = await User.create(name="John Doe", email="john@example.com") + assert user.id is not None + assert user.name == "John Doe" + + # Read + found_user = await User.get(id=user.id) + assert found_user.name == "John Doe" + assert found_user.email == "john@example.com" + + # Update + found_user.name = "Jane Doe" + await found_user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Jane Doe" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_basic_crud_operations_config(self, setup_tortoise_basic): + """Test basic CRUD operations with AWS wrapper.""" + # Create + user = await User.create(name="John Doe", email="john@example.com") + assert user.id is not None + assert user.name == "John Doe" + + # Read + found_user = await User.get(id=user.id) + assert found_user.name == "John Doe" + assert found_user.email == "john@example.com" + + # Update + found_user.name = "Jane Doe" + await found_user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Jane Doe" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_transaction_support(self, setup_tortoise_basic): + """Test transaction handling with AWS wrapper.""" + async with in_transaction() as conn: + await User.create(name="User 1", email="user1@example.com", using_db=conn) + await User.create(name="User 2", email="user2@example.com", using_db=conn) + + # Verify users exist within transaction + users = await User.filter(name__in=["User 1", "User 2"]).using_db(conn) + assert len(users) == 2 + + @pytest.mark.asyncio + async def test_transaction_rollback(self, setup_tortoise_basic): + """Test transaction rollback with AWS wrapper.""" + try: + async with in_transaction() as conn: + await User.create(name="Test User", email="test@example.com", using_db=conn) + # Force rollback by raising exception + raise ValueError("Test rollback") + except ValueError: + pass + + # Verify user was not created due to rollback + users = await User.filter(name="Test User") + assert len(users) == 0 + + @pytest.mark.asyncio + async def test_bulk_operations(self, setup_tortoise_basic): + """Test bulk operations with AWS wrapper.""" + # Bulk create + users_data = [ + {"name": f"User {i}", "email": f"user{i}@example.com"} + for i in range(5) + ] + await User.bulk_create([User(**data) for data in users_data]) + + # Verify bulk creation + users = await User.filter(name__startswith="User") + assert len(users) == 5 + + # Bulk update + await User.filter(name__startswith="User").update(name="Updated User") + + updated_users = await User.filter(name="Updated User") + assert len(updated_users) == 5 + + @pytest.mark.asyncio + async def test_query_operations(self, setup_tortoise_basic): + """Test various query operations with AWS wrapper.""" + # Setup test data + await User.create(name="Alice", email="alice@example.com") + await User.create(name="Bob", email="bob@example.com") + await User.create(name="Charlie", email="charlie@example.com") + + # Test filtering + alice = await User.get(name="Alice") + assert alice.email == "alice@example.com" + + # Test count + count = await User.all().count() + assert count >= 3 + + # Test ordering + users = await User.all().order_by("name") + assert users[0].name == "Alice" + + # Test exists + exists = await User.filter(name="Alice").exists() + assert exists is True + + # Test values + emails = await User.all().values_list("email", flat=True) + assert "alice@example.com" in emails + + @pytest.mark.asyncio + async def test_bulk_create_with_ids(self, setup_tortoise_basic): + """Test bulk create operations with ID verification.""" + # Bulk create 1000 UniqueName objects with no name (null values) + await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) + + # Get all created records with id and name + all_ = await UniqueName.all().values("id", "name") + + # Get the starting ID + inc = all_[0]["id"] + + # Sort by ID for comparison + all_sorted = sorted(all_, key=lambda x: x["id"]) + + # Verify the IDs are sequential and names are None + expected = [{"id": val + inc, "name": None} for val in range(1000)] + + assert len(all_sorted) == 1000 + assert all_sorted == expected + + @pytest.mark.asyncio + async def test_concurrency_read(self, setup_tortoise_basic): + """Test concurrent read operations with AWS wrapper.""" + + await User.create(name="Test User", email="test@example.com") + user1 = await User.first() + + # Perform 100 concurrent reads + all_read = await asyncio.gather(*[User.first() for _ in range(100)]) + + # All reads should return the same user + assert all_read == [user1 for _ in range(100)] + + @pytest.mark.asyncio + async def test_concurrency_create(self, setup_tortoise_basic): + """Test concurrent create operations with AWS wrapper.""" + + # Perform 100 concurrent creates with unique emails + all_write = await asyncio.gather(*[ + User.create(name="Test", email=f"test{i}@example.com") + for i in range(100) + ]) + + # Read all created users + all_read = await User.all() + + # All created users should exist in the database + assert set(all_write) == set(all_read) + + @pytest.mark.asyncio + async def test_atomic_decorator(self, setup_tortoise_basic): + """Test atomic decorator for transaction handling with AWS wrapper.""" + + @atomic() + async def create_users_atomically(): + user1 = await User.create(name="Atomic User 1", email="atomic1@example.com") + user2 = await User.create(name="Atomic User 2", email="atomic2@example.com") + return user1, user2 + + # Execute atomic operation + user1, user2 = await create_users_atomically() + + # Verify both users were created + assert user1.id is not None + assert user2.id is not None + + # Verify users exist in database + found_users = await User.filter(name__startswith="Atomic User") + assert len(found_users) == 2 + + @pytest.mark.asyncio + async def test_atomic_decorator_rollback(self, setup_tortoise_basic): + """Test atomic decorator rollback on exception with AWS wrapper.""" + + @atomic() + async def create_users_with_error(): + await User.create(name="Atomic Rollback User", email="rollback@example.com") + # Force rollback by raising exception + raise ValueError("Intentional error for rollback test") + + # Execute atomic operation that should fail + with pytest.raises(ValueError): + await create_users_with_error() + + # Verify user was not created due to rollback + users = await User.filter(name="Atomic Rollback User") + assert len(users) == 0 diff --git a/tests/integration/container/tortoise/test_tortoise_common.py b/tests/integration/container/tortoise/test_tortoise_common.py new file mode 100644 index 00000000..f560147a --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_common.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from tortoise import Tortoise, connections + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise # noqa: F401 +from tests.integration.container.tortoise.models.test_models import User +from tests.integration.container.utils.test_environment import TestEnvironment + + +async def clear_test_models(): + """Clear all test models by calling .all().delete() on each.""" + from tortoise.models import Model + + import tests.integration.container.tortoise.models.test_models as models_module + + for attr_name in dir(models_module): + attr = getattr(models_module, attr_name) + if isinstance(attr, type) and issubclass(attr, Model) and attr != Model: + await attr.all().delete() + + +async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwargs): + """Setup Tortoise with AWS MySQL backend and configurable plugins.""" + db_url = conn_utils.get_aws_tortoise_url( + TestEnvironment.get_current().get_engine(), + plugins=plugins, + **kwargs, + ) + + config = { + "connections": { + "default": db_url + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + } + } + + from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ + setup_tortoise_connection_provider + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + + await clear_test_models() + + yield + + await clear_test_models() + await reset_tortoise() + + +async def run_basic_read_operations(name_prefix="Test", email_prefix="test"): + """Common test logic for basic read operations.""" + user = await User.create(name=f"{name_prefix} User", email=f"{email_prefix}@example.com") + + found_user = await User.get(id=user.id) + assert found_user.name == f"{name_prefix} User" + assert found_user.email == f"{email_prefix}@example.com" + + users = await User.filter(name=f"{name_prefix} User") + assert len(users) == 1 + assert users[0].id == user.id + + +async def run_basic_write_operations(name_prefix="Write", email_prefix="write"): + """Common test logic for basic write operations.""" + user = await User.create(name=f"{name_prefix} Test", email=f"{email_prefix}@example.com") + assert user.id is not None + + user.name = f"Updated {name_prefix}" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == f"Updated {name_prefix}" + + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + +async def reset_tortoise(): + await Tortoise.close_connections() + await Tortoise._reset_apps() + Tortoise._inited = False + Tortoise.apps = {} + connections._db_config = {} diff --git a/tests/integration/container/tortoise/test_tortoise_config.py b/tests/integration/container/tortoise/test_tortoise_config.py new file mode 100644 index 00000000..9308c643 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_config.py @@ -0,0 +1,292 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise, connections + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise # noqa: F401 +from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ + setup_tortoise_connection_provider +from tests.integration.container.tortoise.models.test_models import User +from tests.integration.container.tortoise.test_tortoise_common import \ + reset_tortoise +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseConfig: + """Test class for Tortoise ORM configuration scenarios.""" + + async def _clear_all_test_models(self): + """Clear all test models for a specific connection.""" + await User.all().delete() + + @pytest_asyncio.fixture + async def setup_tortoise_dict_config(self, conn_utils): + """Setup Tortoise with dictionary configuration instead of URL.""" + # Ensure clean state + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": conn_utils.dbname, + "plugins": "aurora_connection_tracker,failover", + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + } + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + await self._clear_all_test_models() + + yield + + await self._clear_all_test_models() + await reset_tortoise() + + @pytest_asyncio.fixture + async def setup_tortoise_multi_db(self, conn_utils): + """Setup Tortoise with two different databases using same backend.""" + # Create second database name + original_db = conn_utils.dbname + second_db = f"{original_db}_test2" + + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": original_db, + "plugins": "aurora_connection_tracker,failover", + } + }, + "second_db": { + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": second_db, + "plugins": "aurora_connection_tracker,failover" + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + }, + "models2": { + "models": ["tests.integration.container.tortoise.models.test_models_copy"], + "default_connection": "second_db", + } + } + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + + # Create second database + conn = connections.get("second_db") + try: + await conn.db_create() + except Exception: + pass + + await Tortoise.generate_schemas() + + await self._clear_all_test_models() + + yield second_db + await self._clear_all_test_models() + + # Drop second database + conn = connections.get("second_db") + await conn.db_delete() + await reset_tortoise() + + @pytest_asyncio.fixture + async def setup_tortoise_with_router(self, conn_utils): + """Setup Tortoise with router configuration.""" + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise.backends.mysql", + "credentials": { + "host": conn_utils.writer_cluster_host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": conn_utils.dbname, + "plugins": "aurora_connection_tracker,failover" + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + }, + "routers": ["tests.integration.container.tortoise.router.test_router.TestRouter"] + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + await self._clear_all_test_models() + + yield + + await self._clear_all_test_models() + await reset_tortoise() + + @pytest.mark.asyncio + async def test_dict_config_read_operations(self, setup_tortoise_dict_config): + """Test basic read operations with dictionary configuration.""" + # Create test data + user = await User.create(name="Dict Config User", email="dict@example.com") + + # Read operations + found_user = await User.get(id=user.id) + assert found_user.name == "Dict Config User" + assert found_user.email == "dict@example.com" + + # Query operations + users = await User.filter(name="Dict Config User") + assert len(users) == 1 + assert users[0].id == user.id + + @pytest.mark.asyncio + async def test_dict_config_write_operations(self, setup_tortoise_dict_config): + """Test basic write operations with dictionary configuration.""" + # Create + user = await User.create(name="Dict Write Test", email="dictwrite@example.com") + assert user.id is not None + + # Update + user.name = "Updated Dict User" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Updated Dict User" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_multi_db_operations(self, setup_tortoise_multi_db): + """Test operations with multiple databases using same backend.""" + # Get second connection + second_conn = connections.get("second_db") + + # Create users in different databases + user1 = await User.create(name="DB1 User", email="db1@example.com") + user2 = await User.create(name="DB2 User", email="db2@example.com", using_db=second_conn) + + # Verify users exist in their respective databases + db1_users = await User.all() + db2_users = await User.all().using_db(second_conn) + + assert len(db1_users) == 1 + assert len(db2_users) == 1 + assert db1_users[0].name == "DB1 User" + assert db2_users[0].name == "DB2 User" + + # Verify isolation - users don't exist in the other database + db1_user_in_db2 = await User.filter(name="DB1 User").using_db(second_conn) + db2_user_in_db1 = await User.filter(name="DB2 User") + + assert len(db1_user_in_db2) == 0 + assert len(db2_user_in_db1) == 0 + + # Test updates in different databases + user1.name = "Updated DB1 User" + await user1.save() + + user2.name = "Updated DB2 User" + await user2.save(using_db=second_conn) + + # Verify updates + updated_user1 = await User.get(id=user1.id) + updated_user2 = await User.get(id=user2.id, using_db=second_conn) + + assert updated_user1.name == "Updated DB1 User" + assert updated_user2.name == "Updated DB2 User" + + @pytest.mark.asyncio + async def test_router_read_operations(self, setup_tortoise_with_router): + """Test read operations with router configuration.""" + # Create test data + user = await User.create(name="Router User", email="router@example.com") + + # Read operations (should be routed by router) + found_user = await User.get(id=user.id) + assert found_user.name == "Router User" + assert found_user.email == "router@example.com" + + # Query operations + users = await User.filter(name="Router User") + assert len(users) == 1 + assert users[0].id == user.id + + @pytest.mark.asyncio + async def test_router_write_operations(self, setup_tortoise_with_router): + """Test write operations with router configuration.""" + # Create (should be routed by router) + user = await User.create(name="Router Write Test", email="routerwrite@example.com") + assert user.id is not None + + # Update (should be routed by router) + user.name = "Updated Router User" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Updated Router User" + + # Delete (should be routed by router) + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) diff --git a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py new file mode 100644 index 00000000..8a5752ad --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py @@ -0,0 +1,121 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from time import perf_counter_ns, sleep +from uuid import uuid4 + +import pytest +import pytest_asyncio +from boto3 import client +from botocore.exceptions import ClientError + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_deployments([DatabaseEngineDeployment.AURORA]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseCustomEndpoint: + """Test class for Tortoise ORM with custom endpoint plugin.""" + endpoint_id = f"test-tortoise-endpoint-{uuid4()}" + endpoint_info: dict[str, str] = {} + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing.""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + instances = env_info.get_database_info().get_instances() + instance_ids = [instances[0].get_instance_id()] + + try: + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + self._wait_until_endpoint_available(rds_client) + yield self.endpoint_info["Endpoint"] + finally: + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pass # Ignore if endpoint doesn't exist + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + """Wait for the custom endpoint to become available.""" + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) + continue + + response_endpoint = response_endpoints[0] + TestTortoiseCustomEndpoint.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") + + @pytest_asyncio.fixture + async def setup_tortoise_custom_endpoint(self, conn_utils, create_custom_endpoint): + """Setup Tortoise with custom endpoint plugin.""" + async for result in setup_tortoise(conn_utils, plugins="custom_endpoint,aurora_connection_tracker"): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_custom_endpoint): + """Test basic read operations with custom endpoint plugin.""" + await run_basic_read_operations("Custom Test", "custom") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_custom_endpoint): + """Test basic write operations with custom endpoint plugin.""" + await run_basic_write_operations("Custom", "customwrite") diff --git a/tests/integration/container/tortoise/test_tortoise_failover.py b/tests/integration/container/tortoise/test_tortoise_failover.py new file mode 100644 index 00000000..59139422 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_failover.py @@ -0,0 +1,261 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time + +import pytest +import pytest_asyncio +from tortoise import connections +from tortoise.transactions import in_transaction + +from aws_advanced_python_wrapper.errors import ( + FailoverSuccessError, TransactionResolutionUnknownError) +from tests.integration.container.tortoise.models.test_models import \ + TableWithSleepTrigger +from tests.integration.container.tortoise.test_tortoise_common import \ + setup_tortoise +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +from tests.integration.container.utils.test_utils import get_sleep_trigger_sql + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseFailover: + """Test class for Tortoise ORM with Failover.""" + @pytest.fixture(scope='class') + def aurora_utility(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + async def _create_sleep_trigger_record(self, name_prefix="Plugin Test", value="test_value", using_db=None): + await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) + + @pytest_asyncio.fixture + async def sleep_trigger_setup(self): + """Setup and cleanup sleep trigger for testing.""" + connection = connections.get("default") + db_engine = TestEnvironment.get_current().get_engine() + trigger_sql = get_sleep_trigger_sql(db_engine, 60, "table_with_sleep_trigger") + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + await connection.execute_query(trigger_sql) + yield + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + + @pytest_asyncio.fixture + async def setup_tortoise_with_failover(self, conn_utils): + """Setup Tortoise with failover plugins.""" + kwargs = { + "topology_refresh_ms": 10000, + "connect_timeout": 10, + "monitoring-connect_timeout": 5, + "use_pure": True, # MySQL specific + } + async for result in setup_tortoise(conn_utils, plugins="failover", **kwargs): + yield result + + @pytest.mark.asyncio + async def test_basic_operations_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test failover when inserting to a single table""" + insert_exception = None + + def insert_thread(): + nonlocal insert_exception + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._create_sleep_trigger_record()) + except Exception as e: + insert_exception = e + + def failover_thread(): + time.sleep(3) # Wait for insert to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start both threads + insert_t = threading.Thread(target=insert_thread) + failover_t = threading.Thread(target=failover_thread) + + insert_t.start() + failover_t.start() + + # Wait for both threads to complete + insert_t.join() + failover_t.join() + + # Assert that insert thread got FailoverSuccessError + assert insert_exception is not None + assert isinstance(insert_exception, FailoverSuccessError) + + @pytest.mark.asyncio + async def test_transaction_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test transactions with failover during long-running operations.""" + transaction_exception = None + + def transaction_thread(): + nonlocal transaction_exception + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def run_transaction(): + async with in_transaction() as conn: + await self._create_sleep_trigger_record("TX Plugin Test", "tx_test_value", conn) + + loop.run_until_complete(run_transaction()) + except Exception as e: + transaction_exception = e + + def failover_thread(): + time.sleep(5) # Wait for transaction to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start both threads + tx_t = threading.Thread(target=transaction_thread) + failover_t = threading.Thread(target=failover_thread) + + tx_t.start() + failover_t.start() + + # Wait for both threads to complete + tx_t.join() + failover_t.join() + + # Assert that transaction thread got FailoverSuccessError + assert transaction_exception is not None + assert isinstance(transaction_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) + + # Verify no records were created due to transaction rollback + record_count = await TableWithSleepTrigger.all().count() + assert record_count == 0 + + # Verify autocommit is re-enabled after failover by inserting a record + autocommit_record = await TableWithSleepTrigger.create( + name="Autocommit Test", + value="autocommit_value" + ) + + # Verify the record exists (should be auto-committed) + found_record = await TableWithSleepTrigger.get(id=autocommit_record.id) + assert found_record.name == "Autocommit Test" + assert found_record.value == "autocommit_value" + + # Clean up the test record + await found_record.delete() + + @pytest.mark.asyncio + async def test_concurrent_queries_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test concurrent queries with failover during long-running operation.""" + connection = connections.get("default") + + # Run 15 concurrent select queries + async def run_select_query(query_id): + return await connection.execute_query(f"SELECT {query_id} as query_id") + + initial_tasks = [run_select_query(i) for i in range(15)] + initial_results = await asyncio.gather(*initial_tasks) + assert len(initial_results) == 15 + + # Run insert query with failover + sleep_exception = None + + def insert_query_thread(): + nonlocal sleep_exception + for attempt in range(3): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Concurrent Test {attempt}", value="sleep_value") + ) + except Exception as e: + sleep_exception = e + break # Stop on first exception (likely the failover) + + def failover_thread(): + time.sleep(5) # Wait for sleep query to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + sleep_t = threading.Thread(target=insert_query_thread) + failover_t = threading.Thread(target=failover_thread) + + sleep_t.start() + failover_t.start() + + sleep_t.join() + failover_t.join() + + # Verify failover exception occurred + assert sleep_exception is not None + assert isinstance(sleep_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) + + # Run another 15 concurrent select queries after failover to verify that we can still make connections + post_failover_tasks = [run_select_query(i + 100) for i in range(15)] + post_failover_results = await asyncio.gather(*post_failover_tasks) + assert len(post_failover_results) == 15 + + @pytest.mark.asyncio + async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test multiple concurrent insert operations with failover during long-running operations.""" + # Track exceptions from all insert threads + insert_exceptions = [] + + def insert_thread(thread_id): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Concurrent Insert {thread_id}", value="insert_value") + ) + except Exception as e: + insert_exceptions.append(e) + + def failover_thread(): + time.sleep(5) # Wait for inserts to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start 15 insert threads + insert_threads = [] + for i in range(15): + t = threading.Thread(target=insert_thread, args=(i,)) + insert_threads.append(t) + t.start() + + # Start failover thread + failover_t = threading.Thread(target=failover_thread) + failover_t.start() + + # Wait for all threads to complete + for t in insert_threads: + t.join() + failover_t.join() + + # Verify ALL threads got FailoverSuccessError or TransactionResolutionUnknownError + assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" + failover_errors = [e for e in insert_exceptions if isinstance(e, (FailoverSuccessError, TransactionResolutionUnknownError))] + assert len(failover_errors) == 15, f"Expected all 15 threads to get failover errors, got {len(failover_errors)}" diff --git a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py new file mode 100644 index 00000000..b4f9e32d --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features, + enable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_features([TestEnvironmentFeatures.IAM]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseIamAuthentication: + """Test class for Tortoise ORM with IAM authentication.""" + + @pytest_asyncio.fixture + async def setup_tortoise_iam(self, conn_utils): + """Setup Tortoise with IAM authentication.""" + async for result in setup_tortoise(conn_utils, plugins="iam", user=conn_utils.iam_user): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_iam): + """Test basic read operations with IAM authentication.""" + await run_basic_read_operations() + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_iam): + """Test basic write operations with IAM authentication.""" + await run_basic_write_operations() diff --git a/tests/integration/container/tortoise/test_tortoise_multi_plugins.py b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py new file mode 100644 index 00000000..d6abed4b --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_multi_plugins.py @@ -0,0 +1,283 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time +from time import perf_counter_ns, sleep +from uuid import uuid4 + +import pytest +import pytest_asyncio +from boto3 import client +from botocore.exceptions import ClientError +from tortoise import connections + +from aws_advanced_python_wrapper.errors import FailoverSuccessError +from tests.integration.container.tortoise.models.test_models import \ + TableWithSleepTrigger +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_features, enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +from tests.integration.container.utils.test_utils import get_sleep_trigger_sql + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_features([TestEnvironmentFeatures.IAM]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseMultiPlugins: + """Test class for Tortoise ORM with multiple AWS wrapper plugins.""" + endpoint_id = f"test-multi-endpoint-{uuid4()}" + endpoint_info: dict[str, str] = {} + + @pytest.fixture(scope='class') + def aurora_utility(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing.""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + instances = env_info.get_database_info().get_instances() + instance_ids = [instances[0].get_instance_id()] + + try: + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + self._wait_until_endpoint_available(rds_client) + yield self.endpoint_info["Endpoint"] + finally: + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pass + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + """Wait for the custom endpoint to become available.""" + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) + continue + + response_endpoint = response_endpoints[0] + TestTortoiseMultiPlugins.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") + + async def _create_sleep_trigger_record(self, name_prefix="Multi Test", value="multi_value", using_db=None): + """Create TableWithSleepTrigger record.""" + await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) + + @pytest_asyncio.fixture + async def sleep_trigger_setup(self): + """Setup and cleanup sleep trigger for testing.""" + connection = connections.get("default") + db_engine = TestEnvironment.get_current().get_engine() + trigger_sql = get_sleep_trigger_sql(db_engine, 60, "table_with_sleep_trigger") + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + await connection.execute_query(trigger_sql) + yield + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + + @pytest_asyncio.fixture + async def setup_tortoise_multi_plugins(self, conn_utils, create_custom_endpoint): + """Setup Tortoise with multiple plugins.""" + kwargs = { + "topology_refresh_ms": 10000, + "reader_host_selector_strategy": "fastest_response" + } + async for result in setup_tortoise(conn_utils, + plugins="failover,iam,custom_endpoint,aurora_connection_tracker,fastest_response_strategy", + user=conn_utils.iam_user, **kwargs + ): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_multi_plugins): + """Test basic read operations with multiple plugins.""" + await run_basic_read_operations("Multi", "multi") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_multi_plugins): + """Test basic write operations with multiple plugins.""" + await run_basic_write_operations("Multi", "multi") + + @pytest.mark.asyncio + async def test_multi_plugins_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): + """Test multiple plugins work with failover during long-running operations.""" + insert_exception = None + + def insert_thread(): + nonlocal insert_exception + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._create_sleep_trigger_record()) + except Exception as e: + insert_exception = e + + def failover_thread(): + time.sleep(5) # Wait for insert to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start both threads + insert_t = threading.Thread(target=insert_thread) + failover_t = threading.Thread(target=failover_thread) + + insert_t.start() + failover_t.start() + + # Wait for both threads to complete + insert_t.join() + failover_t.join() + + # Assert that insert thread got FailoverSuccessError + assert insert_exception is not None + assert isinstance(insert_exception, FailoverSuccessError) + + @pytest.mark.asyncio + async def test_concurrent_queries_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): + """Test concurrent queries with failover during long-running operation.""" + connection = connections.get("default") + + # Run 15 concurrent select queries + async def run_select_query(query_id): + return await connection.execute_query(f"SELECT {query_id} as query_id") + + initial_tasks = [run_select_query(i) for i in range(15)] + initial_results = await asyncio.gather(*initial_tasks) + assert len(initial_results) == 15 + + # Run sleep query with failover + sleep_exception = None + + def sleep_query_thread(): + nonlocal sleep_exception + for attempt in range(3): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Multi Concurrent Test {attempt}", value="multi_sleep_value") + ) + except Exception as e: + sleep_exception = e + break # Stop on first exception (likely the failover) + + def failover_thread(): + time.sleep(5) # Wait for sleep query to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + sleep_t = threading.Thread(target=sleep_query_thread) + failover_t = threading.Thread(target=failover_thread) + + sleep_t.start() + failover_t.start() + + sleep_t.join() + failover_t.join() + + # Verify failover exception occurred + assert sleep_exception is not None + assert isinstance(sleep_exception, FailoverSuccessError) + + # Run another 15 concurrent select queries after failover + post_failover_tasks = [run_select_query(i + 100) for i in range(15)] + post_failover_results = await asyncio.gather(*post_failover_tasks) + assert len(post_failover_results) == 15 + + @pytest.mark.asyncio + async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_multi_plugins, sleep_trigger_setup, aurora_utility): + """Test multiple concurrent insert operations with failover during long-running operations.""" + # Track exceptions from all insert threads + insert_exceptions = [] + + def insert_thread(thread_id): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + TableWithSleepTrigger.create(name=f"Multi Concurrent Insert {thread_id}", value="multi_insert_value") + ) + except Exception as e: + insert_exceptions.append(e) + + def failover_thread(): + time.sleep(5) # Wait for inserts to start + aurora_utility.failover_cluster_and_wait_until_writer_changed() + + # Start 15 insert threads + insert_threads = [] + for i in range(15): + t = threading.Thread(target=insert_thread, args=(i,)) + insert_threads.append(t) + t.start() + + # Start failover thread + failover_t = threading.Thread(target=failover_thread) + failover_t.start() + + # Wait for all threads to complete + for t in insert_threads: + t.join() + failover_t.join() + + # Verify ALL threads got FailoverSuccessError or TransactionResolutionUnknownError + assert len(insert_exceptions) == 15, f"Expected 15 exceptions, got {len(insert_exceptions)}" + failover_errors = [e for e in insert_exceptions if isinstance(e, FailoverSuccessError)] + assert len(failover_errors) == 15, f"Expected all 15 threads to get failover errors, got {len(failover_errors)}" diff --git a/tests/integration/container/tortoise/test_tortoise_relations.py b/tests/integration/container/tortoise/test_tortoise_relations.py new file mode 100644 index 00000000..61281749 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_relations.py @@ -0,0 +1,225 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise +from tortoise.exceptions import IntegrityError + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise # noqa: F401 +from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \ + setup_tortoise_connection_provider +from tests.integration.container.tortoise.models.test_models_relationships import ( + RelTestAccount, RelTestAccountProfile, RelTestLearner, RelTestPublication, + RelTestPublisher, RelTestSubject) +from tests.integration.container.tortoise.test_tortoise_common import \ + reset_tortoise +from tests.integration.container.utils.conditions import disable_on_engines +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment import TestEnvironment + + +async def clear_relationship_models(): + """Clear all relationship test models.""" + await RelTestSubject.all().delete() + await RelTestLearner.all().delete() + await RelTestPublication.all().delete() + await RelTestPublisher.all().delete() + await RelTestAccountProfile.all().delete() + await RelTestAccount.all().delete() + + +@disable_on_engines([DatabaseEngine.PG]) +class TestTortoiseRelationships: + + @pytest_asyncio.fixture(scope="function") + async def setup_tortoise_relationships(self, conn_utils): + """Setup Tortoise with relationship models.""" + db_url = conn_utils.get_aws_tortoise_url( + TestEnvironment.get_current().get_engine(), + plugins="aurora_connection_tracker", + ) + + config = { + "connections": { + "default": db_url + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models_relationships"], + "default_connection": "default", + } + } + } + + setup_tortoise_connection_provider() + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + + # Clear any existing data + await clear_relationship_models() + + yield + + # Cleanup + await clear_relationship_models() + await reset_tortoise() + + @pytest.mark.asyncio + async def test_one_to_one_relationship_creation(self, setup_tortoise_relationships): + """Test creating one-to-one relationships""" + account = await RelTestAccount.create(username="testuser", email="test@example.com") + profile = await RelTestAccountProfile.create(account=account, bio="Test bio", avatar_url="http://example.com/avatar.jpg") + + # Test forward relationship + assert profile.account_id == account.id + + # Test reverse relationship + account_with_profile = await RelTestAccount.get(id=account.id).prefetch_related("profile") + assert account_with_profile.profile.bio == "Test bio" + + @pytest.mark.asyncio + async def test_one_to_one_cascade_delete(self, setup_tortoise_relationships): + """Test cascade delete in one-to-one relationship""" + account = await RelTestAccount.create(username="deleteuser", email="delete@example.com") + await RelTestAccountProfile.create(account=account, bio="Will be deleted") + + # Delete account should cascade to profile + await account.delete() + + # Profile should be deleted + profile_count = await RelTestAccountProfile.filter(account_id=account.id).count() + assert profile_count == 0 + + @pytest.mark.asyncio + async def test_one_to_one_unique_constraint(self, setup_tortoise_relationships): + """Test unique constraint in one-to-one relationship""" + account = await RelTestAccount.create(username="uniqueuser", email="unique@example.com") + await RelTestAccountProfile.create(account=account, bio="First profile") + + # Creating another profile for same account should fail + with pytest.raises(IntegrityError): + await RelTestAccountProfile.create(account=account, bio="Second profile") + + @pytest.mark.asyncio + async def test_one_to_many_relationship_creation(self, setup_tortoise_relationships): + """Test creating one-to-many relationships""" + publisher = await RelTestPublisher.create(name="Test Publisher", email="publisher@example.com") + pub1 = await RelTestPublication.create(title="Publication 1", isbn="1234567890123", publisher=publisher) + pub2 = await RelTestPublication.create(title="Publication 2", isbn="1234567890124", publisher=publisher) + + # Test forward relationship + assert pub1.publisher_id == publisher.id + assert pub2.publisher_id == publisher.id + + # Test reverse relationship + publisher_with_pubs = await RelTestPublisher.get(id=publisher.id).prefetch_related("publications") + assert len(publisher_with_pubs.publications) == 2 + assert {pub.title for pub in publisher_with_pubs.publications} == {"Publication 1", "Publication 2"} + + @pytest.mark.asyncio + async def test_one_to_many_cascade_delete(self, setup_tortoise_relationships): + """Test cascade delete in one-to-many relationship""" + publisher = await RelTestPublisher.create(name="Delete Publisher", email="deletepub@example.com") + await RelTestPublication.create(title="Delete Pub 1", isbn="9999999999999", publisher=publisher) + await RelTestPublication.create(title="Delete Pub 2", isbn="9999999999998", publisher=publisher) + + # Delete publisher should cascade to publications + await publisher.delete() + + # Publications should be deleted + pub_count = await RelTestPublication.filter(publisher_id=publisher.id).count() + assert pub_count == 0 + + @pytest.mark.asyncio + async def test_foreign_key_constraint(self, setup_tortoise_relationships): + """Test foreign key constraint enforcement""" + # Creating publication with non-existent publisher should fail + with pytest.raises(IntegrityError): + await RelTestPublication.create(title="Orphan Publication", isbn="0000000000000", publisher_id=99999) + + @pytest.mark.asyncio + async def test_many_to_many_relationship_creation(self, setup_tortoise_relationships): + """Test creating many-to-many relationships""" + learner1 = await RelTestLearner.create(name="Learner 1", learner_id="L001") + learner2 = await RelTestLearner.create(name="Learner 2", learner_id="L002") + subject1 = await RelTestSubject.create(name="Math 101", code="MATH101", credits=3) + subject2 = await RelTestSubject.create(name="Physics 101", code="PHYS101", credits=4) + + # Add learners to subjects + await subject1.learners.add(learner1, learner2) + await subject2.learners.add(learner1) + + # Test forward relationship + subject1_with_learners = await RelTestSubject.get(id=subject1.id).prefetch_related("learners") + assert len(subject1_with_learners.learners) == 2 + assert {learner.name for learner in subject1_with_learners.learners} == {"Learner 1", "Learner 2"} + + # Test reverse relationship + learner1_with_subjects = await RelTestLearner.get(id=learner1.id).prefetch_related("subjects") + assert len(learner1_with_subjects.subjects) == 2 + assert {subject.name for subject in learner1_with_subjects.subjects} == {"Math 101", "Physics 101"} + + @pytest.mark.asyncio + async def test_many_to_many_remove_relationship(self, setup_tortoise_relationships): + """Test removing many-to-many relationships""" + learner = await RelTestLearner.create(name="Remove Learner", learner_id="L003") + subject = await RelTestSubject.create(name="Remove Subject", code="REM101", credits=2) + + # Add relationship + await subject.learners.add(learner) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 1 + + # Remove relationship + await subject.learners.remove(learner) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 0 + + @pytest.mark.asyncio + async def test_many_to_many_clear_relationships(self, setup_tortoise_relationships): + """Test clearing all many-to-many relationships""" + learner1 = await RelTestLearner.create(name="Clear Learner 1", learner_id="L004") + learner2 = await RelTestLearner.create(name="Clear Learner 2", learner_id="L005") + subject = await RelTestSubject.create(name="Clear Subject", code="CLR101", credits=1) + + # Add multiple relationships + await subject.learners.add(learner1, learner2) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 2 + + # Clear all relationships + await subject.learners.clear() + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 0 + + @pytest.mark.asyncio + async def test_unique_constraints(self, setup_tortoise_relationships): + """Test unique constraints on model fields""" + # Test account unique username + await RelTestAccount.create(username="unique1", email="unique1@example.com") + with pytest.raises(IntegrityError): + await RelTestAccount.create(username="unique1", email="different@example.com") + + # Test publication unique ISBN + publisher = await RelTestPublisher.create(name="ISBN Publisher", email="isbn@example.com") + await RelTestPublication.create(title="ISBN Pub 1", isbn="1111111111111", publisher=publisher) + with pytest.raises(IntegrityError): + await RelTestPublication.create(title="ISBN Pub 2", isbn="1111111111111", publisher=publisher) + + # Test subject unique code + await RelTestSubject.create(name="Unique Subject 1", code="UNQ101", credits=3) + with pytest.raises(IntegrityError): + await RelTestSubject.create(name="Unique Subject 2", code="UNQ101", credits=4) diff --git a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py new file mode 100644 index 00000000..9d9fabd6 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import boto3 +import pytest +import pytest_asyncio + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseSecretsManager: + """Test class for Tortoise ORM with AWS Secrets Manager authentication.""" + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + client = boto3.client('secretsmanager', region_name=region) + + secret_name = f"test-tortoise-secret-{TestEnvironment.get_current().get_info().get_db_name()}" + secret_value = { + "username": conn_utils.user, + "password": conn_utils.password + } + + try: + response = client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_id = response['ARN'] + yield secret_id + finally: + try: + client.delete_secret( + SecretId=secret_name, + ForceDeleteWithoutRecovery=True + ) + except Exception: + pass + + @pytest_asyncio.fixture + async def setup_tortoise_secrets_manager(self, conn_utils, create_secret): + """Setup Tortoise with Secrets Manager authentication.""" + async for result in setup_tortoise(conn_utils, + plugins="aws_secrets_manager", + secrets_manager_secret_id=create_secret): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_secrets_manager): + """Test basic read operations with Secrets Manager authentication.""" + await run_basic_read_operations("Secrets", "secrets") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_secrets_manager): + """Test basic write operations with Secrets Manager authentication.""" + await run_basic_write_operations("Secrets", "secretswrite") diff --git a/tests/integration/container/utils/connection_utils.py b/tests/integration/container/utils/connection_utils.py index 52d0add5..324a7cc3 100644 --- a/tests/integration/container/utils/connection_utils.py +++ b/tests/integration/container/utils/connection_utils.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Optional +from .database_engine import DatabaseEngine from .driver_helper import DriverHelper from .test_environment import TestEnvironment @@ -97,3 +98,30 @@ def get_proxy_connect_params( password = self.password if password is None else password dbname = self.dbname if dbname is None else dbname return DriverHelper.get_connect_params(host, port, user, password, dbname) + + def get_aws_tortoise_url( + self, + db_engine: DatabaseEngine, + host: Optional[str] = None, + port: Optional[int] = None, + user: Optional[str] = None, + password: Optional[str] = None, + dbname: Optional[str] = None, + **kwargs) -> str: + """Build AWS MySQL connection URL for Tortoise ORM with query parameters.""" + host = self.writer_cluster_host if host is None else host + port = self.port if port is None else port + user = self.user if user is None else user + password = self.password if password is None else password + dbname = self.dbname if dbname is None else dbname + + # Build base URL + protocol = "aws-pg" if db_engine == DatabaseEngine.PG else "aws-mysql" + url = f"{protocol}://{user}:{password}@{host}:{port}/{dbname}" + + # Add all kwargs as query parameters + if kwargs: + params = [f"{key}={value}" for key, value in kwargs.items()] + url += "?" + "&".join(params) + + return url diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index db1bbeef..187cfbb1 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -40,7 +40,7 @@ def __init__(self, request: Dict[str, Any]) -> None: self._instances = getattr(DatabaseInstances, typing.cast('str', request.get("instances"))) self._deployment = getattr(DatabaseEngineDeployment, typing.cast('str', request.get("deployment"))) self._target_python_version = getattr(TargetPythonVersion, typing.cast('str', request.get("targetPythonVersion"))) - self._num_of_instances = typing.cast('int', request.get("numOfInstances")) + self._num_of_instances = request.get("numOfInstances", 1) self._features = set() features: Iterable[str] = typing.cast('Iterable[str]', request.get("features")) if features is not None: diff --git a/tests/integration/container/utils/test_telemetry_info.py b/tests/integration/container/utils/test_telemetry_info.py index 24267dc8..e37941ef 100644 --- a/tests/integration/container/utils/test_telemetry_info.py +++ b/tests/integration/container/utils/test_telemetry_info.py @@ -11,18 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + import typing from typing import Any, Dict diff --git a/tests/integration/container/utils/test_utils.py b/tests/integration/container/utils/test_utils.py new file mode 100644 index 00000000..a8942c0b --- /dev/null +++ b/tests/integration/container/utils/test_utils.py @@ -0,0 +1,52 @@ +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .database_engine import DatabaseEngine + + +def get_sleep_sql(db_engine: DatabaseEngine, seconds: int): + if db_engine == DatabaseEngine.MYSQL: + return f"SELECT SLEEP({seconds});" + elif db_engine == DatabaseEngine.PG: + return f"SELECT PG_SLEEP({seconds});" + else: + raise ValueError("Unknown database engine: " + str(db_engine)) + + +def get_sleep_trigger_sql(db_engine: DatabaseEngine, duration: int, table_name: str) -> str: + """Generate SQL to create a sleep trigger for INSERT operations.""" + if db_engine == DatabaseEngine.MYSQL: + return f""" + CREATE TRIGGER {table_name}_sleep_trigger + BEFORE INSERT ON {table_name} + FOR EACH ROW + BEGIN + DO SLEEP({duration}); + END + """ + elif db_engine == DatabaseEngine.PG: + return f""" + CREATE OR REPLACE FUNCTION {table_name}_sleep_function() + RETURNS TRIGGER AS $$ + BEGIN + PERFORM pg_sleep({duration}); + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + + CREATE TRIGGER {table_name}_sleep_trigger + BEFORE INSERT ON {table_name} + FOR EACH ROW + EXECUTE FUNCTION {table_name}_sleep_function(); + """ + else: + raise ValueError(f"Unknown database engine: {db_engine}") diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py new file mode 100644 index 00000000..d300a702 --- /dev/null +++ b/tests/unit/test_tortoise_base_client.py @@ -0,0 +1,316 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from aws_advanced_python_wrapper.tortoise.backends.base.client import ( + AwsConnectionAsyncWrapper, AwsCursorAsyncWrapper, AwsWrapperAsyncConnector, + TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext) + + +class TestAwsCursorAsyncWrapper: + def test_init(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + assert wrapper._cursor == mock_cursor + + @pytest.mark.asyncio + async def test_execute(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "result" + + result = await wrapper.execute("SELECT 1", ["param"]) + + mock_to_thread.assert_called_once_with(mock_cursor.execute, "SELECT 1", ["param"]) + assert result == "result" + + @pytest.mark.asyncio + async def test_executemany(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "result" + + result = await wrapper.executemany("INSERT", [["param1"], ["param2"]]) + + mock_to_thread.assert_called_once_with(mock_cursor.executemany, "INSERT", [["param1"], ["param2"]]) + assert result == "result" + + @pytest.mark.asyncio + async def test_fetchall(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = [("row1",), ("row2",)] + + result = await wrapper.fetchall() + + mock_to_thread.assert_called_once_with(mock_cursor.fetchall) + assert result == [("row1",), ("row2",)] + + @pytest.mark.asyncio + async def test_fetchone(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = ("row1",) + + result = await wrapper.fetchone() + + mock_to_thread.assert_called_once_with(mock_cursor.fetchone) + assert result == ("row1",) + + @pytest.mark.asyncio + async def test_close(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await wrapper.close() + mock_to_thread.assert_called_once_with(mock_cursor.close) + + def test_getattr(self): + mock_cursor = MagicMock() + mock_cursor.rowcount = 5 + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + assert wrapper.rowcount == 5 + + +class TestAwsConnectionAsyncWrapper: + def test_init(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + assert wrapper._wrapped_connection == mock_connection + + @pytest.mark.asyncio + async def test_cursor_context_manager(self): + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.side_effect = [mock_cursor, None] + + async with wrapper.cursor() as cursor: + assert isinstance(cursor, AwsCursorAsyncWrapper) + assert cursor._cursor == mock_cursor + + assert mock_to_thread.call_count == 2 + + @pytest.mark.asyncio + async def test_rollback(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "rollback_result" + + result = await wrapper.rollback() + + mock_to_thread.assert_called_once_with(mock_connection.rollback) + assert result == "rollback_result" + + @pytest.mark.asyncio + async def test_commit(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "commit_result" + + result = await wrapper.commit() + + mock_to_thread.assert_called_once_with(mock_connection.commit) + assert result == "commit_result" + + @pytest.mark.asyncio + async def test_set_autocommit(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await wrapper.set_autocommit(True) + + mock_to_thread.assert_called_once_with(setattr, mock_connection, 'autocommit', True) + + def test_getattr(self): + mock_connection = MagicMock() + mock_connection.some_attr = "test_value" + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + assert wrapper.some_attr == "test_value" + + +class TestTortoiseAwsClientConnectionWrapper: + def test_init(self): + mock_client = MagicMock() + mock_connect_func = MagicMock() + + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_connect_func, with_db=True) + + assert wrapper.client == mock_client + assert wrapper.connect_func == mock_connect_func + assert wrapper.with_db + assert wrapper.connection is None + + @pytest.mark.asyncio + async def test_ensure_connection(self): + mock_client = MagicMock() + mock_client.create_connection = AsyncMock() + + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, MagicMock(), with_db=True) + + await wrapper.ensure_connection() + mock_client.create_connection.assert_called_once_with(with_db=True) + + @pytest.mark.asyncio + async def test_context_manager(self): + mock_client = MagicMock() + mock_client._template = {"host": "localhost"} + mock_client.create_connection = AsyncMock() + mock_connect_func = MagicMock() + mock_connection = MagicMock() + + wrapper = TortoiseAwsClientConnectionWrapper(mock_client, mock_connect_func, with_db=True) + + with patch.object(AwsWrapperAsyncConnector, 'connect_with_aws_wrapper') as mock_connect: + mock_connect.return_value = mock_connection + + with patch.object(AwsWrapperAsyncConnector, 'close_aws_wrapper') as mock_close: + async with wrapper as conn: + assert conn == mock_connection + assert wrapper.connection == mock_connection + + # Verify close was called on exit + mock_close.assert_called_once_with(mock_connection) + + +class TestTortoiseAwsClientTransactionContext: + def test_init(self): + mock_client = MagicMock() + mock_client.connection_name = "test_conn" + + context = TortoiseAwsClientTransactionContext(mock_client) + + assert context.client == mock_client + assert context.connection_name == "test_conn" + + @pytest.mark.asyncio + async def test_ensure_connection(self): + mock_client = MagicMock() + mock_client._parent.create_connection = AsyncMock() + + context = TortoiseAwsClientTransactionContext(mock_client) + + await context.ensure_connection() + mock_client._parent.create_connection.assert_called_once_with(with_db=True) + + @pytest.mark.asyncio + async def test_context_manager_commit(self): + mock_client = MagicMock() + mock_client._parent._template = {"host": "localhost"} + mock_client._parent.create_connection = AsyncMock() + mock_client.connection_name = "test_conn" + mock_client._finalized = False + mock_client.begin = AsyncMock() + mock_client.commit = AsyncMock() + mock_connection = MagicMock() + + context = TortoiseAwsClientTransactionContext(mock_client) + + with patch.object(AwsWrapperAsyncConnector, 'connect_with_aws_wrapper') as mock_connect: + mock_connect.return_value = mock_connection + with patch('tortoise.connection.connections') as mock_connections: + mock_connections.set.return_value = "test_token" + + with patch.object(AwsWrapperAsyncConnector, 'close_aws_wrapper') as mock_close: + async with context as client: + assert client == mock_client + assert mock_client._connection == mock_connection + + # Verify commit was called and connection was closed + mock_client.commit.assert_called_once() + mock_close.assert_called_once_with(mock_connection) + + @pytest.mark.asyncio + async def test_context_manager_rollback_on_exception(self): + mock_client = MagicMock() + mock_client._parent._template = {"host": "localhost"} + mock_client._parent.create_connection = AsyncMock() + mock_client.connection_name = "test_conn" + mock_client._finalized = False + mock_client.begin = AsyncMock() + mock_client.rollback = AsyncMock() + mock_connection = MagicMock() + + context = TortoiseAwsClientTransactionContext(mock_client) + + with patch.object(AwsWrapperAsyncConnector, 'connect_with_aws_wrapper') as mock_connect: + mock_connect.return_value = mock_connection + with patch('tortoise.connection.connections') as mock_connections: + mock_connections.set.return_value = "test_token" + + with patch.object(AwsWrapperAsyncConnector, 'close_aws_wrapper') as mock_close: + try: + async with context: + raise ValueError("Test exception") + except ValueError: + pass + + # Verify rollback was called and connection was closed + mock_client.rollback.assert_called_once() + mock_close.assert_called_once_with(mock_connection) + + +class TestAwsWrapperAsyncConnector: + @pytest.mark.asyncio + async def test_connect_with_aws_wrapper(self): + mock_connect_func = MagicMock() + mock_connection = MagicMock() + kwargs = {"host": "localhost", "user": "test"} + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = mock_connection + + result = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(mock_connect_func, **kwargs) + + mock_to_thread.assert_called_once() + assert isinstance(result, AwsConnectionAsyncWrapper) + assert result._wrapped_connection == mock_connection + + @pytest.mark.asyncio + async def test_close_aws_wrapper(self): + mock_connection = MagicMock() + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await AwsWrapperAsyncConnector.close_aws_wrapper(mock_connection) + + mock_to_thread.assert_called_once_with(mock_connection.close) diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py new file mode 100644 index 00000000..701d4514 --- /dev/null +++ b/tests/unit/test_tortoise_mysql_client.py @@ -0,0 +1,472 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mysql.connector import errors +from tortoise.exceptions import (DBConnectionError, IntegrityError, + OperationalError, TransactionManagementError) + +from aws_advanced_python_wrapper.tortoise.backends.mysql.client import ( + AwsMySQLClient, TransactionWrapper, _gen_savepoint_name, + translate_exceptions) + + +class TestTranslateExceptions: + @pytest.mark.asyncio + async def test_translate_operational_error(self): + @translate_exceptions + async def test_func(self): + raise errors.OperationalError("Test error") + + with pytest.raises(OperationalError): + await test_func(None) + + @pytest.mark.asyncio + async def test_translate_integrity_error(self): + @translate_exceptions + async def test_func(self): + raise errors.IntegrityError("Test error") + + with pytest.raises(IntegrityError): + await test_func(None) + + @pytest.mark.asyncio + async def test_no_exception(self): + @translate_exceptions + async def test_func(self): + return "success" + + result = await test_func(None) + assert result == "success" + + +class TestAwsMySQLClient: + def test_init(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="utf8mb4", + storage_engine="InnoDB", + connection_name="test_conn" + ) + + assert client.user == "test_user" + assert client.password == "test_pass" + assert client.database == "test_db" + assert client.host == "localhost" + assert client.port == 3306 + assert client.charset == "utf8mb4" + assert client.storage_engine == "InnoDB" + + def test_init_defaults(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + assert client.charset == "utf8mb4" + assert client.storage_engine == "innodb" + + @pytest.mark.asyncio + async def test_create_connection_invalid_charset(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="invalid_charset", + connection_name="test_conn" + ) + + with pytest.raises(DBConnectionError, match="Unknown character set"): + await client.create_connection(with_db=True) + + @pytest.mark.asyncio + async def test_create_connection_success(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): + await client.create_connection(with_db=True) + + assert client._template["user"] == "test_user" + assert client._template["database"] == "test_db" + assert client._template["autocommit"] is True + + @pytest.mark.asyncio + async def test_close(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + # close() method now does nothing (AWS wrapper handles cleanup) + await client.close() + # No assertions needed since method is a no-op + + def test_acquire_connection(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + connection_wrapper = client.acquire_connection() + assert connection_wrapper.client == client + + @pytest.mark.asyncio + async def test_execute_insert(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.lastrowid = 123 + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + mock_cursor.execute = AsyncMock() + + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): + result = await client.execute_insert("INSERT INTO test VALUES (?)", ["value"]) + + assert result == 123 + mock_cursor.execute.assert_called_once_with("INSERT INTO test VALUES (?)", ["value"]) + + @pytest.mark.asyncio + async def test_execute_many_with_transactions(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + storage_engine="innodb" + ) + + mock_connection = MagicMock() + mock_cursor = MagicMock() + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + with patch.object(client, '_execute_many_with_transaction') as mock_execute_many_tx: + mock_execute_many_tx.return_value = None + + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): + await client.execute_many("INSERT INTO test VALUES (?)", [["val1"], ["val2"]]) + + mock_execute_many_tx.assert_called_once_with( + mock_cursor, mock_connection, "INSERT INTO test VALUES (?)", [["val1"], ["val2"]] + ) + + @pytest.mark.asyncio + async def test_execute_query(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.rowcount = 2 + mock_cursor.description = [("id",), ("name",)] + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.fetchall = AsyncMock(return_value=[(1, "test"), (2, "test2")]) + + with patch('aws_advanced_python_wrapper.tortoise.backends.mysql.client.logger'): + rowcount, results = await client.execute_query("SELECT * FROM test") + + assert rowcount == 2 + assert len(results) == 2 + assert results[0] == {"id": 1, "name": "test"} + + @pytest.mark.asyncio + async def test_execute_query_dict(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + with patch.object(client, 'execute_query') as mock_execute_query: + mock_execute_query.return_value = (2, [{"id": 1}, {"id": 2}]) + + results = await client.execute_query_dict("SELECT * FROM test") + + assert results == [{"id": 1}, {"id": 2}] + + def test_in_transaction(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + transaction_context = client._in_transaction() + assert transaction_context is not None + + +class TestTransactionWrapper: + def test_init(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + + assert wrapper.connection_name == "test_conn" + assert wrapper._parent == parent_client + assert wrapper._finalized is False + assert wrapper._savepoint is None + + @pytest.mark.asyncio + async def test_begin(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.begin() + + mock_connection.set_autocommit.assert_called_once_with(False) + assert wrapper._finalized is False + + @pytest.mark.asyncio + async def test_commit(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.commit = AsyncMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.commit() + + mock_connection.commit.assert_called_once() + mock_connection.set_autocommit.assert_called_once_with(True) + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_commit_already_finalized(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + wrapper._finalized = True + + with pytest.raises(TransactionManagementError, match="Transaction already finalized"): + await wrapper.commit() + + @pytest.mark.asyncio + async def test_rollback(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.rollback = AsyncMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.rollback() + + mock_connection.rollback.assert_called_once() + mock_connection.set_autocommit.assert_called_once_with(True) + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_savepoint(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute = AsyncMock() + + wrapper._connection = mock_connection + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + await wrapper.savepoint() + + assert wrapper._savepoint is not None + assert wrapper._savepoint.startswith("tortoise_savepoint_") + mock_cursor.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_savepoint_rollback(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + wrapper._savepoint = "test_savepoint" + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute = AsyncMock() + + wrapper._connection = mock_connection + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + await wrapper.savepoint_rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK TO SAVEPOINT test_savepoint") + assert wrapper._savepoint is None + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_savepoint_rollback_no_savepoint(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn" + ) + + wrapper = TransactionWrapper(parent_client) + wrapper._savepoint = None + + with pytest.raises(TransactionManagementError, match="No savepoint to rollback to"): + await wrapper.savepoint_rollback() + + +class TestGenSavepointName: + def test_gen_savepoint_name(self): + name1 = _gen_savepoint_name() + name2 = _gen_savepoint_name() + + assert name1.startswith("tortoise_savepoint_") + assert name2.startswith("tortoise_savepoint_") + assert name1 != name2 # Should be unique