-
Notifications
You must be signed in to change notification settings - Fork 88
feat: isolate async handler execution on dedicated worker event loop #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,16 +5,19 @@ | |
|
|
||
| import asyncio | ||
| import contextvars | ||
| import functools | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import queue | ||
| import threading | ||
| import time | ||
| import uuid | ||
| from collections.abc import Sequence | ||
| from typing import Any, Callable, Dict, Optional | ||
|
|
||
| from starlette.applications import Starlette | ||
| from starlette.concurrency import run_in_threadpool | ||
| from starlette.middleware import Middleware | ||
| from starlette.responses import JSONResponse, Response, StreamingResponse | ||
| from starlette.routing import Route, WebSocketRoute | ||
|
|
@@ -39,6 +42,30 @@ | |
| from .utils import convert_complex_objects | ||
|
|
||
|
|
||
| def _is_async_callable(obj: Any) -> bool: | ||
| """Check if obj is async-callable, unwrapping functools.partial.""" | ||
| while isinstance(obj, functools.partial): | ||
| obj = obj.func | ||
| return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__)) | ||
|
|
||
|
|
||
| def _is_async_gen_callable(obj: Any) -> bool: | ||
| """Check if obj is an async generator function, unwrapping functools.partial.""" | ||
| while isinstance(obj, functools.partial): | ||
| obj = obj.func | ||
| return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__)) | ||
|
|
||
|
|
||
| def _restore_context(ctx: contextvars.Context) -> None: | ||
| """Restore context variables from a snapshot (Django asgiref pattern).""" | ||
| for var, value in ctx.items(): | ||
| try: | ||
| if var.get() != value: | ||
| var.set(value) | ||
| except LookupError: | ||
| var.set(value) | ||
|
|
||
|
|
||
| class RequestContextFormatter(logging.Formatter): | ||
| """Formatter including request and session IDs.""" | ||
|
|
||
|
|
@@ -96,6 +123,9 @@ def __init__( | |
| self._task_counter_lock: threading.Lock = threading.Lock() | ||
| self._forced_ping_status: Optional[PingStatus] = None | ||
| self._last_status_update_time: float = time.time() | ||
| self._worker_loop: Optional[asyncio.AbstractEventLoop] = None | ||
| self._worker_thread: Optional[threading.Thread] = None | ||
| self._worker_loop_lock: threading.Lock = threading.Lock() | ||
|
|
||
| routes = [ | ||
| Route("/invocations", self._handle_invocation, methods=["POST"]), | ||
|
|
@@ -163,7 +193,7 @@ def async_task(self, func: Callable) -> Callable: | |
| - Set ping status to HEALTHY_BUSY while running | ||
| - Revert to HEALTHY when complete | ||
| """ | ||
| if not asyncio.iscoroutinefunction(func): | ||
| if not _is_async_callable(func): | ||
| raise ValueError("@async_task can only be applied to async functions") | ||
|
|
||
| async def wrapper(*args, **kwargs): | ||
|
|
@@ -463,16 +493,92 @@ def run(self, port: int = 8080, host: Optional[str] = None, **kwargs): | |
|
|
||
| uvicorn.run(self, **uvicorn_params) | ||
|
|
||
| async def _invoke_handler(self, handler, request_context, takes_context, payload): | ||
| def _ensure_worker_loop(self) -> asyncio.AbstractEventLoop: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The sync generator returned by this method is iterated by Starlette via Under concurrent streaming workloads (e.g., LLM streaming responses each taking ~20 seconds), this can exhaust the default thread pool (~40 threads) ref. Once exhausted, all new requests are queued waiting for a thread to free up, even though the occupied threads are mostly idle waiting on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one is expected though, customers at max have 2 concurrent connections to runtime servers |
||
| """Lazily create and start a dedicated worker event loop in a background thread. | ||
|
|
||
| The worker loop isolates async handler execution from the main event loop, | ||
| ensuring that blocking async handlers do not prevent /ping from responding. | ||
| """ | ||
| if self._worker_loop is not None and self._worker_loop.is_running(): | ||
| return self._worker_loop | ||
| with self._worker_loop_lock: | ||
| if self._worker_loop is None or not self._worker_loop.is_running(): | ||
| self._worker_loop = asyncio.new_event_loop() | ||
| self._worker_thread = threading.Thread( | ||
| target=self._run_worker_loop, | ||
| daemon=True, | ||
| name="agentcore-worker-loop", | ||
| ) | ||
| self._worker_thread.start() | ||
| return self._worker_loop | ||
|
|
||
| def _run_worker_loop(self) -> None: | ||
| """Entry point for the worker loop background thread.""" | ||
| asyncio.set_event_loop(self._worker_loop) | ||
| self._worker_loop.run_forever() | ||
|
|
||
| @staticmethod | ||
| async def _run_with_context(coro: Any, ctx: contextvars.Context) -> Any: | ||
| """Run a coroutine after restoring context variables from a snapshot.""" | ||
| _restore_context(ctx) | ||
| return await coro | ||
|
|
||
| def _async_gen_to_sync_gen(self, async_gen: Any, ctx: contextvars.Context) -> Any: | ||
| """Bridge an async generator through the worker loop as a sync generator. | ||
|
|
||
| The async generator is iterated on the worker loop. Chunks are sent to | ||
| a thread-safe queue and yielded synchronously. Starlette's StreamingResponse | ||
| iterates this sync generator via iterate_in_threadpool, so the main event | ||
| loop is never blocked. | ||
| """ | ||
| worker_loop = self._ensure_worker_loop() | ||
| q: queue.Queue = queue.Queue(maxsize=100) | ||
| _DONE = object() | ||
|
|
||
| async def _produce() -> None: | ||
| _restore_context(ctx) | ||
| try: | ||
| async for chunk in async_gen: | ||
| q.put((True, chunk)) | ||
| q.put((True, _DONE)) | ||
| except BaseException as e: | ||
| q.put((False, e)) | ||
|
|
||
| worker_loop.call_soon_threadsafe(lambda: worker_loop.create_task(_produce())) | ||
|
|
||
| while True: | ||
| ok, value = q.get() | ||
| if not ok: | ||
| raise value | ||
| if value is _DONE: | ||
| break | ||
| yield value | ||
|
|
||
| async def _invoke_handler(self, handler: Callable, request_context: Any, takes_context: bool, payload: Any) -> Any: | ||
| """Dispatch handler execution based on handler type. | ||
|
|
||
| - Async generator functions: bridged through the worker loop as a sync generator | ||
| - Regular async functions: run on the dedicated worker event loop | ||
| - Sync functions (including sync generators): run in the thread pool | ||
|
|
||
| This ensures the main event loop stays responsive for /ping health checks | ||
| regardless of whether handlers contain blocking operations. | ||
| """ | ||
| try: | ||
| args = (payload, request_context) if takes_context else (payload,) | ||
|
|
||
| if asyncio.iscoroutinefunction(handler): | ||
| return await handler(*args) | ||
| ctx = contextvars.copy_context() | ||
|
|
||
| if _is_async_gen_callable(handler): | ||
| return self._async_gen_to_sync_gen(handler(*args), ctx) | ||
| elif _is_async_callable(handler): | ||
| worker_loop = self._ensure_worker_loop() | ||
| future = asyncio.run_coroutine_threadsafe(self._run_with_context(handler(*args), ctx), worker_loop) | ||
| result = await asyncio.wrap_future(future) | ||
| if inspect.isasyncgen(result): | ||
| return self._async_gen_to_sync_gen(result, ctx) | ||
| return result | ||
| else: | ||
| loop = asyncio.get_event_loop() | ||
| ctx = contextvars.copy_context() | ||
| return await loop.run_in_executor(None, ctx.run, handler, *args) | ||
| return await run_in_threadpool(ctx.run, handler, *args) | ||
| except Exception: | ||
| handler_name = getattr(handler, "__name__", "unknown") | ||
| self.logger.debug("Handler '%s' execution failed", handler_name) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This only unwraps
functools.partialbut doesn't follow__wrapped__chains set byfunctools.wraps. A decorated async function would be silently misclassified as sync and dispatched to the thread pool instead of the worker loop:Unverified suggestion — use
inspect.unwrap()to follow__wrapped__chains (has a built-in 200-hop safety limit to prevent cycles):