Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 114 additions & 8 deletions src/bedrock_agentcore/runtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@

import asyncio
import contextvars
import functools
import inspect
import json
import logging
import queue
import threading
import time
import uuid
from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional

from starlette.applications import Starlette
from starlette.concurrency import run_in_threadpool
from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
Expand All @@ -39,6 +42,30 @@
from .utils import convert_complex_objects


def _is_async_callable(obj: Any) -> bool:
"""Check if obj is async-callable, unwrapping functools.partial."""
while isinstance(obj, functools.partial):
obj = obj.func
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This only unwraps functools.partial but doesn't follow __wrapped__ chains set by functools.wraps. A decorated async function would be silently misclassified as sync and dispatched to the thread pool instead of the worker loop:

def my_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@my_decorator
async def handler(payload):
    return {"ok": True}

# _is_async_callable(handler) → False (checks wrapper, a regular def)

Unverified suggestion — use inspect.unwrap() to follow __wrapped__ chains (has a built-in 200-hop safety limit to prevent cycles):

def _unwrap(obj: Any) -> Any:
    """Unwrap functools.partial and decorator chains."""
    while isinstance(obj, functools.partial):
        obj = obj.func
    try:
        obj = inspect.unwrap(obj)
    except StopIteration:
        pass  # cycle detected, use as-is
    return obj


def _is_async_callable(obj: Any) -> bool:
    obj = _unwrap(obj)
    return asyncio.iscoroutinefunction(obj) or (
        callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
    )


def _is_async_gen_callable(obj: Any) -> bool:
    obj = _unwrap(obj)
    return inspect.isasyncgenfunction(obj) or (
        callable(obj) and inspect.isasyncgenfunction(obj.__call__)
    )



def _is_async_gen_callable(obj: Any) -> bool:
"""Check if obj is an async generator function, unwrapping functools.partial."""
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__))


def _restore_context(ctx: contextvars.Context) -> None:
"""Restore context variables from a snapshot (Django asgiref pattern)."""
for var, value in ctx.items():
try:
if var.get() != value:
var.set(value)
except LookupError:
var.set(value)


class RequestContextFormatter(logging.Formatter):
"""Formatter including request and session IDs."""

Expand Down Expand Up @@ -96,6 +123,9 @@ def __init__(
self._task_counter_lock: threading.Lock = threading.Lock()
self._forced_ping_status: Optional[PingStatus] = None
self._last_status_update_time: float = time.time()
self._worker_loop: Optional[asyncio.AbstractEventLoop] = None
self._worker_thread: Optional[threading.Thread] = None
self._worker_loop_lock: threading.Lock = threading.Lock()

routes = [
Route("/invocations", self._handle_invocation, methods=["POST"]),
Expand Down Expand Up @@ -163,7 +193,7 @@ def async_task(self, func: Callable) -> Callable:
- Set ping status to HEALTHY_BUSY while running
- Revert to HEALTHY when complete
"""
if not asyncio.iscoroutinefunction(func):
if not _is_async_callable(func):
raise ValueError("@async_task can only be applied to async functions")

async def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -463,16 +493,92 @@ def run(self, port: int = 8080, host: Optional[str] = None, **kwargs):

uvicorn.run(self, **uvicorn_params)

async def _invoke_handler(self, handler, request_context, takes_context, payload):
def _ensure_worker_loop(self) -> asyncio.AbstractEventLoop:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sync generator returned by this method is iterated by Starlette via iterate_in_threadpool, which holds a thread for the entire duration of the stream. The thread spends most of its time blocked on q.get() waiting for the next chunk from the worker loop.

Under concurrent streaming workloads (e.g., LLM streaming responses each taking ~20 seconds), this can exhaust the default thread pool (~40 threads) ref. Once exhausted, all new requests are queued waiting for a thread to free up, even though the occupied threads are mostly idle waiting on q.get().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is expected though, customers at max have 2 concurrent connections to runtime servers

"""Lazily create and start a dedicated worker event loop in a background thread.

The worker loop isolates async handler execution from the main event loop,
ensuring that blocking async handlers do not prevent /ping from responding.
"""
if self._worker_loop is not None and self._worker_loop.is_running():
return self._worker_loop
with self._worker_loop_lock:
if self._worker_loop is None or not self._worker_loop.is_running():
self._worker_loop = asyncio.new_event_loop()
self._worker_thread = threading.Thread(
target=self._run_worker_loop,
daemon=True,
name="agentcore-worker-loop",
)
self._worker_thread.start()
return self._worker_loop

def _run_worker_loop(self) -> None:
"""Entry point for the worker loop background thread."""
asyncio.set_event_loop(self._worker_loop)
self._worker_loop.run_forever()

@staticmethod
async def _run_with_context(coro: Any, ctx: contextvars.Context) -> Any:
"""Run a coroutine after restoring context variables from a snapshot."""
_restore_context(ctx)
return await coro

def _async_gen_to_sync_gen(self, async_gen: Any, ctx: contextvars.Context) -> Any:
"""Bridge an async generator through the worker loop as a sync generator.

The async generator is iterated on the worker loop. Chunks are sent to
a thread-safe queue and yielded synchronously. Starlette's StreamingResponse
iterates this sync generator via iterate_in_threadpool, so the main event
loop is never blocked.
"""
worker_loop = self._ensure_worker_loop()
q: queue.Queue = queue.Queue(maxsize=100)
_DONE = object()

async def _produce() -> None:
_restore_context(ctx)
try:
async for chunk in async_gen:
q.put((True, chunk))
q.put((True, _DONE))
except BaseException as e:
q.put((False, e))

worker_loop.call_soon_threadsafe(lambda: worker_loop.create_task(_produce()))

while True:
ok, value = q.get()
if not ok:
raise value
if value is _DONE:
break
yield value

async def _invoke_handler(self, handler: Callable, request_context: Any, takes_context: bool, payload: Any) -> Any:
"""Dispatch handler execution based on handler type.

- Async generator functions: bridged through the worker loop as a sync generator
- Regular async functions: run on the dedicated worker event loop
- Sync functions (including sync generators): run in the thread pool

This ensures the main event loop stays responsive for /ping health checks
regardless of whether handlers contain blocking operations.
"""
try:
args = (payload, request_context) if takes_context else (payload,)

if asyncio.iscoroutinefunction(handler):
return await handler(*args)
ctx = contextvars.copy_context()

if _is_async_gen_callable(handler):
return self._async_gen_to_sync_gen(handler(*args), ctx)
elif _is_async_callable(handler):
worker_loop = self._ensure_worker_loop()
future = asyncio.run_coroutine_threadsafe(self._run_with_context(handler(*args), ctx), worker_loop)
result = await asyncio.wrap_future(future)
if inspect.isasyncgen(result):
return self._async_gen_to_sync_gen(result, ctx)
return result
else:
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context()
return await loop.run_in_executor(None, ctx.run, handler, *args)
return await run_in_threadpool(ctx.run, handler, *args)
except Exception:
handler_name = getattr(handler, "__name__", "unknown")
self.logger.debug("Handler '%s' execution failed", handler_name)
Expand Down
Loading
Loading