Skip to content

Commit 04221c1

Browse files
committed
Added more intergation tests and fixed unit tests
1 parent 1a5d163 commit 04221c1

21 files changed

+1062
-329
lines changed

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def connect(
137137

138138
if queue_pool is None:
139139
raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url))
140-
141140
return queue_pool.connect()
142141

143142
# The pool key should always be retrieved using this method, because the username

aws_advanced_python_wrapper/tortoise/backend/base/client.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import asyncio
1616
import mysql.connector
17-
from concurrent.futures import ThreadPoolExecutor
1817
from contextlib import asynccontextmanager
1918
from typing import Any, Callable, Generic
2019

@@ -26,106 +25,79 @@
2625

2726

2827
class AwsWrapperAsyncConnector:
29-
"""Factory class for creating AWS wrapper connections."""
30-
31-
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
32-
thread_name_prefix="AwsWrapperAsyncExecutor"
33-
)
28+
"""Class for creating and closing AWS wrapper connections."""
3429

3530
@staticmethod
3631
async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection:
3732
"""Create an AWS wrapper connection with async cursor support."""
38-
loop = asyncio.get_event_loop()
39-
connection = await loop.run_in_executor(
40-
AwsWrapperAsyncConnector._executor,
41-
lambda: AwsWrapperConnection.connect(connect_func, **kwargs)
33+
connection = await asyncio.to_thread(
34+
AwsWrapperConnection.connect, connect_func, **kwargs
4235
)
4336
return AwsConnectionAsyncWrapper(connection)
4437

4538
@staticmethod
4639
async def CloseAwsWrapper(connection: AwsWrapperConnection) -> None:
4740
"""Close an AWS wrapper connection asynchronously."""
48-
loop = asyncio.get_event_loop()
49-
await loop.run_in_executor(
50-
AwsWrapperAsyncConnector._executor,
51-
connection.close
52-
)
41+
await asyncio.to_thread(connection.close)
5342

5443

5544
class AwsCursorAsyncWrapper:
56-
"""Wraps a sync cursor to provide async interface."""
57-
58-
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
59-
thread_name_prefix="AwsCursorAsyncWrapperExecutor"
60-
)
45+
"""Wraps sync AwsCursor cursor with async support."""
6146

6247
def __init__(self, sync_cursor):
6348
self._cursor = sync_cursor
6449

6550
async def execute(self, query, params=None):
6651
"""Execute a query asynchronously."""
67-
loop = asyncio.get_event_loop()
68-
return await loop.run_in_executor(self._executor, self._cursor.execute, query, params)
52+
return await asyncio.to_thread(self._cursor.execute, query, params)
6953

7054
async def executemany(self, query, params_list):
7155
"""Execute multiple queries asynchronously."""
72-
loop = asyncio.get_event_loop()
73-
return await loop.run_in_executor(self._executor, self._cursor.executemany, query, params_list)
56+
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
7457

7558
async def fetchall(self):
7659
"""Fetch all results asynchronously."""
77-
loop = asyncio.get_event_loop()
78-
return await loop.run_in_executor(self._executor, self._cursor.fetchall)
60+
return await asyncio.to_thread(self._cursor.fetchall)
7961

8062
async def fetchone(self):
8163
"""Fetch one result asynchronously."""
82-
loop = asyncio.get_event_loop()
83-
return await loop.run_in_executor(self._executor, self._cursor.fetchone)
64+
return await asyncio.to_thread(self._cursor.fetchone)
8465

8566
async def close(self):
8667
"""Close cursor asynchronously."""
87-
loop = asyncio.get_event_loop()
88-
return await loop.run_in_executor(self._executor, self._cursor.close)
68+
return await asyncio.to_thread(self._cursor.close)
8969

9070
def __getattr__(self, name):
9171
"""Delegate non-async attributes to the wrapped cursor."""
9272
return getattr(self._cursor, name)
9373

9474

9575
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
96-
"""AWS wrapper connection with async cursor support."""
97-
98-
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
99-
thread_name_prefix="AwsConnectionAsyncWrapperExecutor"
100-
)
76+
"""Wraps sync AwsConnection with async cursor support."""
10177

10278
def __init__(self, connection: AwsWrapperConnection):
10379
self._wrapped_connection = connection
10480

10581
@asynccontextmanager
10682
async def cursor(self):
10783
"""Create an async cursor context manager."""
108-
loop = asyncio.get_event_loop()
109-
cursor_obj = await loop.run_in_executor(self._executor, self._wrapped_connection.cursor)
84+
cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor)
11085
try:
11186
yield AwsCursorAsyncWrapper(cursor_obj)
11287
finally:
113-
await loop.run_in_executor(self._executor, cursor_obj.close)
88+
await asyncio.to_thread(cursor_obj.close)
11489

11590
async def rollback(self):
11691
"""Rollback the current transaction."""
117-
loop = asyncio.get_event_loop()
118-
return await loop.run_in_executor(self._executor, self._wrapped_connection.rollback)
92+
return await asyncio.to_thread(self._wrapped_connection.rollback)
11993

12094
async def commit(self):
12195
"""Commit the current transaction."""
122-
loop = asyncio.get_event_loop()
123-
return await loop.run_in_executor(self._executor, self._wrapped_connection.commit)
96+
return await asyncio.to_thread(self._wrapped_connection.commit)
12497

12598
async def set_autocommit(self, value: bool):
12699
"""Set autocommit mode."""
127-
loop = asyncio.get_event_loop()
128-
return await loop.run_in_executor(self._executor, lambda: setattr(self._wrapped_connection, 'autocommit', value))
100+
return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value)
129101

130102
def __getattr__(self, name):
131103
"""Delegate all other attributes/methods to the wrapped connection."""
@@ -209,5 +181,5 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
209181
else:
210182
await self.client.commit()
211183
finally:
184+
await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection)
212185
connections.reset(self.token)
213-
await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection)

aws_advanced_python_wrapper/tortoise/backend/mysql/client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
self.host = host
118118
self.port = int(port)
119119
self.extra = kwargs.copy()
120-
120+
121121
# Extract MySQL-specific settings
122122
self.storage_engine = self.extra.pop("storage_engine", "innodb")
123123
self.charset = self.extra.pop("charset", "utf8mb4")
@@ -128,6 +128,12 @@ def __init__(
128128
self.extra.pop("autocommit", None)
129129
self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES")
130130

131+
# Ensure that timeouts are integers
132+
timeout_params = ["connect_timeout", "monitoring-connect_timeout",]
133+
for param in timeout_params:
134+
if param in self.extra and self.extra[param] is not None:
135+
self.extra[param] = int(self.extra[param])
136+
131137
# Initialize connection templates
132138
self._init_connection_templates()
133139

@@ -251,7 +257,6 @@ async def _execute_script(self, query: str, with_db: bool) -> None:
251257
"""Execute a multi-statement query by parsing and running statements sequentially."""
252258
async with self._acquire_connection(with_db) as connection:
253259
logger.debug(f"Executing script: {query}")
254-
print(f"Executing script: {query}")
255260
async with connection.cursor() as cursor:
256261
# Parse multi-statement queries since MySQL Connector doesn't handle them well
257262
statements = sqlparse.split(query)

tests/integration/container/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def pytest_runtest_setup(item):
6868
test_name = item.callspec.id
6969
else:
7070
TestEnvironment.get_current().set_current_driver(None)
71+
# Fallback to item.name if no callspec (for non-parameterized tests)
72+
test_name = getattr(item, 'name', None) or str(item)
7173

7274
logger.info(f"Starting test preparation for: {test_name}")
7375

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tortoise import fields
16+
from tortoise.models import Model
17+
18+
19+
class User(Model):
20+
id = fields.IntField(primary_key=True)
21+
name = fields.CharField(max_length=50)
22+
email = fields.CharField(max_length=100, unique=True)
23+
24+
class Meta:
25+
table = "users"
26+
27+
28+
class UniqueName(Model):
29+
id = fields.IntField(primary_key=True)
30+
name = fields.CharField(max_length=20, null=True, unique=True)
31+
optional = fields.CharField(max_length=20, null=True)
32+
other_optional = fields.CharField(max_length=20, null=True)
33+
34+
class Meta:
35+
table = "unique_names"
36+
37+
38+
class TableWithSleepTrigger(Model):
39+
id = fields.IntField(primary_key=True)
40+
name = fields.CharField(max_length=50)
41+
value = fields.CharField(max_length=100)
42+
43+
class Meta:
44+
table = "table_with_sleep_trigger"
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tortoise import fields
16+
from tortoise.models import Model
17+
18+
19+
# One-to-One Relationship Models
20+
class RelTestAccount(Model):
21+
id = fields.IntField(primary_key=True)
22+
username = fields.CharField(max_length=50, unique=True)
23+
email = fields.CharField(max_length=100)
24+
25+
class Meta:
26+
table = "rel_test_accounts"
27+
28+
29+
class RelTestAccountProfile(Model):
30+
id = fields.IntField(primary_key=True)
31+
account = fields.OneToOneField("models.RelTestAccount", related_name="profile", on_delete=fields.CASCADE)
32+
bio = fields.TextField(null=True)
33+
avatar_url = fields.CharField(max_length=200, null=True)
34+
35+
class Meta:
36+
table = "rel_test_account_profiles"
37+
38+
39+
# One-to-Many Relationship Models
40+
class RelTestPublisher(Model):
41+
id = fields.IntField(primary_key=True)
42+
name = fields.CharField(max_length=100)
43+
email = fields.CharField(max_length=100, unique=True)
44+
45+
class Meta:
46+
table = "rel_test_publishers"
47+
48+
49+
class RelTestPublication(Model):
50+
id = fields.IntField(primary_key=True)
51+
title = fields.CharField(max_length=200)
52+
isbn = fields.CharField(max_length=13, unique=True)
53+
publisher = fields.ForeignKeyField("models.RelTestPublisher", related_name="publications", on_delete=fields.CASCADE)
54+
published_date = fields.DateField(null=True)
55+
56+
class Meta:
57+
table = "rel_test_publications"
58+
59+
60+
# Many-to-Many Relationship Models
61+
class RelTestLearner(Model):
62+
id = fields.IntField(primary_key=True)
63+
name = fields.CharField(max_length=100)
64+
learner_id = fields.CharField(max_length=20, unique=True)
65+
66+
class Meta:
67+
table = "rel_test_learners"
68+
69+
70+
class RelTestSubject(Model):
71+
id = fields.IntField(primary_key=True)
72+
name = fields.CharField(max_length=100)
73+
code = fields.CharField(max_length=10, unique=True)
74+
credits = fields.IntField()
75+
learners = fields.ManyToManyField("models.RelTestLearner", related_name="subjects")
76+
77+
class Meta:
78+
table = "rel_test_subjects"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
class TestRouter:
16+
def db_for_read(self, model):
17+
return "default"
18+
19+
def db_for_write(self, model):
20+
return "default"

0 commit comments

Comments
 (0)