Skip to content

Commit 2fe1785

Browse files
committed
Fixed process-pool for sync tasks.
1 parent 653400e commit 2fe1785

2 files changed

Lines changed: 60 additions & 0 deletions

File tree

a.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import asyncio
2+
3+
from taskiq_redis import ListQueueBroker
4+
5+
broker = ListQueueBroker("redis://localhost")
6+
7+
8+
@broker.task
9+
async def my_task():
10+
pass
11+
12+
13+
async def main():
14+
async with broker:
15+
await my_task.kiq()
16+
17+
18+
if __name__ == "__main__":
19+
asyncio.run(main())

taskiq/receiver/receiver.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@
2828
QUEUE_DONE = b"-1"
2929

3030

31+
def _execute_sync_task_in_executor(
32+
target: Callable[..., Any],
33+
args: tuple[Any, ...],
34+
kwargs: dict[str, Any],
35+
) -> Any:
36+
"""Execute a sync task.
37+
38+
This is a wrapper to ensure we pass the target function directly
39+
to the executor, avoiding issues with pickling bound methods like ctx.run.
40+
41+
:param target: function to execute
42+
:param args: positional arguments
43+
:param kwargs: keyword arguments
44+
:return: result of the function call
45+
"""
46+
return target(*args, **kwargs)
47+
48+
3149
class Receiver:
3250
"""Class that uses as a callback handler."""
3351

@@ -69,6 +87,7 @@ def __init__(
6987
"can result in undefined behavior",
7088
)
7189
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
90+
self.is_process_pool = isinstance(executor, ProcessPoolExecutor)
7291

7392
async def callback( # noqa: C901, PLR0912
7493
self,
@@ -254,6 +273,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
254273
ctx.run,
255274
func,
256275
)
276+
if self.is_process_pool:
277+
# For ProcessPoolExecutor, we can't use ctx.run because it contains
278+
# a reference to contextvars.Context which cannot be pickled.
279+
# Instead, we call the target function directly in the executor.
280+
# Each worker process starts with its own context, so we don't need
281+
# to preserve the parent context.
282+
target_future = loop.run_in_executor(
283+
self.executor,
284+
_execute_sync_task_in_executor,
285+
target,
286+
tuple(message.args),
287+
kwargs,
288+
)
289+
else:
290+
# For ThreadPoolExecutor, we can use ctx.run with functools.partial
291+
ctx = contextvars.copy_context()
292+
func = functools.partial(target, *message.args, **kwargs)
293+
target_future = loop.run_in_executor(
294+
self.executor,
295+
ctx.run,
296+
func,
297+
)
257298
timeout = message.labels.get("timeout")
258299
if timeout is not None:
259300
if not is_coroutine:

0 commit comments

Comments
 (0)