|
28 | 28 | QUEUE_DONE = b"-1" |
29 | 29 |
|
30 | 30 |
|
| 31 | +def _execute_sync_task_in_executor( |
| 32 | + target: Callable[..., Any], |
| 33 | + args: tuple[Any, ...], |
| 34 | + kwargs: dict[str, Any], |
| 35 | +) -> Any: |
| 36 | + """Execute a sync task. |
| 37 | +
|
| 38 | + This is a wrapper to ensure we pass the target function directly |
| 39 | + to the executor, avoiding issues with pickling bound methods like ctx.run. |
| 40 | +
|
| 41 | + :param target: function to execute |
| 42 | + :param args: positional arguments |
| 43 | + :param kwargs: keyword arguments |
| 44 | + :return: result of the function call |
| 45 | + """ |
| 46 | + return target(*args, **kwargs) |
| 47 | + |
| 48 | + |
31 | 49 | class Receiver: |
32 | 50 | """Class that uses as a callback handler.""" |
33 | 51 |
|
@@ -69,6 +87,7 @@ def __init__( |
69 | 87 | "can result in undefined behavior", |
70 | 88 | ) |
71 | 89 | self.sem_prefetch = asyncio.Semaphore(max_prefetch) |
| 90 | + self.is_process_pool = isinstance(executor, ProcessPoolExecutor) |
72 | 91 |
|
73 | 92 | async def callback( # noqa: C901, PLR0912 |
74 | 93 | self, |
@@ -254,6 +273,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 |
254 | 273 | ctx.run, |
255 | 274 | func, |
256 | 275 | ) |
| 276 | + if self.is_process_pool: |
| 277 | + # For ProcessPoolExecutor, we can't use ctx.run because it contains |
| 278 | + # a reference to contextvars.Context which cannot be pickled. |
| 279 | + # Instead, we call the target function directly in the executor. |
| 280 | + # Each worker process starts with its own context, so we don't need |
| 281 | + # to preserve the parent context. |
| 282 | + target_future = loop.run_in_executor( |
| 283 | + self.executor, |
| 284 | + _execute_sync_task_in_executor, |
| 285 | + target, |
| 286 | + tuple(message.args), |
| 287 | + kwargs, |
| 288 | + ) |
| 289 | + else: |
| 290 | + # For ThreadPoolExecutor, we can use ctx.run with functools.partial |
| 291 | + ctx = contextvars.copy_context() |
| 292 | + func = functools.partial(target, *message.args, **kwargs) |
| 293 | + target_future = loop.run_in_executor( |
| 294 | + self.executor, |
| 295 | + ctx.run, |
| 296 | + func, |
| 297 | + ) |
257 | 298 | timeout = message.labels.get("timeout") |
258 | 299 | if timeout is not None: |
259 | 300 | if not is_coroutine: |
|
0 commit comments