diff --git a/pyproject.toml b/pyproject.toml index bb436d8ea..0f57263e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "singlestoredb" -version = "1.16.10" +version = "1.16.11rc1+byasaini" description = "Interface to the SingleStoreDB database and workspace management APIs" readme = {file = "README.md", content-type = "text/markdown"} license = {text = "Apache-2.0"} diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index aba8e1c47..8529b21ae 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -1,5 +1,10 @@ +import contextvars +import logging import os +import time +import uuid from typing import Any +from typing import AsyncIterator from typing import Callable from typing import Optional from typing import Union @@ -9,6 +14,231 @@ from singlestoredb import manage_workspaces from singlestoredb.management.inference_api import InferenceAPIInfo + +# Per-task trace id propagated into the embeddings HTTP transport. +# +# Callers (e.g. an EMBED_TEXT UDF) can do ``http_trace_id.set("")`` +# right before invoking ``aembed_documents``. The :class:`TracingAsyncTransport` +# reads this var inside ``handle_async_request`` and stamps every log line +# with it, so each per-stage HTTP timing line can be correlated back to the +# UDF request that produced it. +# +# ContextVars are per-asyncio-Task: even with many concurrent EMBED_TEXT +# coroutines sharing a single dispatch loop, each call has its own private +# context, so a set() in one task does not leak into another. +http_trace_id: 'contextvars.ContextVar[str]' = contextvars.ContextVar( + 'singlestoredb_embeddings_http_trace_id', default='-', +) + +_http_log = logging.getLogger('singlestoredb.ai.embeddings.http') + + +def _fmt_addr(addr: Any) -> str: + """Format a ``(host, port, ...)`` sockaddr tuple as ``host:port``.""" + if not addr: + return '?' + try: + return f'{addr[0]}:{addr[1]}' + except Exception: + return str(addr) + + +# Hosts already logged as transport-pinned, so we only emit one line per host +# instead of one per request. +_pin_logged_hosts: 'set[str]' = set() + + +class HttpTraceIdFilter(logging.Filter): + """ + Stamps every log record with the current :data:`http_trace_id` value + under ``record.trace_id``. + + Attach this to handlers (or loggers) for ``httpx`` / ``httpcore`` so + that their low-level per-stage log lines (``connect_tcp.started``, + ``start_tls.started``, ``send_request_body.complete`` etc.) carry the + same trace id the caller stamped before invoking ``aembed_documents``. + + Why this works: ``httpcore`` runs inside the same ``asyncio.Task`` as + the caller (it's just deeper in the ``await`` chain), and ContextVars + are per-Task, so ``http_trace_id.get()`` inside ``filter()`` returns + the value set by the caller for that specific request. Each concurrent + EMBED_TEXT call has its own value and they do not bleed across. + """ + + def filter(self, record: logging.LogRecord) -> bool: + record.trace_id = http_trace_id.get() + return True + + +def enable_http_debug_logging(level: int = logging.DEBUG) -> None: + """ + Turn on per-stage httpx / httpcore logging. + + This is very verbose (one log line per TCP connect, TLS handshake, + header send, body send, response header read, response body read, etc.) + but is the fastest way to pinpoint which network phase a stuck embedding + request is sitting in. Enable on demand, e.g. by setting the + ``SINGLESTOREDB_EMBEDDINGS_HTTP_DEBUG=1`` env var or by calling this + function directly. + """ + for name in ( + 'httpx', + 'httpcore', + 'httpcore.connection', + 'httpcore.http11', + 'httpcore.http2', + 'httpcore.proxy', + ): + logging.getLogger(name).setLevel(level) + + +if os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_HTTP_DEBUG', '', +).lower() in ('1', 'true', 'yes'): + enable_http_debug_logging() + + +class TracingAsyncTransport(httpx.AsyncBaseTransport): + """ + Wraps another :class:`httpx.AsyncBaseTransport` and logs per-stage + timings for every request, stamped with the current :data:`http_trace_id` + context value. + + Emits three log lines per request: + + 1. ``->`` when the request is handed to the inner transport, with the + outgoing body size. + 2. ``<- headers`` when the response headers arrive (gives time-to-first- + byte, i.e. the gap that captures DNS + TCP + TLS + upload + upstream + processing). + 3. ``<- body`` when the response body is fully consumed (gives separate + body-download elapsed and total elapsed). + + The 1->2 gap vs the 2->3 gap is what tells you whether a hang is on the + request/upstream side or on the response/download side. + """ + + def __init__(self, inner: httpx.AsyncBaseTransport) -> None: + self._inner = inner + + async def handle_async_request( + self, request: httpx.Request, + ) -> httpx.Response: + tid = http_trace_id.get() + rid = uuid.uuid4().hex[:6] + + # Deterministic, resolver-independent IP pin. A DNS pin (monkeypatching + # socket.getaddrinfo) is bypassed by httpx/anyio's resolution path, so + # connections still spread across every NLB IP and hit the cross-AZ + # source-port collision. Here we instead dial a fixed IP at the + # transport while preserving TLS SNI + certificate validation + the + # Host header, so every embedding request lands on one endpoint + # regardless of how DNS is resolved. Enabled via + # SINGLESTOREDB_EMBEDDINGS_PIN_IP (optionally restricted to + # SINGLESTOREDB_EMBEDDINGS_PIN_HOST). Read per-request so a pin set + # after this client was constructed (e.g. by the EMBED_TEXT notebook) + # still takes effect. + pin_ip = os.environ.get('SINGLESTOREDB_EMBEDDINGS_PIN_IP') + if pin_ip: + pin_host = os.environ.get('SINGLESTOREDB_EMBEDDINGS_PIN_HOST') + host = request.url.host + if host and host != pin_ip and (not pin_host or host == pin_host): + try: + # Preserve the original Host header (httpx set it to the + # hostname at build time); only ensure it is present. + if 'host' not in request.headers: + request.headers['Host'] = host + request.extensions = { + **request.extensions, 'sni_hostname': host, + } + request.url = request.url.copy_with(host=pin_ip) + if host not in _pin_logged_hosts: + _pin_logged_hosts.add(host) + _http_log.warning( + '[%s/%s] TRANSPORT IP PIN active: dialing %s via %s ' + '(SNI/cert/Host preserved)', tid, rid, host, pin_ip, + ) + except Exception as e: + _http_log.warning( + '[%s/%s] TRANSPORT IP PIN failed: %r', tid, rid, e, + ) + + try: + body_len = len(request.content or b'') + except httpx.RequestNotRead: + body_len = -1 + t0 = time.perf_counter() + _http_log.info( + '[%s/%s] -> %s %s body=%dB', + tid, rid, request.method, request.url, body_len, + ) + try: + response = await self._inner.handle_async_request(request) + except BaseException as e: + _http_log.error( + '[%s/%s] xx EXC after %.3fs: %r', + tid, rid, time.perf_counter() - t0, e, + ) + raise + + t_headers = time.perf_counter() + # Surface the actual local/remote socket addresses for this request so + # the (src_ip:src_port -> dst_ip:dst_port) 4-tuple can be correlated + # with node-side tcpdump/conntrack captures. ``src`` here is the pod's + # chosen ephemeral port (the one Cilium masquerade normally preserves), + # and ``dst`` is the resolved endpoint IP actually connected to. + local_addr = remote_addr = None + try: + stream = response.extensions.get('network_stream') + if stream is not None: + local_addr = stream.get_extra_info('client_addr') + remote_addr = stream.get_extra_info('server_addr') + if local_addr is None or remote_addr is None: + sock = stream.get_extra_info('socket') + if sock is not None: + local_addr = local_addr or sock.getsockname() + remote_addr = remote_addr or sock.getpeername() + except Exception: + pass + _http_log.info( + '[%s/%s] <- headers status=%d ttfb=%.3fs src=%s dst=%s', + tid, rid, response.status_code, t_headers - t0, + _fmt_addr(local_addr), _fmt_addr(remote_addr), + ) + + # Wrap the body stream so we also time how long the body download + # itself takes. Captures `tid`, `rid`, `t0`, `t_headers` by closure + # so the log line is correctly correlated even though body consumption + # happens later (after handle_async_request has returned). + original_stream = response.stream + + class _TimingStream(httpx.AsyncByteStream): + + async def __aiter__(self) -> AsyncIterator[bytes]: + total = 0 + t_body_start = time.perf_counter() + try: + async for chunk in original_stream: + total += len(chunk) + yield chunk + finally: + t_body_end = time.perf_counter() + _http_log.info( + '[%s/%s] <- body bytes=%d body_elapsed=%.3fs ' + 'total=%.3fs', + tid, rid, total, + t_body_end - t_body_start, t_body_end - t0, + ) + + async def aclose(self) -> None: + await original_stream.aclose() + + response.stream = _TimingStream() + return response + + async def aclose(self) -> None: + await self._inner.aclose() + try: from langchain_openai import OpenAIEmbeddings except ImportError: @@ -34,6 +264,7 @@ def SingleStoreEmbeddingsFactory( model_name: str, api_key: Optional[str] = None, http_client: Optional[httpx.Client] = None, + http_async_client: Optional[httpx.AsyncClient] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, base_url: Optional[str] = None, hosting_platform: Optional[str] = None, @@ -152,6 +383,70 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: ) if http_client is not None: openai_kwargs['http_client'] = http_client + + if http_async_client is None: + # Explicit timeouts: without these, httpx falls back to its 5s + # default at the client level, but the OpenAI SDK overrides that + # with a per-request 600s read timeout, so a stalled response can + # sit on the socket for ~10 minutes before httpx notices. We use a + # tighter read timeout so a dead/half-open connection fails fast + # instead of waiting for the application-level defensive timeout + # (e.g. EMBED_TEXT's asyncio.wait_for) to fire. + client_timeout = httpx.Timeout( + connect=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_CONNECT_TIMEOUT', '10', + ), + ), + read=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_READ_TIMEOUT', '60', + ), + ), + write=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_WRITE_TIMEOUT', '30', + ), + ), + pool=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_POOL_TIMEOUT', '10', + ), + ), + ) + # Allow connection reuse. The previous configuration + # (max_keepalive_connections=0) forced a fresh TCP+TLS handshake + # for every request, which under heavy concurrency churns sockets + # and occasionally yields one connection that the upstream accepts + # but never finishes responding on. + client_limits = httpx.Limits( + max_connections=int( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_MAX_CONNECTIONS', '64', + ), + ), + max_keepalive_connections=int( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_MAX_KEEPALIVE', '16', + ), + ), + keepalive_expiry=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_KEEPALIVE_EXPIRY', '30', + ), + ), + ) + http_async_client = httpx.AsyncClient( + timeout=client_timeout, + limits=client_limits, + transport=TracingAsyncTransport( + httpx.AsyncHTTPTransport( + limits=client_limits, + ), + ), + ) + openai_kwargs['http_async_client'] = http_async_client + return OpenAIEmbeddings( **openai_kwargs, **kwargs, diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index e45718dec..0c757b394 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -61,11 +61,24 @@ async def run_udf_app( f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', ) + # uvicorn's default keep-alive timeout (5s) makes the server the active + # closer of idle pooled connections from the upstream proxy (cilium-envoy). + # The server then holds the TCP TIME-WAIT for that 4-tuple for ~60s; when the + # proxy reuses the same source port within that window the new SYN collides + # with the lingering TIME-WAIT socket and is silently dropped, surfacing as + # 5s connect timeouts (503 "upstream connect error ... connection timeout"). + # Keeping connections warm past the proxy's idle interval avoids that churn. + # keep_alive_timeout = int( + # os.environ.get('SINGLESTOREDB_APP_KEEPALIVE_TIMEOUT', '120'), + # ) + keep_alive_timeout = 120 + config = uvicorn.Config( app, host='0.0.0.0', port=app_config.listen_port, log_config=app.get_uvicorn_log_config(), + timeout_keep_alive=keep_alive_timeout, ) # Register the functions only if the app is running interactively. diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 331828d17..dafa7650d 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,6 +24,7 @@ """ import argparse import asyncio +import atexit import contextvars import dataclasses import datetime @@ -113,6 +114,226 @@ async def to_thread( return await loop.run_in_executor(None, func_call) +async def _poll_cancel(cancel_event: threading.Event) -> None: + while not cancel_event.is_set(): + await asyncio.sleep(0.1) + + +async def _cancellable_run( + cancel_event: threading.Event, + coro: Any, +) -> Any: + task = asyncio.create_task(coro) + cancel_check = asyncio.create_task(_poll_cancel(cancel_event)) + done, pending = await asyncio.wait( + [task, cancel_check], return_when=asyncio.FIRST_COMPLETED, + ) + for p in pending: + p.cancel() + if cancel_check in done: + task.cancel() + raise asyncio.CancelledError() + return task.result() + + +# Each `to_thread` worker thread owns a long-lived event loop reused across +# requests, so loop-bound resources (HTTP pools, DB sessions, sockets) can +# survive between calls handled by the same thread. +# +# This per-thread loop is only used for SYNC user UDFs: a sync UDF blocks +# its worker thread for the duration of the call, so giving each worker +# thread its own loop avoids cross-thread loop sharing for those calls. +_thread_local = threading.local() +_loop_registry: 'Set[asyncio.AbstractEventLoop]' = set() +_loop_registry_lock = threading.Lock() + + +def _get_thread_loop() -> asyncio.AbstractEventLoop: + """Return (creating if needed) the calling thread's persistent loop.""" + loop = getattr(_thread_local, 'loop', None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + _thread_local.loop = loop + with _loop_registry_lock: + _loop_registry.add(loop) + return loop + + +def _run_on_thread_loop(coro: Any) -> Any: + """ + Run ``coro`` on the calling thread's persistent loop. + + The loop is never closed between calls, so loop-bound resources (e.g. + httpx keep-alive pools) survive across requests and the deferred + "Event loop is closed" errors thrown by httpx/anyio at teardown do not + occur. + + Caveat: tasks the user code spawns via ``asyncio.create_task`` and + leaves running outlive the current call too. That is the price of + keeping shared resources alive; ``cancel_event`` does not reach them. + """ + loop = _get_thread_loop() + return loop.run_until_complete(coro) + + +def _shutdown_thread_loops() -> None: + """Best-effort cleanup of all persistent worker-thread loops at exit.""" + with _loop_registry_lock: + loops = list(_loop_registry) + _loop_registry.clear() + + for loop in loops: + if loop.is_closed(): + continue + try: + # Owning thread is no longer running the loop; safe to drive + # teardown from this (exiting) thread. + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + except Exception: + pass + finally: + try: + loop.close() + except Exception: + pass + + +atexit.register(_shutdown_thread_loops) + + +# Dedicated event loop used for ALL async UDF requests. +# +# Async UDFs commonly create resources that are bound to the event loop +# they are first used on (httpx connection pools, async DB clients, anyio +# streams, ...). Dispatching async requests to ad-hoc worker threads — +# each with its own loop — produced " is bound +# to a different event loop" errors when those cached resources were +# reused by a request that landed on a different worker thread. +# +# Routing every async UDF onto a single dedicated loop fixes that, and +# also gives true concurrency across requests: ``run_coroutine_threadsafe`` +# schedules each new coroutine immediately so that incoming requests do +# not queue behind in-flight ones. +# +# Sync UDFs intentionally still go through the worker-thread / per-thread +# loop path above: a sync UDF would block this dedicated loop and starve +# every other in-flight async request. +_async_dispatch_loop: 'Optional[asyncio.AbstractEventLoop]' = None +_async_dispatch_thread: 'Optional[threading.Thread]' = None +_async_dispatch_lock = threading.Lock() + + +def _get_async_dispatch_loop() -> asyncio.AbstractEventLoop: + """ + Return (lazily creating) the singleton async-dispatch event loop. + + The loop is owned by a dedicated daemon thread that runs ``run_forever`` + for the lifetime of the process. All async UDF coroutines are scheduled + on this loop so that loop-bound resources can be safely reused across + requests. + """ + global _async_dispatch_loop, _async_dispatch_thread + + loop = _async_dispatch_loop + if loop is not None and not loop.is_closed(): + return loop + + with _async_dispatch_lock: + if _async_dispatch_loop is not None and \ + not _async_dispatch_loop.is_closed(): + return _async_dispatch_loop + + ready = threading.Event() + captured: List[asyncio.AbstractEventLoop] = [] + + def run_loop() -> None: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + captured.append(new_loop) + ready.set() + try: + new_loop.run_forever() + finally: + try: + new_loop.run_until_complete(new_loop.shutdown_asyncgens()) + except Exception: + pass + try: + new_loop.run_until_complete( + new_loop.shutdown_default_executor(), + ) + except Exception: + pass + try: + new_loop.close() + except Exception: + pass + + thread = threading.Thread( + target=run_loop, + name='singlestoredb-udf-async-dispatch', + daemon=True, + ) + thread.start() + ready.wait() + + _async_dispatch_loop = captured[0] + _async_dispatch_thread = thread + return _async_dispatch_loop + + +def _get_async_dispatch_thread() -> 'Optional[threading.Thread]': + """Return the dedicated dispatch thread (or ``None`` if not started).""" + return _async_dispatch_thread + + +async def _dispatch_to_async_loop(coro: Any) -> Any: + """ + Schedule ``coro`` on the dedicated async-dispatch loop and await its result. + + The coroutine begins running immediately on the dispatch loop — it does + NOT wait for any earlier in-flight async UDF to complete — so concurrent + async requests run in parallel on a single shared loop. + + Cancellation of the awaiting task on the caller's loop is propagated + best-effort to the work scheduled on the dispatch loop. The user code's + ``cancel_event`` (set by the request handler on timeout / disconnect) + remains the authoritative cancellation signal because it is observed by + ``_cancellable_run`` from inside the dispatch loop. + """ + loop = _get_async_dispatch_loop() + cf = asyncio.run_coroutine_threadsafe(coro, loop) + try: + return await asyncio.wrap_future(cf) + except asyncio.CancelledError: + cf.cancel() + raise + + +def _shutdown_async_dispatch_loop() -> None: + """Best-effort cleanup of the dedicated async-dispatch loop at exit.""" + global _async_dispatch_loop, _async_dispatch_thread + with _async_dispatch_lock: + loop = _async_dispatch_loop + thread = _async_dispatch_thread + _async_dispatch_loop = None + _async_dispatch_thread = None + + if loop is not None and not loop.is_closed(): + try: + loop.call_soon_threadsafe(loop.stop) + except Exception: + pass + + if thread is not None: + thread.join(timeout=5) + + +atexit.register(_shutdown_async_dispatch_loop) + + # Use negative values to indicate unsigned ints / binary data / usec time precision rowdat_1_type_map = { 'bool': ft.LONGLONG, @@ -1190,20 +1411,44 @@ async def __call__( cancel_event = threading.Event() - with timer('parse_input'): - inputs = input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), + # Parsing the request body can be CPU heavy (esp. for + # rowdat_1 / arrow payloads). Run it in the default + # executor thread pool so the main uvicorn loop is not + # blocked while inputs are being decoded. + load_input = input_handler['load'] # type: ignore + colspec = func_info['colspec'] + async with timer('parse_input'): + inputs = await to_thread( + lambda: load_input(colspec, b''.join(data)), ) - func_task = asyncio.create_task( - func(cancel_event, call_timer, *inputs) - if func_info['is_async'] - else to_thread( - lambda: asyncio.run( - func(cancel_event, call_timer, *inputs), + # Async user UDFs share a single dedicated event-loop thread + # so that loop-bound resources (httpx pools, async clients, + # ...) can be reused across requests; new requests are + # scheduled immediately and run concurrently on that loop. + # Sync user UDFs continue to use the worker-thread pool (one + # persistent loop per thread) because a sync call would + # block the shared dispatch loop and starve other requests. + if func_info.get('is_async'): + func_task = asyncio.create_task( + _dispatch_to_async_loop( + _cancellable_run( + cancel_event, + func(cancel_event, call_timer, *inputs), + ), ), - ), - ) + ) + else: + func_task = asyncio.create_task( + to_thread( + lambda: _run_on_thread_loop( + _cancellable_run( + cancel_event, + func(cancel_event, call_timer, *inputs), + ), + ), + ), + ) disconnect_task = asyncio.create_task( asyncio.sleep(int(1e9)) if ignore_cancel else cancel_on_disconnect(receive), @@ -1219,17 +1464,21 @@ async def __call__( all_tasks, return_when=asyncio.FIRST_COMPLETED, ) + # Signal the worker before awaiting cancellation: cancelling + # func_task only flips its asyncio wrapper, not the executor + # work; only cancel_event reaches the worker loop. + if func_task in pending: + cancel_event.set() + await cancel_all_tasks(pending) for task in done: if task is disconnect_task: - cancel_event.set() raise asyncio.CancelledError( 'Function call was cancelled by client disconnect', ) elif task is timeout_task: - cancel_event.set() raise asyncio.TimeoutError( 'Function call was cancelled due to timeout', ) @@ -1237,9 +1486,15 @@ async def __call__( elif task is func_task: result.extend(task.result()) - with timer('format_output'): - body = output_handler['dump']( - [x[1] for x in func_info['returns']], *result, # type: ignore + # Serializing the response can also be CPU heavy. Run it + # in the default executor thread pool so the main + # uvicorn loop stays responsive to other connections + # while this request is being encoded. + dump_output = output_handler['dump'] # type: ignore + return_types = [x[1] for x in func_info['returns']] + async with timer('format_output'): + body = await to_thread( + dump_output, return_types, *result, ) await send(output_handler['response']) @@ -1292,6 +1547,7 @@ async def __call__( await send(self.error_response_dict) finally: + cancel_event.set() await cancel_all_tasks(all_tasks) # Handle api reflection diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py new file mode 100644 index 000000000..7c89c397a --- /dev/null +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -0,0 +1,835 @@ +"""Tests for the async UDF persistent per-thread event loop.""" +import asyncio +import contextvars +import json as jsonlib +import threading +import time +import unittest +from typing import Any +from typing import Dict +from typing import List +from typing import Set +from typing import Tuple + +from ..functions import udf +from ..functions.ext.asgi import _cancellable_run +from ..functions.ext.asgi import _dispatch_to_async_loop +from ..functions.ext.asgi import _get_async_dispatch_loop +from ..functions.ext.asgi import _get_async_dispatch_thread +from ..functions.ext.asgi import _get_thread_loop +from ..functions.ext.asgi import _run_on_thread_loop +from ..functions.ext.asgi import Application +from ..functions.ext.asgi import to_thread + + +class TestUDFDispatchEdgeCases(unittest.TestCase): + """Test edge cases in the UDF dispatch stack.""" + + def test_timeout_cancels_running_function(self) -> None: + """Cancel event set from timer thread cancels a blocked coroutine.""" + cancel_event = threading.Event() + + async def long_running() -> str: + await asyncio.sleep(999) + return 'should not reach' + + def set_cancel_after_delay() -> None: + time.sleep(0.2) + cancel_event.set() + + timer = threading.Thread(target=set_cancel_after_delay) + timer.start() + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, long_running()), + ) + elapsed = time.monotonic() - start + timer.join() + # 0.2s delay + up to 0.1s poll interval + margin + self.assertLess(elapsed, 0.5) + + def test_exception_propagates_through_full_stack(self) -> None: + """User exception propagates unwrapped through the entire dispatch.""" + cancel_event = threading.Event() + + class CustomUDFError(Exception): + pass + + async def failing_udf() -> None: + raise CustomUDFError('embedding service unavailable') + + with self.assertRaises(CustomUDFError) as ctx: + _run_on_thread_loop( + _cancellable_run(cancel_event, failing_udf()), + ) + self.assertEqual(str(ctx.exception), 'embedding service unavailable') + + def test_cancel_event_detected_within_poll_interval(self) -> None: + """Cancellation is detected within one poll cycle (0.1s).""" + cancel_event = threading.Event() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + def set_cancel() -> None: + time.sleep(0.05) + cancel_event.set() + + timer = threading.Thread(target=set_cancel) + timer.start() + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, blocked()), + ) + elapsed = time.monotonic() - start + timer.join() + # 0.05s delay + 0.1s poll interval + margin + self.assertLess(elapsed, 0.25) + + def test_context_vars_propagate_through_to_thread(self) -> None: + """Context variables are visible inside to_thread executor.""" + test_var: contextvars.ContextVar[str] = contextvars.ContextVar( + 'test_var', + ) + test_var.set('hello_from_parent') + captured: List[str] = [] + + def read_context_var() -> str: + val = test_var.get('NOT_FOUND') + captured.append(val) + return val + + async def run_in_thread() -> str: + return await to_thread(read_context_var) + + result = _run_on_thread_loop(run_in_thread()) + self.assertEqual(result, 'hello_from_parent') + self.assertEqual(captured, ['hello_from_parent']) + + def test_concurrent_requests_isolated(self) -> None: + """Parallel executions don't share state.""" + results: List[Any] = [None, None, None] + + def run_isolated(index: int) -> None: + async def compute() -> int: + await asyncio.sleep(0.05) + return index * 10 + + results[index] = _run_on_thread_loop(compute()) + + threads = [ + threading.Thread(target=run_isolated, args=(i,)) + for i in range(3) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(results, [0, 10, 20]) + + def test_sync_function_through_async_wrapper(self) -> None: + """Synchronous function works when wrapped as async coroutine.""" + cancel_event = threading.Event() + + async def sync_as_async() -> int: + # Simulates what decorator.py's async_wrapper does for sync UDFs + return 42 + 1 + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, sync_as_async()), + ) + self.assertEqual(result, 43) + + def test_cancel_event_not_set_on_success(self) -> None: + """Cancel event remains unset after successful execution.""" + cancel_event = threading.Event() + + async def quick() -> str: + return 'fast' + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, quick()), + ) + self.assertEqual(result, 'fast') + self.assertFalse(cancel_event.is_set()) + + +class TestRunOnThreadLoop(unittest.TestCase): + """Test _run_on_thread_loop reuses a persistent per-thread event loop.""" + + def test_basic_coroutine(self) -> None: + async def simple() -> int: + return 42 + + self.assertEqual(_run_on_thread_loop(simple()), 42) + + def test_loop_reused_across_calls(self) -> None: + """The same loop object is reused for successive calls in a thread.""" + loops: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + loops.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture_loop()) + _run_on_thread_loop(capture_loop()) + + self.assertIs(loops[0], loops[1]) + + def test_loop_not_closed_between_calls(self) -> None: + """The persistent loop stays open so resources survive requests.""" + captured: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + captured.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture_loop()) + loop = captured[0] + self.assertFalse(loop.is_closed()) + + # Still usable for the next request. + _run_on_thread_loop(capture_loop()) + self.assertFalse(loop.is_closed()) + + def test_async_resource_survives_between_calls(self) -> None: + """An object bound to the loop can be reused on the next call. + + This mirrors caching e.g. an httpx.AsyncClient keyed by the loop and + reusing its connection pool on subsequent requests. + """ + clients: dict = {} + + async def get_or_create_client() -> int: + loop = asyncio.get_running_loop() + if loop not in clients: + clients[loop] = object() + return id(clients[loop]) + + first = _run_on_thread_loop(get_or_create_client()) + second = _run_on_thread_loop(get_or_create_client()) + + self.assertEqual(first, second) + self.assertEqual(len(clients), 1) + + def test_separate_threads_get_separate_loops(self) -> None: + """Each worker thread owns its own persistent loop.""" + loops: List[asyncio.AbstractEventLoop] = [] + lock = threading.Lock() + + def run_in_thread() -> None: + async def capture() -> bool: + with lock: + loops.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture()) + + threads = [threading.Thread(target=run_in_thread) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(loops), 3) + self.assertEqual(len({id(loop) for loop in loops}), 3) + + def test_get_thread_loop_idempotent(self) -> None: + """_get_thread_loop returns the same loop on repeated calls.""" + def run_in_thread(out: List[asyncio.AbstractEventLoop]) -> None: + out.append(_get_thread_loop()) + out.append(_get_thread_loop()) + + out: List[asyncio.AbstractEventLoop] = [] + t = threading.Thread(target=run_in_thread, args=(out,)) + t.start() + t.join() + + self.assertIs(out[0], out[1]) + + def test_exception_propagates(self) -> None: + async def failing() -> None: + raise ValueError('test error') + + with self.assertRaises(ValueError) as ctx: + _run_on_thread_loop(failing()) + self.assertEqual(str(ctx.exception), 'test error') + + def test_cancellable_run_integration(self) -> None: + """_cancellable_run works on the persistent loop.""" + cancel_event = threading.Event() + + async def slow_func() -> str: + return 'completed' + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, slow_func()), + ) + self.assertEqual(result, 'completed') + + def test_cancellation_via_event(self) -> None: + """Cancellation propagates through the persistent-loop stack.""" + cancel_event = threading.Event() + cancel_event.set() + + async def blocked_func() -> str: + await asyncio.sleep(999) + return 'should not reach' + + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, blocked_func()), + ) + + # Loop must remain usable after a cancelled request. + async def quick() -> str: + return 'ok' + + self.assertEqual( + _run_on_thread_loop(_cancellable_run(threading.Event(), quick())), + 'ok', + ) + + +class TestAsyncDispatchLoop(unittest.TestCase): + """All async UDF dispatches share a single dedicated event-loop thread. + + The dispatch loop is process-global and lazily started; resources bound + to it (HTTP pools, async clients, connection caches) are reused across + every async UDF request. New requests are scheduled immediately and run + concurrently on that loop instead of being serialized behind earlier + in-flight requests. + """ + + def test_dispatch_loop_is_single_dedicated_thread(self) -> None: + """All dispatches run on the same dedicated thread (not the caller).""" + seen_threads: Set[int] = set() + + async def capture() -> int: + seen_threads.add(threading.get_ident()) + return 1 + + async def run_many() -> None: + await asyncio.gather(*[ + _dispatch_to_async_loop(capture()) for _ in range(8) + ]) + + caller_thread = threading.get_ident() + asyncio.run(run_many()) + + self.assertEqual(len(seen_threads), 1) + self.assertNotIn(caller_thread, seen_threads) + # The thread we observed is the singleton dispatch thread. + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + self.assertEqual(seen_threads.pop(), dispatch_thread.ident) + + def test_dispatch_loop_is_single_event_loop(self) -> None: + """All dispatches run on the SAME event loop instance.""" + captured: List[asyncio.AbstractEventLoop] = [] + + async def capture() -> int: + captured.append(asyncio.get_running_loop()) + return 1 + + async def run_many() -> None: + await asyncio.gather(*[ + _dispatch_to_async_loop(capture()) for _ in range(5) + ]) + + asyncio.run(run_many()) + + self.assertEqual(len(captured), 5) + first = captured[0] + for loop in captured: + self.assertIs(loop, first) + self.assertIs(first, _get_async_dispatch_loop()) + + def test_concurrent_dispatches_do_not_serialize(self) -> None: + """Slow dispatches run in parallel on the loop; new requests do not + wait for earlier ones to finish.""" + n = 6 + per_call_sleep = 0.3 + + async def slow() -> str: + await asyncio.sleep(per_call_sleep) + return 'done' + + async def run_many() -> List[str]: + return await asyncio.gather(*[ + _dispatch_to_async_loop(slow()) for _ in range(n) + ]) + + start = time.monotonic() + results = asyncio.run(run_many()) + elapsed = time.monotonic() - start + + self.assertEqual(results, ['done'] * n) + # Serialized would be ~ n * per_call_sleep. Parallel ~ per_call_sleep. + # Allow generous margin for CI noise. + self.assertLess(elapsed, per_call_sleep * 2) + + def test_new_request_does_not_wait_for_in_flight_request(self) -> None: + """A new async request is submitted to the dispatch thread + immediately and runs while an earlier request is still in-flight. + + This is the explicit guarantee that async UDF dispatch is not + serialized: a request fired AFTER another long one has started + must (a) start before the long one finishes, (b) finish before + the long one finishes, and (c) be submitted with negligible + latency from the caller's perspective. + """ + long_sleep = 1.0 + ts: Dict[str, float] = {} + # Created lazily on the dispatch loop so the asyncio.Event is bound + # to the correct loop. + signals: Dict[str, asyncio.Event] = {} + + async def long_running() -> str: + ts['long_started'] = time.monotonic() + signals['started'] = asyncio.Event() + signals['started'].set() + await asyncio.sleep(long_sleep) + ts['long_finished'] = time.monotonic() + return 'long' + + async def quick() -> str: + ts['quick_started'] = time.monotonic() + await asyncio.sleep(0) + ts['quick_finished'] = time.monotonic() + return 'quick' + + async def driver() -> None: + long_task = asyncio.create_task( + _dispatch_to_async_loop(long_running()), + ) + # Wait until the long task has actually started on the + # dispatch loop. Only after this point can we be sure the + # next dispatch is "during" an in-flight request. + for _ in range(100): + await asyncio.sleep(0.01) + if 'started' in signals and signals['started'].is_set(): + break + self.assertIn('long_started', ts) + + ts['quick_dispatch_called'] = time.monotonic() + quick_result = await _dispatch_to_async_loop(quick()) + ts['quick_dispatch_returned'] = time.monotonic() + self.assertEqual(quick_result, 'quick') + + long_result = await long_task + self.assertEqual(long_result, 'long') + + asyncio.run(driver()) + + # The new request actually overlapped the in-flight one. + self.assertGreater(ts['quick_started'], ts['long_started']) + self.assertLess(ts['quick_started'], ts['long_finished']) + self.assertLess(ts['quick_finished'], ts['long_finished']) + + # Submission of the new request to the dispatch thread is + # non-blocking: the awaiter returned in well under the long + # request's remaining time. + dispatch_latency = ts['quick_dispatch_returned'] \ + - ts['quick_dispatch_called'] + self.assertLess(dispatch_latency, long_sleep / 2) + + def test_many_new_requests_run_during_one_in_flight_request(self) -> None: + """Many new async requests, each fired sequentially while a single + long-running request is in-flight, all start AND finish before the + long one finishes.""" + long_sleep = 1.0 + n_quick = 8 + ts: Dict[str, float] = {} + quick_finished: List[float] = [] + signals: Dict[str, asyncio.Event] = {} + + async def long_running() -> str: + ts['long_started'] = time.monotonic() + signals['started'] = asyncio.Event() + signals['started'].set() + await asyncio.sleep(long_sleep) + ts['long_finished'] = time.monotonic() + return 'long' + + async def quick(i: int) -> int: + await asyncio.sleep(0.01) + quick_finished.append(time.monotonic()) + return i + + async def driver() -> None: + long_task = asyncio.create_task( + _dispatch_to_async_loop(long_running()), + ) + # Wait for the long task to start. + for _ in range(100): + await asyncio.sleep(0.01) + if 'started' in signals and signals['started'].is_set(): + break + + results = await asyncio.gather(*[ + _dispatch_to_async_loop(quick(i)) for i in range(n_quick) + ]) + self.assertEqual(results, list(range(n_quick))) + await long_task + + asyncio.run(driver()) + + # All quick requests finished before the long one did, proving + # they were not queued behind it. + self.assertEqual(len(quick_finished), n_quick) + for finish in quick_finished: + self.assertLess(finish, ts['long_finished']) + self.assertGreater(finish, ts['long_started']) + + def test_loop_bound_resource_reused_across_dispatches(self) -> None: + """A resource keyed by id(loop) is shared by every async request, + even across separate caller event loops (separate parent runs).""" + cache: Dict[int, object] = {} + + async def acquire() -> int: + loop = asyncio.get_running_loop() + key = id(loop) + if key not in cache: + cache[key] = object() + return id(cache[key]) + + async def run_one() -> int: + return await _dispatch_to_async_loop(acquire()) + + first = asyncio.run(run_one()) + second = asyncio.run(run_one()) + third = asyncio.run(run_one()) + + self.assertEqual(first, second) + self.assertEqual(second, third) + self.assertEqual(len(cache), 1) + + def test_dispatch_propagates_exception(self) -> None: + """Exceptions from the dispatched coroutine surface to the caller.""" + class DispatchedError(Exception): + pass + + async def failing() -> None: + raise DispatchedError('boom') + + async def driver() -> None: + await _dispatch_to_async_loop(failing()) + + with self.assertRaises(DispatchedError) as ctx: + asyncio.run(driver()) + self.assertEqual(str(ctx.exception), 'boom') + + def test_dispatch_with_cancel_event(self) -> None: + """`_cancellable_run` on the dispatch loop honors the cancel event.""" + cancel_event = threading.Event() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + def trip_cancel() -> None: + time.sleep(0.1) + cancel_event.set() + + timer = threading.Thread(target=trip_cancel) + timer.start() + + async def driver() -> None: + await _dispatch_to_async_loop( + _cancellable_run(cancel_event, blocked()), + ) + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + asyncio.run(driver()) + elapsed = time.monotonic() - start + timer.join() + # 0.1s delay + 0.1s poll interval + margin + self.assertLess(elapsed, 0.5) + + def test_dispatch_loop_survives_after_cancellation(self) -> None: + """The dispatch loop remains usable after a cancelled request.""" + cancel_event = threading.Event() + cancel_event.set() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + async def driver_cancel() -> None: + await _dispatch_to_async_loop( + _cancellable_run(cancel_event, blocked()), + ) + + with self.assertRaises(asyncio.CancelledError): + asyncio.run(driver_cancel()) + + async def quick() -> str: + return 'ok' + + async def driver_ok() -> str: + return await _dispatch_to_async_loop(quick()) + + self.assertEqual(asyncio.run(driver_ok()), 'ok') + + +# Module-level UDFs used by the Application integration tests below. They +# must be defined at module scope so the signature inspection helpers can +# resolve their type hints. + +# Records the thread that actually executes each UDF body, keyed by tag. +_dispatch_observation: Dict[str, int] = {} +_dispatch_observation_lock = threading.Lock() +# Per-tag start / finish timestamps, used by the "no waiting for in-flight" +# test below to assert overlap between concurrent requests. +_dispatch_started_at: Dict[str, float] = {} +_dispatch_finished_at: Dict[str, float] = {} + + +def _record(tag: str) -> None: + with _dispatch_observation_lock: + _dispatch_observation[tag] = threading.get_ident() + _dispatch_started_at[tag] = time.monotonic() + + +def _record_finish(tag: str) -> None: + with _dispatch_observation_lock: + _dispatch_finished_at[tag] = time.monotonic() + + +@udf +async def _async_record_udf(tag: str) -> int: + _record(tag) + await asyncio.sleep(0) + _record_finish(tag) + return len(tag) + + +@udf +async def _async_slow_udf(tag: str) -> int: + _record(tag) + await asyncio.sleep(0.4) + _record_finish(tag) + return len(tag) + + +@udf +async def _async_long_udf(tag: str) -> int: + """Long-running async UDF used to verify that newly arriving async + requests do not have to wait for it to finish.""" + _record(tag) + await asyncio.sleep(1.0) + _record_finish(tag) + return len(tag) + + +@udf +def _sync_record_udf(tag: str) -> int: + _record(tag) + _record_finish(tag) + return len(tag) + + +def _make_invoke_args( + name: str, + rows: List[Tuple[Any, ...]], +) -> Tuple[Dict[str, Any], Any, Any, List[Dict[str, Any]]]: + """Build a minimal ASGI scope/receive/send for an /invoke request.""" + payload = jsonlib.dumps({ + 'data': [[i, *row] for i, row in enumerate(rows)], + }).encode('utf-8') + + received: Dict[str, bool] = {'sent': False} + + async def receive() -> Dict[str, Any]: + if received['sent']: + await asyncio.sleep(60) + return {'type': 'http.disconnect'} + received['sent'] = True + return {'type': 'http.request', 'body': payload, 'more_body': False} + + sent: List[Dict[str, Any]] = [] + + async def send(msg: Dict[str, Any]) -> None: + sent.append(msg) + + scope = { + 'type': 'http', + 'method': 'POST', + 'path': '/invoke', + 'scheme': 'http', + 'headers': [ + (b'content-type', b'application/json'), + (b'accepts', b'application/json'), + (b's2-ef-name', name.encode('utf-8')), + (b's2-ef-version', b'1.0'), + (b's2-ef-ignore-cancel', b'true'), + ], + } + return scope, receive, send, sent + + +def _reset_dispatch_observation() -> None: + with _dispatch_observation_lock: + _dispatch_observation.clear() + _dispatch_started_at.clear() + _dispatch_finished_at.clear() + + +class TestApplicationDispatchRouting(unittest.TestCase): + """End-to-end: Application routes async UDFs to the dispatch loop and + sync UDFs to a worker thread, and concurrent async requests run in + parallel on the dispatch loop.""" + + def setUp(self) -> None: + _reset_dispatch_observation() + self.app = Application( + functions=[ + _async_record_udf, + _async_slow_udf, + _async_long_udf, + _sync_record_udf, + ], + disable_metrics=True, + ) + + @staticmethod + def _headers_dict(scope: Dict[str, Any]) -> Dict[bytes, bytes]: + return {k: v for k, v in scope['headers']} + + def _invoke(self, name: str, rows: List[Tuple[Any, ...]]) -> List[Dict[str, Any]]: + scope, receive, send, sent = _make_invoke_args(name, rows) + scope['headers'] = list(scope['headers']) + # Application reads headers as a dict via ``dict(scope['headers'])``, + # which works for our list of tuples. + asyncio.run(self.app(scope, receive, send)) + return sent + + async def _invoke_async( + self, name: str, rows: List[Tuple[Any, ...]], + ) -> List[Dict[str, Any]]: + scope, receive, send, sent = _make_invoke_args(name, rows) + scope['headers'] = list(scope['headers']) + await self.app(scope, receive, send) + return sent + + def test_async_udf_runs_on_dispatch_thread(self) -> None: + """An async UDF body executes on the dedicated dispatch thread.""" + sent = self._invoke('_async_record_udf', [('alpha',)]) + statuses = [m for m in sent if m.get('type') == 'http.response.start'] + self.assertTrue(statuses and statuses[0]['status'] == 200, sent) + + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + with _dispatch_observation_lock: + self.assertEqual(_dispatch_observation['alpha'], dispatch_thread.ident) + + def test_sync_udf_runs_on_a_worker_thread_not_dispatch(self) -> None: + """A sync UDF body runs on a worker thread, NOT the dispatch thread.""" + # Force the dispatch thread to exist so we can compare ids. + _get_async_dispatch_loop() + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + + sent = self._invoke('_sync_record_udf', [('beta',)]) + statuses = [m for m in sent if m.get('type') == 'http.response.start'] + self.assertTrue(statuses and statuses[0]['status'] == 200, sent) + + with _dispatch_observation_lock: + sync_thread = _dispatch_observation['beta'] + + self.assertNotEqual(sync_thread, threading.get_ident()) + self.assertNotEqual(sync_thread, dispatch_thread.ident) + + def test_concurrent_async_requests_share_dispatch_thread(self) -> None: + """Two concurrent async UDF requests both execute on the dispatch thread.""" + + async def driver() -> None: + await asyncio.gather( + self._invoke_async('_async_record_udf', [('one',)]), + self._invoke_async('_async_record_udf', [('two',)]), + self._invoke_async('_async_record_udf', [('three',)]), + ) + + asyncio.run(driver()) + + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + with _dispatch_observation_lock: + for tag in ('one', 'two', 'three'): + self.assertEqual( + _dispatch_observation[tag], dispatch_thread.ident, + f'tag {tag} ran on wrong thread', + ) + + def test_concurrent_async_requests_do_not_serialize(self) -> None: + """Concurrent async UDF requests run in parallel on the dispatch loop; + a new request does not wait for in-flight ones.""" + n = 4 + per_call_sleep = 0.4 + + async def driver() -> None: + await asyncio.gather(*[ + self._invoke_async('_async_slow_udf', [(f'r{i}',)]) + for i in range(n) + ]) + + start = time.monotonic() + asyncio.run(driver()) + elapsed = time.monotonic() - start + + # Serialized would be ~ n * per_call_sleep. Parallel ~ per_call_sleep. + self.assertLess(elapsed, per_call_sleep * 2) + + def test_new_async_request_runs_during_in_flight_request(self) -> None: + """An async request arriving while another is still running gets + dispatched onto the async thread immediately and finishes before + the in-flight one — i.e., a new request does not wait for any + existing async request to be served.""" + + async def driver() -> None: + long_call = asyncio.create_task( + self._invoke_async('_async_long_udf', [('long',)]), + ) + # Spin until the long request has actually started executing + # on the dispatch thread, so any new dispatch we fire after + # this point is genuinely "during" an in-flight request. + for _ in range(200): + await asyncio.sleep(0.01) + with _dispatch_observation_lock: + if 'long' in _dispatch_started_at: + break + self.assertIn('long', _dispatch_started_at) + + t_call = time.monotonic() + await self._invoke_async('_async_record_udf', [('quick',)]) + t_returned = time.monotonic() + await long_call + + asyncio.run(driver()) + + with _dispatch_observation_lock: + long_started = _dispatch_started_at['long'] + long_finished = _dispatch_finished_at['long'] + quick_started = _dispatch_started_at['quick'] + quick_finished = _dispatch_finished_at['quick'] + + # quick must have started AFTER long started (it was fired later) + # but BEFORE long finished, and itself finished before long did. + self.assertGreater(quick_started, long_started) + self.assertLess(quick_started, long_finished) + self.assertLess(quick_finished, long_finished) + + # Sanity: the long UDF body really did span the long sleep. + self.assertGreaterEqual(long_finished - long_started, 0.9) + + +if __name__ == '__main__': + unittest.main()