Skip to content

Commit b5ef21d

Browse files
committed
Fixed process-pool for sync tasks.
1 parent 653400e commit b5ef21d

1 file changed

Lines changed: 42 additions & 10 deletions

File tree

taskiq/receiver/receiver.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import sys
66
from collections.abc import Callable
7-
from concurrent.futures import Executor
7+
from concurrent.futures import Executor, ProcessPoolExecutor
88
from logging import getLogger
99
from time import time
1010
from typing import Any, get_type_hints
@@ -28,6 +28,24 @@
2828
QUEUE_DONE = b"-1"
2929

3030

31+
def _execute_sync_task_in_executor(
32+
target: Callable[..., Any],
33+
args: tuple[Any, ...],
34+
kwargs: dict[str, Any],
35+
) -> Any:
36+
"""Execute a sync task.
37+
38+
This is a wrapper to ensure we pass the target function directly
39+
to the executor, avoiding issues with pickling bound methods like ctx.run.
40+
41+
:param target: function to execute
42+
:param args: positional arguments
43+
:param kwargs: keyword arguments
44+
:return: result of the function call
45+
"""
46+
return target(*args, **kwargs)
47+
48+
3149
class Receiver:
3250
"""Class that uses as a callback handler."""
3351

@@ -69,6 +87,7 @@ def __init__(
6987
"can result in undefined behavior",
7088
)
7189
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
90+
self.is_process_pool = isinstance(executor, ProcessPoolExecutor)
7291

7392
async def callback( # noqa: C901, PLR0912
7493
self,
@@ -245,15 +264,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
245264
target_future = target(*message.args, **kwargs)
246265
else:
247266
is_coroutine = False
248-
# If this is a synchronous function, we
249-
# run it in executor and preserve the context.
250-
ctx = contextvars.copy_context()
251-
func = functools.partial(target, *message.args, **kwargs)
252-
target_future = loop.run_in_executor(
253-
self.executor,
254-
ctx.run,
255-
func,
256-
)
267+
if self.is_process_pool:
268+
# For ProcessPoolExecutor, we can't use ctx.run because it contains
269+
# a reference to contextvars.Context which cannot be pickled.
270+
# Instead, we call the target function directly in the executor.
271+
# Each worker process starts with its own context, so we don't need
272+
# to preserve the parent context.
273+
target_future = loop.run_in_executor(
274+
self.executor,
275+
_execute_sync_task_in_executor,
276+
target,
277+
tuple(message.args),
278+
kwargs,
279+
)
280+
else:
281+
# For ThreadPoolExecutor, we can use ctx.run with functools.partial
282+
ctx = contextvars.copy_context()
283+
func = functools.partial(target, *message.args, **kwargs)
284+
target_future = loop.run_in_executor(
285+
self.executor,
286+
ctx.run,
287+
func,
288+
)
257289
timeout = message.labels.get("timeout")
258290
if timeout is not None:
259291
if not is_coroutine:

0 commit comments

Comments
 (0)