Skip to content

Commit 2a5b151

Browse files
committed
add sql alchemy tortoise connection provider
1 parent d224974 commit 2a5b151

File tree

9 files changed

+127
-38
lines changed

9 files changed

+127
-38
lines changed

aws_advanced_python_wrapper/driver_dialect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
if TYPE_CHECKING:
2020
from aws_advanced_python_wrapper.hostinfo import HostInfo
2121
from aws_advanced_python_wrapper.pep249 import Connection, Cursor
22+
from types import ModuleType
2223

2324
from abc import ABC
2425
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
@@ -164,3 +165,7 @@ def ping(self, conn: Connection) -> bool:
164165
return True
165166
except Exception:
166167
return False
168+
169+
def get_driver_module(self) -> ModuleType:
170+
raise UnsupportedOperationError(
171+
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "get_driver_module"))

aws_advanced_python_wrapper/driver_dialect_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from aws_advanced_python_wrapper.errors import AwsWrapperError
2525
from aws_advanced_python_wrapper.utils.log import Logger
2626
from aws_advanced_python_wrapper.utils.messages import Messages
27-
from aws_advanced_python_wrapper.utils.properties import (Properties,
28-
WrapperProperties)
27+
from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperties
2928
from aws_advanced_python_wrapper.utils.utils import Utils
3029

3130
logger = Logger(__name__)
@@ -52,7 +51,8 @@ class DriverDialectManager(DriverDialectProvider):
5251
}
5352

5453
pool_connection_driver_dialect: Dict[str, str] = {
55-
"SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect"
54+
"SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect",
55+
"SqlAlchemyTortoisePooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect",
5656
}
5757

5858
@staticmethod

aws_advanced_python_wrapper/mysql_driver_dialect.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
if TYPE_CHECKING:
2020
from aws_advanced_python_wrapper.hostinfo import HostInfo
2121
from aws_advanced_python_wrapper.pep249 import Connection
22+
from types import ModuleType
2223

2324
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
2425
from inspect import signature
@@ -28,12 +29,11 @@
2829
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
2930
from aws_advanced_python_wrapper.utils.decorators import timeout
3031
from aws_advanced_python_wrapper.utils.messages import Messages
31-
from aws_advanced_python_wrapper.utils.properties import (Properties,
32-
PropertiesUtils,
33-
WrapperProperties)
32+
from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties
3433

3534
CMYSQL_ENABLED = False
3635

36+
import mysql.connector
3737
from mysql.connector import MySQLConnection # noqa: E402
3838
from mysql.connector.cursor import MySQLCursor # noqa: E402
3939

@@ -201,3 +201,6 @@ def prepare_connect_info(self, host_info: HostInfo, original_props: Properties)
201201

202202
def supports_connect_timeout(self) -> bool:
203203
return True
204+
205+
def get_driver_module(self) -> ModuleType:
206+
return mysql.connector

aws_advanced_python_wrapper/pg_driver_dialect.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,21 @@
1414

1515
from __future__ import annotations
1616

17+
from inspect import signature
1718
from typing import TYPE_CHECKING, Any, Callable, Set
1819

1920
import psycopg
2021

2122
if TYPE_CHECKING:
2223
from aws_advanced_python_wrapper.hostinfo import HostInfo
2324
from aws_advanced_python_wrapper.pep249 import Connection
24-
25-
from inspect import signature
25+
from types import ModuleType
2626

2727
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2828
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
2929
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
3030
from aws_advanced_python_wrapper.utils.messages import Messages
31-
from aws_advanced_python_wrapper.utils.properties import (Properties,
32-
PropertiesUtils,
33-
WrapperProperties)
31+
from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties
3432

3533

3634
class PgDriverDialect(DriverDialect):
@@ -175,3 +173,6 @@ def supports_tcp_keepalive(self) -> bool:
175173

176174
def supports_abort_connection(self) -> bool:
177175
return True
176+
177+
def get_driver_module(self) -> ModuleType:
178+
return psycopg

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHos
385385
SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None.
386386
SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties.
387387

388+
SqlAlchemyTortoiseConnectionProvider.UnableToDetermineDialect=[SqlAlchemyTortoiseConnectionProvider] Unable to resolve sql alchemy dialect for the following driver dialect '{}'.
389+
388390
SqlAlchemyDriverDialect.SetValueOnNoneConnection=[SqlAlchemyDriverDialect] Attempted to set the '{}' value on a pooled connection, but no underlying driver connection was found. This can happen if the pooled connection has previously been closed.
389391

390392
StaleDnsHelper.ClusterEndpointDns=[StaleDnsPlugin] Cluster endpoint {} resolves to {}.

aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717
from typing import TYPE_CHECKING, Any
1818

19-
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
20-
from aws_advanced_python_wrapper.errors import AwsWrapperError
21-
from aws_advanced_python_wrapper.utils.messages import Messages
22-
2319
if TYPE_CHECKING:
2420
from aws_advanced_python_wrapper.hostinfo import HostInfo
2521
from aws_advanced_python_wrapper.pep249 import Connection
2622
from aws_advanced_python_wrapper.utils.properties import Properties
23+
from types import ModuleType
2724

2825
from sqlalchemy import PoolProxiedConnection
2926

27+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
28+
from aws_advanced_python_wrapper.errors import AwsWrapperError
29+
from aws_advanced_python_wrapper.utils.messages import Messages
30+
3031

3132
class SqlAlchemyDriverDialect(DriverDialect):
3233
_driver_name: str = "SQLAlchemy"
@@ -125,3 +126,6 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
125126
return None
126127

127128
return self._underlying_driver.transfer_session_state(from_driver_conn, to_driver_conn)
129+
130+
def get_driver_module(self) -> ModuleType:
131+
return self._underlying_driver.get_driver_module()

aws_advanced_python_wrapper/tortoise/backend/mysql/client.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from mysql.connector import errors
2323
from mysql.connector.charsets import MYSQL_CHARACTER_SETS
2424
from pypika_tortoise import MySQLQuery
25-
from tortoise import timezone
2625
from tortoise.backends.base.client import (
2726
BaseDBAsyncClient,
2827
Capabilities,
@@ -39,16 +38,16 @@
3938
)
4039

4140
from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager
41+
from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError
4242
from aws_advanced_python_wrapper.hostinfo import HostInfo
43-
from aws_advanced_python_wrapper.sql_alchemy_connection_provider import SqlAlchemyPooledConnectionProvider
4443
from aws_advanced_python_wrapper.tortoise.backend.base.client import (
4544
TortoiseAwsClientConnectionWrapper,
4645
TortoiseAwsClientTransactionContext,
4746
)
4847
from aws_advanced_python_wrapper.tortoise.backend.mysql.executor import AwsMySQLExecutor
4948
from aws_advanced_python_wrapper.tortoise.backend.mysql.schema_generator import AwsMySQLSchemaGenerator
49+
from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import SqlAlchemyTortoisePooledConnectionProvider
5050
from aws_advanced_python_wrapper.utils.log import Logger
51-
from aws_advanced_python_wrapper.errors import AwsWrapperError
5251

5352
logger = Logger(__name__)
5453
T = TypeVar("T")
@@ -62,10 +61,12 @@ async def translate_exceptions_(self, *args) -> T:
6261
try:
6362
try:
6463
return await func(self, *args)
65-
except AwsWrapperError as aws_err:
64+
except AwsWrapperError as aws_err: # Unwrap any AwsWrappedErrors
6665
if aws_err.__cause__:
6766
raise aws_err.__cause__
6867
raise
68+
except FailoverError as exc: # Raise any failover errors
69+
raise
6970
except errors.IntegrityError as exc:
7071
raise IntegrityError(exc)
7172
except (
@@ -127,16 +128,16 @@ def __init__(
127128
self.extra.pop("connection_name", None)
128129
self.extra.pop("fetch_inserted", None)
129130
self.extra.pop("db", None)
130-
self.extra.pop("autocommit", None)
131+
self.extra.pop("autocommit", None) # We need this to be true
131132
self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES")
132133

133134
# Initialize connection templates
134135
self._init_connection_templates()
135-
136+
136137
# Initialize state
137138
self._template = {}
138139
self._connection = None
139-
self._pool = None
140+
self._pool: SqlAlchemyTortoisePooledConnectionProvider = None
140141
self._pool_init_lock = asyncio.Lock()
141142

142143
def _init_connection_templates(self) -> None:
@@ -158,20 +159,19 @@ def _configure_pool(self, host_info: HostInfo, props: Dict[str, Any]) -> Dict[st
158159
"""Configure connection pool settings."""
159160
return {"pool_size": self.pool_maxsize, "max_overflow": -1}
160161

161-
@staticmethod
162-
def _get_pool_key(host_info: HostInfo, props: Dict[str, Any]) -> str:
162+
def _get_pool_key(self, host_info: HostInfo, props: Dict[str, Any]) -> str:
163163
"""Generate unique pool key for connection pooling."""
164164
url = host_info.url
165165
user = props["user"]
166166
db = props["database"]
167-
return f"{url}{user}{db}"
167+
return f"{url}{user}{db}{self.connection_name}"
168168

169169
async def _init_pool_if_needed(self) -> None:
170170
"""Initialize connection pool only once across all instances."""
171171
if not AwsMySQLClient._pool_initialized:
172172
async with AwsMySQLClient._pool_init_class_lock:
173173
if not AwsMySQLClient._pool_initialized:
174-
self._pool = SqlAlchemyPooledConnectionProvider(
174+
self._pool = SqlAlchemyTortoisePooledConnectionProvider(
175175
pool_configurator=self._configure_pool,
176176
pool_mapping=self._get_pool_key,
177177
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from types import ModuleType
15+
from typing import Callable, Dict
16+
17+
from sqlalchemy import Dialect
18+
from sqlalchemy.dialects.mysql import mysqlconnector
19+
from sqlalchemy.dialects.postgresql import psycopg
20+
21+
from aws_advanced_python_wrapper.database_dialect import DatabaseDialect
22+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
23+
from aws_advanced_python_wrapper.errors import AwsWrapperError
24+
from aws_advanced_python_wrapper.hostinfo import HostInfo
25+
from aws_advanced_python_wrapper.sql_alchemy_connection_provider import SqlAlchemyPooledConnectionProvider
26+
from aws_advanced_python_wrapper.utils.log import Logger
27+
from aws_advanced_python_wrapper.utils.messages import Messages
28+
from aws_advanced_python_wrapper.utils.properties import Properties
29+
30+
logger = Logger(__name__)
31+
32+
33+
34+
class SqlAlchemyTortoisePooledConnectionProvider(SqlAlchemyPooledConnectionProvider):
35+
"""
36+
Tortoise-specific pooled connection provider that handles failover by disposing pools.
37+
"""
38+
39+
_sqlalchemy_dialect_map : Dict[str, ModuleType] = {
40+
"MySQLDriverDialect": mysqlconnector,
41+
"PostgresDriverDialect": psycopg
42+
}
43+
44+
def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
45+
if self._accept_url_func:
46+
return self._accept_url_func(host_info, props)
47+
url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host)
48+
return url_type.is_rds
49+
50+
def _create_pool(
51+
self,
52+
target_func: Callable,
53+
driver_dialect: DriverDialect,
54+
database_dialect: DatabaseDialect,
55+
host_info: HostInfo,
56+
props: Properties):
57+
kwargs = dict() if self._pool_configurator is None else self._pool_configurator(host_info, props)
58+
prepared_properties = driver_dialect.prepare_connect_info(host_info, props)
59+
database_dialect.prepare_conn_props(prepared_properties)
60+
kwargs["creator"] = self._get_connection_func(target_func, prepared_properties)
61+
dialect = self._get_pool_dialect(driver_dialect)
62+
if not dialect:
63+
raise AwsWrapperError(Messages.get_formatted("SqlAlchemyTortoisePooledConnectionProvider.NoDialect", driver_dialect.__class__.__name__))
64+
65+
'''
66+
We need to pass in pre_ping and dialect to QueuePool the queue pool to enable health checks.
67+
Without this health check, we could be using dead connections after a failover.
68+
'''
69+
kwargs["pre_ping"] = True
70+
kwargs["dialect"] = dialect
71+
return self._create_sql_alchemy_pool(**kwargs)
72+
73+
def _get_pool_dialect(self, driver_dialect: DriverDialect) -> Dialect:
74+
dialect = None
75+
driver_dialect_class_name = driver_dialect.__class__.__name__
76+
if driver_dialect_class_name == "SqlAlchemyDriverDialect":
77+
driver_dialect_class_name = driver_dialect_class_name._underlying_driver_dialect.__class__.__name__
78+
module = self._sqlalchemy_dialect_map.get(driver_dialect_class_name)
79+
80+
if not module:
81+
return dialect
82+
dialect = module.dialect()
83+
dialect.dbapi = driver_dialect.get_driver_module()
84+
85+
return dialect

tests/integration/container/utils/test_telemetry_info.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
#
15-
# Licensed under the Apache License, Version 2.0 (the "License").
16-
# You may not use this file except in compliance with the License.
17-
# You may obtain a copy of the License at
18-
#
19-
# http://www.apache.org/licenses/LICENSE-2.0
20-
#
21-
# Unless required by applicable law or agreed to in writing, software
22-
# distributed under the License is distributed on an "AS IS" BASIS,
23-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24-
# See the License for the specific language governing permissions and
25-
# limitations under the License.
14+
2615
import typing
2716
from typing import Any, Dict
2817

0 commit comments

Comments
 (0)