Skip to content

Commit d4a1e1e

Browse files
committed
Add thread pool executors
1 parent cad4e9b commit d4a1e1e

File tree

1 file changed

+56
-21
lines changed
  • aws_advanced_python_wrapper/tortoise/backend/base

1 file changed

+56
-21
lines changed

aws_advanced_python_wrapper/tortoise/backend/base/client.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import mysql.connector
17+
from concurrent.futures import ThreadPoolExecutor
1718
from contextlib import asynccontextmanager
1819
from typing import Any, Callable, Generic
1920

@@ -24,41 +25,67 @@
2425
from aws_advanced_python_wrapper import AwsWrapperConnection
2526

2627

27-
async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection:
28-
"""Create an AWS wrapper connection with async cursor support."""
29-
connection = await asyncio.to_thread(
30-
AwsWrapperConnection.connect,
31-
connect_func,
32-
**kwargs,
28+
class AwsWrapperAsyncConnector:
29+
"""Factory class for creating AWS wrapper connections."""
30+
31+
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
32+
thread_name_prefix="AwsWrapperConnectorExecutor"
3333
)
34-
return AwsConnectionAsyncWrapper(connection)
34+
35+
@staticmethod
36+
async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection:
37+
"""Create an AWS wrapper connection with async cursor support."""
38+
loop = asyncio.get_event_loop()
39+
connection = await loop.run_in_executor(
40+
AwsWrapperAsyncConnector._executor,
41+
lambda: AwsWrapperConnection.connect(connect_func, **kwargs)
42+
)
43+
return AwsConnectionAsyncWrapper(connection)
44+
45+
@staticmethod
46+
async def CloseAwsWrapper(connection: AwsWrapperConnection) -> None:
47+
"""Close an AWS wrapper connection asynchronously."""
48+
loop = asyncio.get_event_loop()
49+
await loop.run_in_executor(
50+
AwsWrapperAsyncConnector._executor,
51+
connection.close
52+
)
3553

3654

3755
class AwsCursorAsyncWrapper:
3856
"""Wraps a sync cursor to provide async interface."""
3957

58+
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
59+
thread_name_prefix="AwsCursorAsyncWrapperExecutor"
60+
)
61+
4062
def __init__(self, sync_cursor):
4163
self._cursor = sync_cursor
4264

4365
async def execute(self, query, params=None):
4466
"""Execute a query asynchronously."""
45-
return await asyncio.to_thread(self._cursor.execute, query, params)
67+
loop = asyncio.get_event_loop()
68+
return await loop.run_in_executor(self._executor, self._cursor.execute, query, params)
4669

4770
async def executemany(self, query, params_list):
4871
"""Execute multiple queries asynchronously."""
49-
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
72+
loop = asyncio.get_event_loop()
73+
return await loop.run_in_executor(self._executor, self._cursor.executemany, query, params_list)
5074

5175
async def fetchall(self):
5276
"""Fetch all results asynchronously."""
53-
return await asyncio.to_thread(self._cursor.fetchall)
77+
loop = asyncio.get_event_loop()
78+
return await loop.run_in_executor(self._executor, self._cursor.fetchall)
5479

5580
async def fetchone(self):
5681
"""Fetch one result asynchronously."""
57-
return await asyncio.to_thread(self._cursor.fetchone)
82+
loop = asyncio.get_event_loop()
83+
return await loop.run_in_executor(self._executor, self._cursor.fetchone)
5884

5985
async def close(self):
6086
"""Close cursor asynchronously."""
61-
return await asyncio.to_thread(self._cursor.close)
87+
loop = asyncio.get_event_loop()
88+
return await loop.run_in_executor(self._executor, self._cursor.close)
6289

6390
def __getattr__(self, name):
6491
"""Delegate non-async attributes to the wrapped cursor."""
@@ -68,29 +95,37 @@ def __getattr__(self, name):
6895
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
6996
"""AWS wrapper connection with async cursor support."""
7097

98+
_executor: ThreadPoolExecutor = ThreadPoolExecutor(
99+
thread_name_prefix="AwsConnectionAsyncWrapperExecutor"
100+
)
101+
71102
def __init__(self, connection: AwsWrapperConnection):
72103
self._wrapped_connection = connection
73104

74105
@asynccontextmanager
75106
async def cursor(self):
76107
"""Create an async cursor context manager."""
77-
cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor)
108+
loop = asyncio.get_event_loop()
109+
cursor_obj = await loop.run_in_executor(self._executor, self._wrapped_connection.cursor)
78110
try:
79111
yield AwsCursorAsyncWrapper(cursor_obj)
80112
finally:
81-
await asyncio.to_thread(cursor_obj.close)
113+
await loop.run_in_executor(self._executor, cursor_obj.close)
82114

83115
async def rollback(self):
84116
"""Rollback the current transaction."""
85-
return await asyncio.to_thread(self._wrapped_connection.rollback)
117+
loop = asyncio.get_event_loop()
118+
return await loop.run_in_executor(self._executor, self._wrapped_connection.rollback)
86119

87120
async def commit(self):
88121
"""Commit the current transaction."""
89-
return await asyncio.to_thread(self._wrapped_connection.commit)
122+
loop = asyncio.get_event_loop()
123+
return await loop.run_in_executor(self._executor, self._wrapped_connection.commit)
90124

91125
async def set_autocommit(self, value: bool):
92126
"""Set autocommit mode."""
93-
return await asyncio.to_thread(lambda: setattr(self._wrapped_connection, 'autocommit', value))
127+
loop = asyncio.get_event_loop()
128+
return await loop.run_in_executor(self._executor, lambda: setattr(self._wrapped_connection, 'autocommit', value))
94129

95130
def __getattr__(self, name):
96131
"""Delegate all other attributes/methods to the wrapped connection."""
@@ -128,13 +163,13 @@ async def ensure_connection(self) -> None:
128163
async def __aenter__(self) -> T_conn:
129164
"""Acquire connection from pool."""
130165
await self.ensure_connection()
131-
self.connection = await ConnectWithAwsWrapper(self.connect_func, **self.client._template)
166+
self.connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(self.connect_func, **self.client._template)
132167
return self.connection
133168

134169
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
135170
"""Close connection and release back to pool."""
136171
if self.connection:
137-
await asyncio.to_thread(self.connection.close)
172+
await AwsWrapperAsyncConnector.CloseAwsWrapper(self.connection)
138173

139174

140175
class TortoiseAwsClientTransactionContext(TransactionContext):
@@ -159,7 +194,7 @@ async def __aenter__(self) -> TransactionalDBClient:
159194
self.token = connections.set(self.connection_name, self.client)
160195

161196
# Create connection and begin transaction
162-
self.client._connection = await ConnectWithAwsWrapper(
197+
self.client._connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(
163198
mysql.connector.Connect,
164199
**self.client._parent._template
165200
)
@@ -178,4 +213,4 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
178213
await self.client.commit()
179214
finally:
180215
connections.reset(self.token)
181-
await asyncio.to_thread(self.client._connection.close)
216+
await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection)

0 commit comments

Comments
 (0)