|
4 | 4 | import inspect |
5 | 5 | import sys |
6 | 6 | from collections.abc import Callable |
7 | | -from concurrent.futures import Executor |
| 7 | +from concurrent.futures import Executor, ProcessPoolExecutor |
8 | 8 | from logging import getLogger |
9 | 9 | from time import time |
10 | 10 | from typing import Any, get_type_hints |
|
28 | 28 | QUEUE_DONE = b"-1" |
29 | 29 |
|
30 | 30 |
|
| 31 | +def _execute_sync_task_in_executor( |
| 32 | + target: Callable[..., Any], |
| 33 | + args: tuple[Any, ...], |
| 34 | + kwargs: dict[str, Any], |
| 35 | +) -> Any: |
| 36 | + """Execute a sync task. |
| 37 | +
|
| 38 | + This is a wrapper to ensure we pass the target function directly |
| 39 | + to the executor, avoiding issues with pickling bound methods like ctx.run. |
| 40 | +
|
| 41 | + :param target: function to execute |
| 42 | + :param args: positional arguments |
| 43 | + :param kwargs: keyword arguments |
| 44 | + :return: result of the function call |
| 45 | + """ |
| 46 | + return target(*args, **kwargs) |
| 47 | + |
| 48 | + |
31 | 49 | class Receiver: |
32 | 50 | """Class that uses as a callback handler.""" |
33 | 51 |
|
@@ -69,6 +87,7 @@ def __init__( |
69 | 87 | "can result in undefined behavior", |
70 | 88 | ) |
71 | 89 | self.sem_prefetch = asyncio.Semaphore(max_prefetch) |
| 90 | + self.is_process_pool = isinstance(executor, ProcessPoolExecutor) |
72 | 91 |
|
73 | 92 | async def callback( # noqa: C901, PLR0912 |
74 | 93 | self, |
@@ -245,15 +264,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 |
245 | 264 | target_future = target(*message.args, **kwargs) |
246 | 265 | else: |
247 | 266 | is_coroutine = False |
248 | | - # If this is a synchronous function, we |
249 | | - # run it in executor and preserve the context. |
250 | | - ctx = contextvars.copy_context() |
251 | | - func = functools.partial(target, *message.args, **kwargs) |
252 | | - target_future = loop.run_in_executor( |
253 | | - self.executor, |
254 | | - ctx.run, |
255 | | - func, |
256 | | - ) |
| 267 | + if self.is_process_pool: |
| 268 | + # For ProcessPoolExecutor, we can't use ctx.run because it contains |
| 269 | + # a reference to contextvars.Context which cannot be pickled. |
| 270 | + # Instead, we call the target function directly in the executor. |
| 271 | + # Each worker process starts with its own context, so we don't need |
| 272 | + # to preserve the parent context. |
| 273 | + target_future = loop.run_in_executor( |
| 274 | + self.executor, |
| 275 | + _execute_sync_task_in_executor, |
| 276 | + target, |
| 277 | + tuple(message.args), |
| 278 | + kwargs, |
| 279 | + ) |
| 280 | + else: |
| 281 | + # For ThreadPoolExecutor, we can use ctx.run with functools.partial |
| 282 | + ctx = contextvars.copy_context() |
| 283 | + func = functools.partial(target, *message.args, **kwargs) |
| 284 | + target_future = loop.run_in_executor( |
| 285 | + self.executor, |
| 286 | + ctx.run, |
| 287 | + func, |
| 288 | + ) |
257 | 289 | timeout = message.labels.get("timeout") |
258 | 290 | if timeout is not None: |
259 | 291 | if not is_coroutine: |
|
0 commit comments