From b644ddda352c4833cecad9f64b56908d0f4a7259 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 15:32:32 -0700 Subject: [PATCH 1/4] [INITIAL] Update [ghstack-poisoned] --- extension/llm/server/README.md | 29 ++ .../conformance/test_openai_contract.py | 207 +++++++++ extension/llm/server/python/README.md | 129 ++++++ extension/llm/server/python/runner_pool.py | 300 +++++++++++++ extension/llm/server/python/server.py | 146 +++++++ extension/llm/server/python/serving_chat.py | 405 ++++++++++++++++++ extension/llm/server/python/tests/conftest.py | 101 +++++ .../llm/server/python/tests/test_contract.py | 333 ++++++++++++++ .../server/python/tests/test_runner_pool.py | 276 ++++++++++++ .../python/tests/test_sampling_params.py | 269 ++++++++++++ .../llm/server/python/tests/test_template.py | 139 ++++++ .../server/python/tests/test_tool_calls.py | 155 +++++++ 12 files changed, 2489 insertions(+) create mode 100644 extension/llm/server/README.md create mode 100644 extension/llm/server/conformance/test_openai_contract.py create mode 100644 extension/llm/server/python/README.md create mode 100644 extension/llm/server/python/runner_pool.py create mode 100644 extension/llm/server/python/server.py create mode 100644 extension/llm/server/python/serving_chat.py create mode 100644 extension/llm/server/python/tests/conftest.py create mode 100644 extension/llm/server/python/tests/test_contract.py create mode 100644 extension/llm/server/python/tests/test_runner_pool.py create mode 100644 extension/llm/server/python/tests/test_sampling_params.py create mode 100644 extension/llm/server/python/tests/test_template.py create mode 100644 extension/llm/server/python/tests/test_tool_calls.py diff --git a/extension/llm/server/README.md b/extension/llm/server/README.md new file mode 100644 index 00000000000..7d2d2393e12 --- /dev/null +++ b/extension/llm/server/README.md @@ -0,0 +1,29 @@ +# ExecuTorch LLM Server + +OpenAI-compatible serving for ExecuTorch LLMs, so any OpenAI-compatible agent +harness (pi, opencode, ...) can use ExecuTorch as a local backend. + +``` +extension/llm/server/ + spec/ # language-neutral OpenAI contract ExecuTorch targets + conformance/ # one test suite every language server must pass + python/ # Python server implementation (current) + # cpp/ # future: no-Python single-binary server +``` + +Why this layout: the OpenAI contract is identical across languages, so the +**spec** and **conformance** suite are shared, and each language gets its own +implementation directory. The real cross-language reuse comes from the C++ +`TextLLMRunner` (and the planned `Session` primitives) underneath — each server +is a thin protocol shell over that engine. See `python/README.md` to run it. + +Status: experimental, reliability-first and deliberately narrow. Implemented: +`/health`, `/v1/models`, `/v1/chat/completions` (streaming + non-streaming), +Hugging Face chat templates (`--hf-tokenizer`), `temperature` / `max_tokens` / +`max_completion_tokens` / `stop`, Hermes/Qwen tool calling +(`...`, complete calls only) with `tool_choice="none"`, +structured API errors, cancellation, and an opt-in conservative per-runner KV +prefix cache (`--enable-prefix-cache`). Unsupported params (`top_p`, `seed`, +`n>1`, `reasoning_effort`) are rejected with a structured 400 rather than +silently ignored. See `python/README.md` to run it and `spec/README.md` for the +exact contract. diff --git a/extension/llm/server/conformance/test_openai_contract.py b/extension/llm/server/conformance/test_openai_contract.py new file mode 100644 index 00000000000..3347126b5f4 --- /dev/null +++ b/extension/llm/server/conformance/test_openai_contract.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Language-neutral OpenAI-contract conformance tests. + +Runs against any base URL (ExecuTorch, llama.cpp, mlx-lm, ...) so every server +implementation is validated against one shared spec. Point it at a running +server: + + OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest test_openai_contract.py + +Skips automatically if no server is reachable. +""" + +import json +import os +import urllib.error +import urllib.request + +import pytest + +BASE_URL = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8000/v1").rstrip("/") +MODEL = os.environ.get("OPENAI_MODEL", "executorch") + + +def _post(path: str, body: dict, stream: bool = False): + req = urllib.request.Request( + f"{BASE_URL}{path}", + data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, + method="POST", + ) + return urllib.request.urlopen(req, timeout=120) + + +def _server_up() -> bool: + try: + urllib.request.urlopen(f"{BASE_URL}/models", timeout=5) + return True + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _server_up(), reason="no OpenAI server at OPENAI_BASE_URL" +) + + +def test_models_listing(): + with urllib.request.urlopen(f"{BASE_URL}/models", timeout=10) as r: + data = json.loads(r.read()) + assert data["object"] == "list" + assert any("id" in m for m in data["data"]) + + +def test_chat_completion_nonstreaming(): + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "Say hello in one word."}], + "max_tokens": 16, + "temperature": 0.0, + } + with _post("/chat/completions", body) as r: + data = json.loads(r.read()) + assert data["object"] == "chat.completion" + assert data["choices"][0]["message"]["role"] == "assistant" + assert isinstance(data["choices"][0]["message"]["content"], str) + assert data["choices"][0]["finish_reason"] is not None + + +def test_chat_completion_streaming(): + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "Count to three."}], + "max_tokens": 32, + "stream": True, + } + saw_role = saw_content = saw_done = False + with _post("/chat/completions", body, stream=True) as r: + for raw in r: + line = raw.decode().strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + saw_done = True + break + chunk = json.loads(payload) + assert chunk["object"] == "chat.completion.chunk" + delta = chunk["choices"][0]["delta"] + saw_role = saw_role or delta.get("role") == "assistant" + saw_content = saw_content or bool(delta.get("content")) + assert saw_role and saw_content and saw_done + + +def test_multibyte_streaming_integrity(): + # Byte-level BPE can split a multi-byte character across tokens; the stream + # must reassemble it, not abort with a UTF-8 decode error. + body = { + "model": MODEL, + "messages": [ + {"role": "user", "content": "Reply with exactly: 你好世界 🌍 café"} + ], + "max_tokens": 32, + "temperature": 0.0, + "stream": True, + } + content, saw_done, saw_error = "", False, False + with _post("/chat/completions", body, stream=True) as r: + for raw in r: + line = raw.decode().strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + saw_done = True + break + chunk = json.loads(payload) + if "error" in chunk: + saw_error = True + content += ( + chunk["choices"][0]["delta"].get("content", "") + if chunk.get("choices") + else "" + ) + assert saw_done and not saw_error + assert isinstance(content, str) and content # reassembled, valid UTF-8 + + +def test_usage_chunk_in_stream(): + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "Say hi."}], + "max_tokens": 16, + "stream": True, + "stream_options": {"include_usage": True}, + } + usage = None + with _post("/chat/completions", body, stream=True) as r: + for raw in r: + line = raw.decode().strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + break + chunk = json.loads(payload) + if chunk.get("usage"): + usage = chunk["usage"] + assert usage is not None, "no usage chunk emitted with include_usage" + assert usage["prompt_tokens"] > 0 and usage["completion_tokens"] > 0 + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a city.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, +} + + +def test_tool_call_response_shape(): + body = { + "model": MODEL, + "messages": [ + {"role": "user", "content": "What is the weather in Paris? Use the tool."} + ], + "tools": [WEATHER_TOOL], + "max_tokens": 128, + "temperature": 0.0, + } + with _post("/chat/completions", body) as r: + data = json.loads(r.read()) + calls = data["choices"][0]["message"].get("tool_calls") + assert calls, "expected tool_calls in response" + tc = calls[0] + assert tc["type"] == "function" + assert tc["id"] + assert tc["function"]["name"] == "get_weather" + json.loads(tc["function"]["arguments"]) # arguments is a JSON string + assert data["choices"][0]["finish_reason"] == "tool_calls" + + +def test_error_body_shape(): + # Over-long prompt -> structured 400 (OpenAI error envelope), not a 500/drop. + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "word " * 40000}], + "max_tokens": 8, + } + try: + _post("/chat/completions", body) + raise AssertionError("expected an HTTP error for over-long prompt") + except urllib.error.HTTPError as e: + assert 400 <= e.code < 500 + err = json.loads(e.read())["error"] + assert err["message"] and err["type"] diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md new file mode 100644 index 00000000000..1d535e9fdc7 --- /dev/null +++ b/extension/llm/server/python/README.md @@ -0,0 +1,129 @@ +# ExecuTorch LLM Server — Python + +A thin OpenAI-compatible HTTP server over ExecuTorch's `TextLLMRunner`. + +## Install + +```bash +pip install -r requirements.txt +# transformers is optional but recommended for model-correct chat templates +pip install transformers +``` + +Requires an ExecuTorch build with the LLM runner pybindings +(`EXECUTORCH_BUILD_PYBIND=ON`) so `executorch.extension.llm.runner` imports. + +### Model & runtime requirements + +LLM `.pte` files exported via `export_llm` use ExecuTorch custom/quantized ops: +`use_sdpa_with_kv_cache` → `llama::custom_sdpa`, and quantized exports +(`embedding_quantize`, `8da4w`, ...) → `quantized_decomposed` ops. These are the +Python-runtime equivalent of the C++ build flags in the canonical +[Llama README](../../../../examples/models/llama/README.md) +(`-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON`, +`-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON`). **The server registers them +automatically** (imports `executorch.extension.llm.custom_ops.custom_ops` and +`executorch.kernels.quantized` before constructing the runner); without them the +runner fails with `Missing operator ... load_method('forward') failed`. + +Tokenizer: pass the model's tokenizer — `tokenizer.json` (HF, e.g. Qwen3) or +`tokenizer.model` (Llama); the runner auto-detects. If you see an RE2 lookahead +warning it falls back to PCRE2 and still works (build with +`-DSUPPORT_REGEX_LOOKAHEAD=ON` for the native regex path). + +## Run + +```bash +python -m executorch.extension.llm.server.python.server \ + --model-path /path/to/model.pte \ + --tokenizer-path /path/to/tokenizer.bin \ + --hf-tokenizer Qwen/Qwen2.5-Coder-7B-Instruct \ + --model-id qwen2.5-coder \ + --enable-prefix-cache \ + --host 127.0.0.1 --port 8000 +``` + +`--hf-tokenizer` is **required** (it applies the model's real `chat_template`) +unless you pass `--allow-chatml-fallback` to opt into approximate generic ChatML +— which is wrong for many instruct/tool models and can't reproduce controls like +`enable_thinking`. + +Key flags: + +| Flag | Effect | +|------|--------| +| `--hf-tokenizer` | model's HF chat template (required unless fallback) | +| `--allow-chatml-fallback` | opt into approximate ChatML when no HF tokenizer | +| `--no-think` | default `enable_thinking=False` (e.g. Qwen3) | +| `--max-context N` | reject over-long prompts with 400 instead of failing mid-gen | +| `--num-runners N` | *requested* physical sessions (each = one KV cache, N × memory); the actual count is clamped by the engine's `serving_capacity()` — XNNPACK self-contained `.pte` is single-slot, so N>1 is clamped to 1 and extra requests queue | +| `--enable-prefix-cache` | opt-in turn-to-turn KV reuse (requires `--hf-tokenizer`; runs the LLMEngine/LLMSession path) | + +## Use from an agent harness + +- **opencode** (`opencode.json`): + ```json + { "provider": { "executorch": { + "npm": "@ai-sdk/openai-compatible", + "options": { "baseURL": "http://127.0.0.1:8000/v1" }, + "models": { "qwen2.5-coder": { "name": "Qwen2.5-Coder (ExecuTorch)" } } } } } + ``` +- **pi** (`~/.pi/agent/models.json`): + ```json + { "providers": { "executorch": { + "baseUrl": "http://127.0.0.1:8000/v1", "api": "openai-completions", + "apiKey": "x", "models": [ { "id": "qwen2.5-coder" } ] } } } + ``` + +## Validate + +Two layers, both contract-focused (assert on the wire, not internals): + +```bash +# 1. Hermetic unit tests — fake engine, no model/GPU, fast (CI-friendly). +pip install pytest httpx +pytest tests/ + +# 2. Conformance — black-box, against a LIVE server (real model, or llama.cpp/mlx-lm). +OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest ../conformance/test_openai_contract.py +``` + +`tests/` swaps in a `FakeRunner` via `RunnerPool(runner_factory=...)`, so the real +server/protocol/streaming code is tested over HTTP without a `.pte`. + +## Architecture + +Control plane (this dir, Python): server, OpenAI protocol, chat templating, +session routing/streaming, and prefix-reuse *policy*. Data plane (C++): the +`LLMEngine`/`LLMSession` API owns token stepping and KV mutation (prefill/decode/ +sampling) and releases the GIL. Python depends on `LLMEngine`/`LLMSession`, not on +`TextLLMRunner` token-step internals (`TextLLMRunner` is a legacy/direct runner +and a C++ implementation detail behind the session adapter). How many physical +sessions can exist without multiplying model memory is decided by +`serving_capacity()`, not by `--num-runners`. Tensor data never crosses into +Python element-wise. + +| File | Role | +|------|------| +| `server.py` | FastAPI app, routes, CLI entrypoint | +| `protocol.py` | OpenAI request/response schemas | +| `chat_template.py` | messages (+tools) → prompt string | +| `runner_pool.py` | session pool + serving-capacity admission + affinity routing + async streaming bridge | +| `serving_chat.py` | `/v1/chat/completions` (streaming + non-streaming, stop, tools) | +| `prefix_cache.py` | turn-to-turn KV prefix-reuse policy over an `LLMSession` (opt-in) | +| `tool_parsers/` | Hermes/Qwen `` parser only | + +## Scope & caveats + +Deliberately narrow (reliability-first): Hermes/Qwen tool calling only; +unsupported sampling params are rejected, not ignored. `--num-runners` is a +*request*, not a guarantee — the engine's `serving_capacity()` is authoritative, +and an XNNPACK self-contained `.pte` is conservative **single-slot** for v1 +(packed weights may be per-method-instance, so extra physical sessions would +duplicate model memory): N>1 is clamped to 1 and concurrent requests queue on the +resident session. The engine serializes backend execution across sessions (op +kernels aren't assumed thread-safe — this is also what fixed the multi-runner +heap corruption). Prefix cache requires the LLMSession/engine path +(`--enable-prefix-cache` + `--hf-tokenizer`). Weight sharing across physical +sessions on a backend that supports it (e.g. CUDA/AOTI), adaptive thinking, and +multi-session subagents are future work. diff --git a/extension/llm/server/python/runner_pool.py b/extension/llm/server/python/runner_pool.py new file mode 100644 index 00000000000..8a85c6c12c0 --- /dev/null +++ b/extension/llm/server/python/runner_pool.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pool of runner sessions + the streaming bridge. + +Each pooled item exposes a uniform ``generate(prompt, config, token_cb, +stats_cb)`` / ``stop()`` surface. With a tokenizer it's a PrefixCachingSession +(turn-to-turn KV prefix reuse via seek/prefill_tokens); without one it's a +_StatelessRunner. + +Session lifecycle: sessions are conversation-scoped, not request-scoped. The +pool keeps them warm for the process lifetime — acquire(prompt) routes a request +to the idle session whose KV already holds the longest matching prefix, and +release returns it to the pool with KV intact so the conversation's next turn +reuses it (concurrent conversations keep their caches instead of round-robin +eviction). A cache session is reset only on an unrecoverable error, and torn +down at shutdown — never reset per request, which would discard the cache reuse +exists to exploit. Stateless mode (no tokenizer) is the deliberate exception: +_StatelessRunner resets before each request because it does no prefix reuse. +There is no idle-eviction policy yet — N is fixed at construction. + +A pool of N gives concurrency (the pybindings release the GIL during +prefill/decode). The cost is N x the per-session KV cache, which dominates +memory (e.g. ~7.5 GB at 32K context for a 0.6B model vs ~0.5 GB of weights), so +N is bounded by RAM, not compute. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import AsyncIterator, Callable, Optional + +from executorch.extension.llm.runner import GenerationConfig, TextLLMRunner + +from .prefix_cache import longest_common_prefix, PrefixCachingSession + +logger = logging.getLogger(__name__) + +_SENTINEL = object() +_MAX_CONTEXT_KEY = "get_max_context_len" +_MAX_SEQ_KEY = "get_max_seq_len" + + +def _register_llm_ops() -> None: + """Register ExecuTorch LLM custom + quantized ops in this process. + + LLM .pte files exported with use_sdpa_with_kv_cache use llama::custom_sdpa, + and quantized exports (e.g. embedding_quantize / 8da4w) use + quantized_decomposed ops. The C++ runners link these; the Python runtime + must import them or load_method('forward') fails with "Missing operator". + Harmless if a build has them statically registered. + """ + try: + import executorch.extension.llm.custom_ops.custom_ops # noqa: F401 + except Exception as e: # noqa: BLE001 + logger.debug("custom_ops not imported (%s); assuming statically linked", e) + try: + import executorch.kernels.quantized # noqa: F401 + except Exception as e: # noqa: BLE001 + logger.debug( + "quantized kernels not imported (%s); assuming statically linked", e + ) + + +@dataclass +class GenStats: + prompt_tokens: int = 0 + completion_tokens: int = 0 + + +def _admit_session_count(requested, engine, production, allow_weight_duplication): + """How many *physical* sessions to create. Bounded by what the backend hosts + without duplicating packed weights (engine.serving_capacity()), unless the + operator explicitly opts into duplication. XNNPACK on a self-contained .pte + is single-slot (1): N logical requests queue on one physical session rather + than loading N runtimes. Without an engine, production also forces 1 (a + standalone TextLLMRunner can't share weights or run concurrent backend + calls safely).""" + n = max(1, requested) + if engine is not None: + # The escape hatch relaxes ONLY the engine's weight-duplication clamp — + # the engine still serializes backend execution internally, so N sessions + # are safe (just memory-costly). + if allow_weight_duplication: + return n + cap = int( + engine.serving_capacity().get( + "max_physical_sessions_without_weight_duplication", 1 + ) + ) + limit = cap if cap > 0 else 1 + if n > limit: + logger.warning( + "Engine hosts %d physical session(s) without weight duplication; " + "clamping num_runners %d->%d. Concurrent requests queue on the " + "resident session(s).", + limit, + n, + limit, + ) + return limit + return n + # No engine: standalone TextLLMRunner. N>1 in production is unsafe REGARDLESS + # of weight-duplication willingness — concurrent backend calls into separate + # Modules corrupt the heap (no engine-owned serialization), so the escape + # hatch does NOT relax this. Fix thread-safety or use the engine path first. + if production and n > 1: + logger.warning( + "No shared-weight engine (no --hf-tokenizer, or data_path set); forcing " + "num_runners=1 — %d standalone runners would duplicate weights and run " + "unsafe concurrent backend calls.", + n, + ) + return 1 + return n + + +class _StatelessRunner: + """Resets the KV cache before each request (no prefix reuse). Used when no + tokenizer is available to do token-level prefix matching.""" + + def __init__(self, runner): + self._runner = runner + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + self._runner.reset() + self._runner.generate(prompt, config, token_callback, stats_callback) + + def stop(self): + self._runner.stop() + + +class RunnerPool: + def __init__( + self, + model_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + data_path: Optional[str] = None, + num_runners: int = 1, + runner_factory: Optional[Callable[[], object]] = None, + tokenizer: Optional[object] = None, + allow_weight_duplication_for_parallel_runners: bool = False, + ): + # runner_factory is the test/extensibility seam: tests inject a fake. + # Production wiring: + # - With a tokenizer (prefix cache enabled, --hf-tokenizer) and engine + # support: one LLMEngine; each session reuses the engine resources and + # backend execution is serialized. The factory yields an LLMSession + # that PrefixCachingSession drives via prefill_tokens + decode_one + # (exact ids -> turn-to-turn reuse incl. completions). + # - Without a tokenizer (or with .ptd): the standalone TextLLMRunner + # generate() path (no shared weights / no engine serialization). + production = runner_factory is None + self._engine = None + self._max_context_len = None + self._max_seq_len = None + if runner_factory is None: + _register_llm_ops() + if tokenizer is not None and data_path is None: + from executorch.extension.llm.runner import LLMEngine + + self._engine = LLMEngine( + model_path=model_path, tokenizer_path=tokenizer_path + ) + metadata = self._engine.metadata() + self._max_context_len = metadata.get(_MAX_CONTEXT_KEY) + self._max_seq_len = metadata.get(_MAX_SEQ_KEY) + _engine = self._engine + + def runner_factory(): + return _engine.create_session() + + else: + if tokenizer is not None: + logger.warning( + "Prefix cache requested, but LLMEngine is unavailable for this " + "artifact path (for example data_path/.ptd). Falling back to " + "stateless TextLLMRunner; token-step APIs are exposed only " + "through LLMSession." + ) + + def runner_factory(): + return TextLLMRunner( + model_path=model_path, + tokenizer_path=tokenizer_path, + data_path=data_path, + ) + + def make_session(index): + runner = runner_factory() + if tokenizer is not None and (self._engine is not None or not production): + # Drive an LLMSession-shaped object via decode_one with prefix + # reuse. Production raw TextLLMRunner fallback deliberately does + # not use this path; token-step pybinds live only on LLMSession. + return PrefixCachingSession( + runner, + tokenizer, + index=index, + max_context_len=self._max_context_len, + max_seq_len=self._max_seq_len, + ) + return _StatelessRunner(runner) + + n = _admit_session_count( + num_runners, + self._engine, + production, + allow_weight_duplication_for_parallel_runners, + ) + self._tokenizer = tokenizer + self._sessions = [make_session(i) for i in range(n)] + self._busy = [False] * n + self._cond = asyncio.Condition() + self._executor = ThreadPoolExecutor(max_workers=n) + + def _pick(self, prompt: str) -> int: + """Index of an idle session, preferring the one whose KV already holds the + longest token prefix of `prompt` (so a conversation's next turn lands on + the runner that can reuse its cache). Tie -> emptiest cache, to avoid + evicting a longer cache that likely belongs to another live conversation.""" + idle = [i for i, b in enumerate(self._busy) if not b] + if self._tokenizer is None or not prompt: + return idle[0] + try: + pids = self._tokenizer.encode(prompt, add_special_tokens=False) + except ( + Exception + ): # noqa: BLE001 - routing is best-effort; fall back to any idle + return idle[0] + + def key(i: int): + cached = getattr(self._sessions[i], "cached_tokens", None) or [] + return (longest_common_prefix(cached, pids), -len(cached)) + + return max(idle, key=key) + + @asynccontextmanager + async def acquire(self, prompt: str = ""): + async with self._cond: + while all(self._busy): + await self._cond.wait() + idx = self._pick(prompt) + self._busy[idx] = True + try: + yield self._sessions[idx] + finally: + async with self._cond: + self._busy[idx] = False + self._cond.notify() + + async def generate_stream( + self, + runner, + prompt: str, + config: GenerationConfig, + stats: Optional[GenStats] = None, + ) -> AsyncIterator[str]: + """Yield generated text tokens. If `stats` is given it's filled in place + with token counts (per-request, so concurrent streams don't race).""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + out_stats = stats if stats is not None else GenStats() + + def token_cb(token: str) -> None: + loop.call_soon_threadsafe(queue.put_nowait, token) + + def stats_cb(s) -> None: + out_stats.prompt_tokens = s.num_prompt_tokens + out_stats.completion_tokens = s.num_generated_tokens + + def run() -> None: + try: + # `runner` is a session wrapper: PrefixCachingSession reuses the + # shared prefix; _StatelessRunner resets first. Cache policy lives + # in the wrapper, not here. + runner.generate(prompt, config, token_cb, stats_cb) + except Exception as e: # noqa: BLE001 - surface to the stream consumer + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL) + + fut = loop.run_in_executor(self._executor, run) + try: + while True: + item = await queue.get() + if item is _SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + except asyncio.CancelledError: + runner.stop() + raise + finally: + await fut diff --git a/extension/llm/server/python/server.py b/extension/llm/server/python/server.py new file mode 100644 index 00000000000..3b0d06b23f2 --- /dev/null +++ b/extension/llm/server/python/server.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible HTTP server for ExecuTorch LLMs. + +Point any OpenAI-compatible agent harness (pi, opencode, ...) at +``http://:/v1``. + +Example: + python -m executorch.extension.llm.server.python.server \\ + --model-path model.pte --tokenizer-path tokenizer.bin \\ + --hf-tokenizer Qwen/Qwen2.5-Coder-7B-Instruct --model-id qwen2.5-coder +""" + +import argparse +import logging + +from fastapi import FastAPI +from fastapi.responses import JSONResponse, StreamingResponse + +from .chat_template import ChatTemplate +from .errors import APIError +from .protocol import ChatCompletionRequest, ModelCard, ModelList +from .runner_pool import RunnerPool +from .serving_chat import ServingChat +from .tool_parsers import HermesDetector + +logger = logging.getLogger(__name__) + + +def build_app(serving: ServingChat, model_id: str) -> FastAPI: + app = FastAPI(title="ExecuTorch LLM Server") + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/v1/models") + async def list_models() -> ModelList: + return ModelList(data=[ModelCard(id=model_id)]) + + @app.post("/v1/chat/completions") + async def chat_completions(req: ChatCompletionRequest): + # Typed param → FastAPI validates the body and returns 422 on bad input. + # APIError (e.g. context_length_exceeded) → structured 4xx/5xx, never a + # dropped connection. Mid-stream failures are handled inside the stream. + try: + result = await serving.create(req) + except APIError as e: + return JSONResponse(e.body(), status_code=e.status) + if req.stream: + return StreamingResponse(result, media_type="text/event-stream") + return JSONResponse(result.model_dump(exclude_none=True)) + + return app + + +def main() -> None: + p = argparse.ArgumentParser(description="ExecuTorch OpenAI-compatible LLM server") + p.add_argument("--model-path", required=True, help="Path to the .pte model") + p.add_argument("--tokenizer-path", required=True, help="Path to the tokenizer") + p.add_argument("--data-path", default=None, help="Optional .ptd weights file") + p.add_argument( + "--hf-tokenizer", + default=None, + help="HF tokenizer id/dir for model-correct chat templating (required unless " + "--allow-chatml-fallback). Also required for --enable-prefix-cache.", + ) + p.add_argument( + "--allow-chatml-fallback", + action="store_true", + help="Allow approximate generic ChatML templating when --hf-tokenizer is absent. " + "Off by default: the fallback can't reproduce model-specific controls.", + ) + p.add_argument( + "--model-id", default="executorch", help="Model id reported on /v1/models" + ) + p.add_argument( + "--no-think", + action="store_true", + help="Default reasoning off (sends enable_thinking=False to the chat template, " + "e.g. Qwen3). Per-request chat_template_kwargs still override this.", + ) + p.add_argument( + "--max-context", + type=int, + default=None, + help="Model context window; if set (and a tokenizer is available), prompts that " + "exceed it are rejected with 400 context_length_exceeded instead of failing mid-generation. " + "Set this to match the value used at export.", + ) + p.add_argument( + "--num-runners", type=int, default=1, help="KV-cache instances (N x memory)" + ) + p.add_argument( + "--enable-prefix-cache", + action="store_true", + help="Enable conservative per-runner turn-to-turn KV prefix reuse. Off by default; " + "requires --hf-tokenizer and a non-sliding-window model (falls back to full prefill " + "on any reuse failure).", + ) + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, default=8000) + args = p.parse_args() + logging.basicConfig(level=logging.INFO) + + if args.enable_prefix_cache and not args.hf_tokenizer: + p.error( + "--enable-prefix-cache requires --hf-tokenizer (token-level prefix matching)." + ) + + default_template_kwargs = {"enable_thinking": False} if args.no_think else None + # Requires --hf-tokenizer unless --allow-chatml-fallback (raises otherwise). + template = ChatTemplate( + args.hf_tokenizer, + default_template_kwargs=default_template_kwargs, + allow_fallback=args.allow_chatml_fallback, + ) + cache_tokenizer = template.tokenizer() if args.enable_prefix_cache else None + if cache_tokenizer is not None: + logger.info("KV prefix caching enabled (conservative, per-runner).") + pool = RunnerPool( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + data_path=args.data_path, + num_runners=args.num_runners, + tokenizer=cache_tokenizer, + ) + serving = ServingChat( + pool, + template, + args.model_id, + max_context=args.max_context, + tool_detector_cls=HermesDetector, + ) + + import uvicorn # imported here so build_app() is usable without the ASGI server + + uvicorn.run(build_app(serving, args.model_id), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py new file mode 100644 index 00000000000..8bf2779a945 --- /dev/null +++ b/extension/llm/server/python/serving_chat.py @@ -0,0 +1,405 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""/v1/chat/completions handler: builds prompts, drives the runner, and emits +OpenAI-shaped responses (streaming and non-streaming).""" + +import json +import logging +import math +from typing import AsyncIterator, Optional + +from executorch.extension.llm.runner import GenerationConfig + +from .chat_template import ChatTemplate +from .errors import APIError, ContextLengthExceeded, GenerationError + +logger = logging.getLogger(__name__) +from .protocol import ( + _new_id, + ChatCompletionChunk, + ChatCompletionRequest, + ChatCompletionResponse, + Choice, + ChunkChoice, + DeltaMessage, + FunctionCall, + ResponseMessage, + ToolCall, + Usage, +) +from .runner_pool import GenStats, RunnerPool +from .tool_parsers import HermesDetector, ToolCallItem + + +def _earliest_stop(text: str, stops: list[str]) -> Optional[int]: + """Index of the earliest special-token occurrence in `text`, or None.""" + best = None + for s in stops: + i = text.find(s) + if i != -1 and (best is None or i < best): + best = i + return best + + +class ServingChat: + def __init__( + self, + pool: RunnerPool, + template: ChatTemplate, + model_id: str, + max_context: Optional[int] = None, + tool_detector_cls: Optional[type[HermesDetector]] = None, + ): + self._pool = pool + self._template = template + self._model_id = model_id + self._max_context = max_context + # Detector CLASS; a fresh instance is created per request so streaming + # state is never shared across concurrent requests. + self._tool_detector_cls = tool_detector_cls + # Special tokens (e.g. <|im_end|>) the runner decodes to text; we cut the + # visible content at the first one so they don't leak into responses. + self._stops = template.special_tokens() + + @staticmethod + def _tool_names(req: ChatCompletionRequest) -> set[str]: + names = set() + for t in req.tools or []: + fn = t.get("function", {}) if isinstance(t, dict) else {} + if fn.get("name"): + names.add(fn["name"]) + return names + + def _strip_specials(self, text: str) -> str: + cut = _earliest_stop(text, self._stops) + return text[:cut] if cut is not None else text + + @staticmethod + def _to_openai_tool_call(item: ToolCallItem) -> ToolCall: + return ToolCall( + index=item.tool_index, + id=_new_id("call"), + type="function", + function=FunctionCall(name=item.name, arguments=item.arguments), + ) + + def _tools_active(self, req: ChatCompletionRequest) -> bool: + # tool_choice="none" disables tools even when the client sends them. + return bool(self._tool_detector_cls and req.tools and req.tool_choice != "none") + + @staticmethod + def _request_stops(req: ChatCompletionRequest) -> list[str]: + s = req.stop + if not s: + return [] + return [s] if isinstance(s, str) else [x for x in s if x] + + @staticmethod + def _apply_stop(text: str, stops: list[str]) -> str: + """Truncate at the earliest stop string (the stop itself is excluded).""" + cut = _earliest_stop(text, stops) + return text[:cut] if cut is not None else text + + def _truncate_raw(self, text: str, req: ChatCompletionRequest) -> str: + """Cut raw model output at the earliest special token or request stop + sequence BEFORE tool parsing, so a tool call (or any text) past the stop + boundary is neither parsed nor emitted.""" + return self._apply_stop(text, self._stops + self._request_stops(req)) + + @staticmethod + async def _collect_until_stop(stream: AsyncIterator[str], runner, stops: list[str]): + """Accumulate a buffered (non-streamed) generation into one string, + halting the runner early once a stop string (special token or request + stop) appears, then draining so stats finalize. Returns (text, stopped): + `stopped` lets the caller force finish_reason="stop" even when tokens + queued before the runner observed stop() pushed the count to max_tokens.""" + text = "" + stopped = False + async for tok in stream: + text += tok + if stops and _earliest_stop(text, stops) is not None: + stopped = True + runner.stop() + async for _ in stream: # drain so stats_cb fires + pass + break + return text, stopped + + def _extract_tools(self, req: ChatCompletionRequest, text: str): + """Returns (tool_calls | None, content_text). Falls back to plain text.""" + if self._tools_active(req): + parsed = self._tool_detector_cls().detect_and_parse( + text, self._tool_names(req) + ) + if parsed.calls: + content = self._strip_specials(parsed.normal_text) or None + return [self._to_openai_tool_call(c) for c in parsed.calls], content + text = parsed.normal_text + return None, self._strip_specials(text) + + async def _clean( + self, stream: AsyncIterator[str], stops: list[str], on_stop=None + ) -> AsyncIterator[str]: + # Yield text up to the earliest stop string (special token or request + # `stop`), buffering across tokens so a stop spanning chunks is caught. + # On a hit: optionally stop the runner early, then drain the source so it + # finalizes (usage stats recorded, worker thread joined). + hold = ( + max((len(s) for s in stops), default=1) - 1 + ) # keep a possible partial-stop tail + buf = "" + async for token in stream: + buf += token + cut = _earliest_stop(buf, stops) + if cut is not None: + if cut > 0: + yield buf[:cut] + if on_stop is not None: + on_stop() + async for _ in stream: # drain so stats_cb fires + pass + return + if hold == 0: + yield buf + buf = "" + elif len(buf) > hold: + yield buf[:-hold] + buf = buf[-hold:] + if buf: + yield buf + + def _config(self, req: ChatCompletionRequest) -> GenerationConfig: + kwargs = {"echo": False, "max_new_tokens": req.resolved_max_tokens()} + if req.temperature is not None: + kwargs["temperature"] = req.temperature + return GenerationConfig(**kwargs) + + def _finish_reason( + self, + req: ChatCompletionRequest, + completion_tokens: int, + tool_calls=None, + stopped: bool = False, + ) -> str: + # Precedence: tool call > stop boundary > length. `stopped` (a stop + # sequence / special token was hit) wins over "length" even if tokens + # queued before the runner observed stop() reached max_tokens. + if tool_calls: + return "tool_calls" + if stopped: + return "stop" + mt = req.resolved_max_tokens() + return "length" if mt and mt > 0 and completion_tokens >= mt else "stop" + + @staticmethod + def _reject_invalid_values(req: ChatCompletionRequest) -> None: + """Reject out-of-range values (invalid_value); these take precedence over + the unsupported-parameter error.""" + if req.temperature is not None and ( + not math.isfinite(req.temperature) + or req.temperature < 0.0 + or req.temperature > 2.0 + ): + raise APIError( + 400, + f"temperature must be between 0 and 2 (got {req.temperature}).", + "invalid_request_error", + "invalid_value", + ) + # max_tokens / max_completion_tokens, if given, must be positive integers + # (OpenAI rejects 0 and negatives; our -1 sentinel means "unset/auto"). + for field in ("max_tokens", "max_completion_tokens"): + v = getattr(req, field) + if v is not None and v <= 0: + raise APIError( + 400, + f"{field} must be a positive integer (got {v}).", + "invalid_request_error", + "invalid_value", + ) + + @staticmethod + def _reject_unsupported_params(req: ChatCompletionRequest) -> None: + """Reject params we don't honor rather than silently ignoring them (a + client relying on e.g. top_p/seed/logprobs would otherwise get wrong + behavior). Only the no-op/default value of each passes: top_p exactly + 1.0; penalties 0; response_format type "text"; tool_choice none/auto/ + unset; parallel_tool_calls true (false can't be guaranteed without + constraining); logprobs are not returned at all.""" + rf = req.response_format + flags = [ + (req.n != 1, "n>1"), + (req.top_p is not None and req.top_p != 1.0, "top_p"), + (req.seed is not None, "seed"), + (req.reasoning_effort is not None, "reasoning_effort"), + (bool(req.frequency_penalty), "frequency_penalty"), + (bool(req.presence_penalty), "presence_penalty"), + (req.top_k is not None, "top_k"), + (bool(req.logit_bias), "logit_bias"), + ( + bool(rf) and rf.get("type", "text") != "text", + "response_format (only 'text')", + ), + (bool(req.logprobs), "logprobs"), + (req.top_logprobs is not None, "top_logprobs"), + (req.parallel_tool_calls is False, "parallel_tool_calls=false"), + ( + req.tool_choice not in (None, "none", "auto"), + "tool_choice (only 'none' or 'auto')", + ), + ] + unsupported = [label for cond, label in flags if cond] + if unsupported: + raise APIError( + 400, + f"Unsupported parameter(s): {', '.join(unsupported)}. This server honors " + "temperature, max_tokens/max_completion_tokens, stop, and tools (Hermes).", + "invalid_request_error", + "unsupported_parameter", + ) + + async def create(self, req: ChatCompletionRequest): + self._reject_invalid_values(req) + self._reject_unsupported_params(req) + prompt = self._template.render( + req.messages, tools=req.tools, template_kwargs=req.chat_template_kwargs + ) + # Pre-flight context check: reject cleanly instead of failing mid-generation + # (only possible when a tokenizer is available to count, e.g. --hf-tokenizer). + if self._max_context: + count = self._template.count_tokens(prompt) + if count is not None and count >= self._max_context: + raise ContextLengthExceeded(count, self._max_context) + config = self._config(req) + if req.stream: + return self._stream(req, prompt, config) + return await self._complete(req, prompt, config) + + async def _complete( + self, req: ChatCompletionRequest, prompt: str, config: GenerationConfig + ) -> ChatCompletionResponse: + stats = GenStats() + async with self._pool.acquire(prompt) as runner: + try: + # Collect raw text (markers intact for tool parsing), halting early + # at a stop boundary (special token or request stop). + text, stopped = await self._collect_until_stop( + self._pool.generate_stream(runner, prompt, config, stats), + runner, + self._stops + self._request_stops(req), + ) + except Exception as e: # noqa: BLE001 - surface as a structured API error + raise GenerationError(str(e)) + # Bound the raw output at the first stop/special token BEFORE tool + # parsing, so a call after the stop boundary is not parsed/emitted. + tool_calls, content = self._extract_tools(req, self._truncate_raw(text, req)) + finish = self._finish_reason(req, stats.completion_tokens, tool_calls, stopped) + return ChatCompletionResponse( + model=self._model_id, + choices=[ + Choice( + message=ResponseMessage(content=content, tool_calls=tool_calls), + finish_reason=finish, + ) + ], + usage=Usage( + prompt_tokens=stats.prompt_tokens, + completion_tokens=stats.completion_tokens, + total_tokens=stats.prompt_tokens + stats.completion_tokens, + ), + ) + + async def _stream( + self, req: ChatCompletionRequest, prompt: str, config: GenerationConfig + ) -> AsyncIterator[str]: + cid = _new_id("chatcmpl") + + def chunk(delta: DeltaMessage, finish=None) -> str: + c = ChatCompletionChunk( + id=cid, + model=self._model_id, + choices=[ChunkChoice(delta=delta, finish_reason=finish)], + ) + return f"data: {c.model_dump_json(exclude_none=True)}\n\n" + + yield chunk(DeltaMessage(role="assistant")) + error: Optional[Exception] = None + use_tools = self._tools_active(req) + tool_calls = None + content = None + + stats = GenStats() + stop_hit = [False] # set when a stop boundary is reached (forces finish="stop") + stops = self._stops + self._request_stops(req) + async with self._pool.acquire(prompt) as runner: + try: + if use_tools: + # v1: buffer the (usually short) tool response, parse once. + # Halt early at a stop boundary, and bound the raw output + # BEFORE parsing so post-stop tool calls / text don't leak. + raw, stop_hit[0] = await self._collect_until_stop( + self._pool.generate_stream(runner, prompt, config, stats), + runner, + stops, + ) + tool_calls, content = self._extract_tools( + req, self._truncate_raw(raw, req) + ) + else: + # Plain chat: stream tokens live (best UX), cutting at special + # tokens or request stop sequences and halting early on a hit. + def on_stop(): + stop_hit[0] = True + runner.stop() + + async for token in self._clean( + self._pool.generate_stream(runner, prompt, config, stats), + stops, + on_stop=on_stop, + ): + yield chunk(DeltaMessage(content=token)) + except ( + Exception + ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket + error = e + + if error is not None: + err = { + "message": f"Generation failed: {error}", + "type": "server_error", + "code": None, + } + yield f"data: {json.dumps({'error': err})}\n\n" + yield "data: [DONE]\n\n" + return + + if use_tools: + if content: + yield chunk(DeltaMessage(content=content)) + for tc in tool_calls or []: + yield chunk(DeltaMessage(tool_calls=[tc])) + finish = self._finish_reason( + req, stats.completion_tokens, tool_calls, stop_hit[0] + ) + else: + finish = self._finish_reason( + req, stats.completion_tokens, stopped=stop_hit[0] + ) + yield chunk(DeltaMessage(), finish=finish) + if req.stream_options and req.stream_options.include_usage: + usage_chunk = ChatCompletionChunk( + id=cid, + model=self._model_id, + choices=[], + usage=Usage( + prompt_tokens=stats.prompt_tokens, + completion_tokens=stats.completion_tokens, + total_tokens=stats.prompt_tokens + stats.completion_tokens, + ), + ) + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" + yield "data: [DONE]\n\n" diff --git a/extension/llm/server/python/tests/conftest.py b/extension/llm/server/python/tests/conftest.py new file mode 100644 index 00000000000..f6234e59ddf --- /dev/null +++ b/extension/llm/server/python/tests/conftest.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fixtures for hermetic contract tests. + +We inject a FakeRunner through RunnerPool's runner_factory seam, so the tests +exercise the real server, protocol, templating, and streaming code over the +HTTP boundary with NO model, tokenizer, or GPU. This mirrors ExecuTorch's +fake_llm_executor approach: fake the engine, test the real surface. + +Requires a build where `executorch.extension.llm.runner` imports (for +GenerationConfig), plus fastapi, pydantic, httpx, pytest. +""" + +import pytest + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.runner_pool import RunnerPool +from executorch.extension.llm.server.python.server import build_app +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.tool_parsers import HermesDetector +from fastapi.testclient import TestClient + + +class _FakeStats: + num_prompt_tokens = 5 + num_generated_tokens = 0 + + +class FakeRunner: + """Canned engine: emits fixed tokens, records the config it was given.""" + + def __init__(self, tokens, fail=False): + self._tokens = list(tokens) + self._fail = fail + self.captured_config = None + self.stopped = False + self.reset_count = 0 + + def reset(self): + self.reset_count += 1 + + def stop(self): + self.stopped = True + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + self.captured_config = config + if self._fail: + raise RuntimeError("Generation failed") + for tok in self._tokens: + if token_callback: + token_callback(tok) + if stats_callback: + stats = _FakeStats() + stats.num_generated_tokens = len(self._tokens) + stats_callback(stats) + + +class _FakeTokenizer: + """Minimal stand-in for an HF tokenizer (counting + templating).""" + + all_special_tokens: list = [] + + def __init__(self, prompt_tokens): + self._n = prompt_tokens + + def encode(self, text): + return [0] * self._n + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + return "PROMPT" + + +@pytest.fixture +def make_client(): + def _make( + tokens=("Hello", ", ", "world"), + max_context=None, + prompt_tokens=None, + fail=False, + ): + fake = FakeRunner(tokens, fail=fail) + pool = RunnerPool(runner_factory=lambda: fake, num_runners=1) + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + if prompt_tokens is not None: + template._hf = _FakeTokenizer(prompt_tokens) + serving = ServingChat( + pool, + template, + "test-model", + max_context=max_context, + tool_detector_cls=HermesDetector, + ) + return TestClient(build_app(serving, "test-model")), fake + + return _make diff --git a/extension/llm/server/python/tests/test_contract.py b/extension/llm/server/python/tests/test_contract.py new file mode 100644 index 00000000000..3613d87b42f --- /dev/null +++ b/extension/llm/server/python/tests/test_contract.py @@ -0,0 +1,333 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hermetic OpenAI-contract tests (fake engine, no model/GPU). + +These assert on the public HTTP/wire contract only — response object shapes, +the streaming chunk protocol, status codes — never on internal classes or +methods. Implementation can change freely as long as these pass. +""" + +import json + +import pytest + + +def _sse_chunks(text): + chunks, done = [], False + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + done = True + continue + chunks.append(json.loads(payload)) + return chunks, done + + +def test_health(make_client): + client, _ = make_client() + assert client.get("/health").json() == {"status": "ok"} + + +def test_models_listing_shape(make_client): + client, _ = make_client() + body = client.get("/v1/models").json() + assert body["object"] == "list" + assert body["data"][0]["id"] == "test-model" + assert body["data"][0]["object"] == "model" + + +def test_chat_nonstreaming_shape(make_client): + client, _ = make_client(tokens=["Hello", ", ", "world"]) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "chat.completion" + choice = body["choices"][0] + assert choice["message"]["role"] == "assistant" + assert choice["message"]["content"] == "Hello, world" + assert choice["finish_reason"] == "stop" + for k in ("prompt_tokens", "completion_tokens", "total_tokens"): + assert k in body["usage"] + assert body["usage"]["total_tokens"] == ( + body["usage"]["prompt_tokens"] + body["usage"]["completion_tokens"] + ) + + +def test_chat_streaming_protocol(make_client): + client, _ = make_client(tokens=["a", "b", "c"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + chunks, done = _sse_chunks(resp.text) + assert done, "stream must terminate with data: [DONE]" + assert all(c["object"] == "chat.completion.chunk" for c in chunks) + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks) + assert content == "abc" + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + + +def test_request_params_forwarded_to_generation(make_client): + # Contract behavior: the server must honor max_tokens/temperature. + client, fake = make_client() + client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 7, + "temperature": 0.1, + }, + ) + assert fake.captured_config.max_new_tokens == 7 + assert abs(fake.captured_config.temperature - 0.1) < 1e-6 + + +def test_tools_field_accepted(make_client): + # tools is part of the contract even before parsing is enforced (M2). + client, _ = make_client() + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": [ + {"type": "function", "function": {"name": "f", "parameters": {}}} + ], + }, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["message"]["role"] == "assistant" + + +def test_invalid_request_returns_422(make_client): + client, _ = make_client() + resp = client.post( + "/v1/chat/completions", json={"model": "test-model"} + ) # no messages + assert resp.status_code == 422 + + +def test_special_tokens_stripped_nonstreaming(make_client): + # The runner may decode EOS/special tokens to text; they must not leak. + client, _ = make_client(tokens=["Hello", " world", "<|im_end|>", "LEAK"]) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.json()["choices"][0]["message"]["content"] == "Hello world" + + +def test_usage_populated_when_special_token_cuts_early(make_client): + # Regression: cutting at a special token must not skip usage stats. + client, _ = make_client(tokens=["Hello", "<|im_end|>", "LEAK"]) + body = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ).json() + assert body["choices"][0]["message"]["content"] == "Hello" + assert body["usage"]["completion_tokens"] > 0 + + +def test_special_tokens_stripped_streaming(make_client): + client, _ = make_client(tokens=["Hello", " world", "<|im_end|>", "LEAK"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + chunks, _ = _sse_chunks(resp.text) + content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks) + assert content == "Hello world" + assert "LEAK" not in content and "<|im_end|>" not in content + + +# (1) Context-size-exceeded -> structured 400, both modes (not a dropped socket). +def test_context_length_exceeded_returns_400(make_client): + client, _ = make_client(max_context=2048, prompt_tokens=2940) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "x" * 100}], + }, + ) + assert resp.status_code == 400 + err = resp.json()["error"] + assert err["code"] == "context_length_exceeded" + assert err["type"] == "invalid_request_error" + + +def test_context_length_exceeded_streaming_returns_400(make_client): + client, _ = make_client(max_context=2048, prompt_tokens=2940) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "x"}], + "stream": True, + }, + ) + # Pre-flight check rejects before the stream starts -> clean 400, no SSE. + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "context_length_exceeded" + + +# (1) Mid-generation failure -> structured error, never a dropped connection. +def test_generation_failure_returns_structured_error(make_client): + client, _ = make_client(fail=True) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 500 + assert resp.json()["error"]["type"] == "server_error" + + +def test_generation_failure_streaming_emits_error_event(make_client): + client, _ = make_client(fail=True) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + assert resp.status_code == 200 # headers already sent; error arrives as an event + chunks, done = _sse_chunks(resp.text) + assert done + assert any("error" in c for c in chunks) + + +# (3) finish_reason == "length" when max_tokens is hit. +def test_finish_reason_length_when_max_tokens_hit(make_client): + client, _ = make_client(tokens=["a", "b", "c"]) # 3 generated tokens + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 3, + }, + ) + assert resp.json()["choices"][0]["finish_reason"] == "length" + + +def test_finish_reason_stop_when_under_max_tokens(make_client): + client, _ = make_client(tokens=["a", "b", "c"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 100, + }, + ) + assert resp.json()["choices"][0]["finish_reason"] == "stop" + + +# (4) Error-variant matrix: malformed requests -> consistent 422. +@pytest.mark.parametrize( + "bad_body", + [ + {"model": "m"}, # missing messages + {"model": "m", "messages": "not-a-list"}, + { + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "temperature": "hot", + }, + { + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "stream": "maybe", + }, + {"model": "m", "messages": [{"content": "no role"}]}, + ], +) +def test_invalid_requests_return_422(make_client, bad_body): + client, _ = make_client() + assert client.post("/v1/chat/completions", json=bad_body).status_code == 422 + + +# (6) Streaming usage when stream_options.include_usage is set. +def test_streaming_usage_included(make_client): + client, _ = make_client(tokens=["a", "b"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + chunks, _ = _sse_chunks(resp.text) + usage_chunks = [c for c in chunks if c.get("usage")] + assert usage_chunks, "expected a chunk carrying usage" + u = usage_chunks[-1]["usage"] + assert u["total_tokens"] == u["prompt_tokens"] + u["completion_tokens"] + + +def test_streaming_usage_absent_by_default(make_client): + client, _ = make_client(tokens=["a", "b"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + chunks, _ = _sse_chunks(resp.text) + assert not any(c.get("usage") for c in chunks) + + +# (2) Unicode/multibyte content survives streaming intact. +def test_unicode_streaming_integrity(make_client): + pieces = ["café ", "日本語 ", "😀", "🎉"] + client, _ = make_client(tokens=pieces) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + chunks, _ = _sse_chunks( + resp.text + ) # _sse_chunks uses json.loads -> validates UTF-8/JSON + content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks) + assert content == "".join(pieces) + + +def test_unicode_nonstreaming_integrity(make_client): + pieces = ["café ", "日本語 ", "😀"] + client, _ = make_client(tokens=pieces) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.json()["choices"][0]["message"]["content"] == "".join(pieces) diff --git a/extension/llm/server/python/tests/test_runner_pool.py b/extension/llm/server/python/tests/test_runner_pool.py new file mode 100644 index 00000000000..4491695e151 --- /dev/null +++ b/extension/llm/server/python/tests/test_runner_pool.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pool-level contract tests: abort-on-cancel and concurrency isolation. + +Written with asyncio.run (sync test bodies) to avoid depending on an async +pytest plugin. +""" + +import asyncio +import threading + +from executorch.extension.llm.server.python.prefix_cache import PrefixCachingSession +from executorch.extension.llm.server.python.runner_pool import ( + _admit_session_count, + _StatelessRunner, + RunnerPool, +) + + +class _FakeEngine: + """serving_capacity()-shaped fake.""" + + def __init__(self, max_physical): + self._cap = max_physical + + def serving_capacity(self): + return { + "max_physical_sessions_without_weight_duplication": self._cap, + "estimated_bytes_per_session": 0, + } + + +# Admission clamps physical sessions to the engine's no-duplication capacity. +def test_admit_clamps_to_serving_capacity(): + # Single-slot engine (XNNPACK): 4 requested -> 1 physical session. + assert ( + _admit_session_count( + 4, _FakeEngine(1), production=True, allow_weight_duplication=False + ) + == 1 + ) + # Engine that hosts 4 without duplication: honor the request. + assert ( + _admit_session_count( + 4, _FakeEngine(4), production=True, allow_weight_duplication=False + ) + == 4 + ) + # Request below capacity is untouched. + assert ( + _admit_session_count( + 2, _FakeEngine(4), production=True, allow_weight_duplication=False + ) + == 2 + ) + # Unknown/zero capacity -> conservative 1. + assert ( + _admit_session_count( + 4, _FakeEngine(0), production=True, allow_weight_duplication=False + ) + == 1 + ) + # No engine in production -> force 1 (standalone can't share weights). + assert ( + _admit_session_count(4, None, production=True, allow_weight_duplication=False) + == 1 + ) + # Explicit opt-in relaxes the ENGINE weight-duplication clamp (engine still + # serializes execution, so N sessions are safe). + assert ( + _admit_session_count( + 4, _FakeEngine(1), production=True, allow_weight_duplication=True + ) + == 4 + ) + # ...but opt-in does NOT relax the no-engine standalone safety clamp: + # concurrent backend calls into separate Modules corrupt the heap. + assert ( + _admit_session_count(4, None, production=True, allow_weight_duplication=True) + == 1 + ) + # Injected test factory (production=False, no engine) is left alone. + assert ( + _admit_session_count(4, None, production=False, allow_weight_duplication=False) + == 4 + ) + + +class _BlockingRunner: + """Emits one token, then blocks until stop() is called.""" + + def __init__(self): + self._gate = threading.Event() + self.stopped = False + + def reset(self): + pass + + def stop(self): + self.stopped = True + self._gate.set() + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + if token_callback: + token_callback("TOKEN") + self._gate.wait(timeout=5) + + +class _EchoRunner: + """Emits the prompt back as a single token; used to detect cross-talk.""" + + def reset(self): + pass + + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + if token_callback: + token_callback(prompt) + + +# (7) Client disconnect / cancellation must stop the runner. +def test_cancellation_calls_stop(): + async def scenario(): + runner = _BlockingRunner() + pool = RunnerPool(runner_factory=lambda: runner, num_runners=1) + async with pool.acquire() as r: + agen = pool.generate_stream(r, "p", None).__aiter__() + assert await agen.__anext__() == "TOKEN" # runner now blocking + nxt = asyncio.ensure_future(agen.__anext__()) + await asyncio.sleep(0.05) + nxt.cancel() + try: + await nxt + except asyncio.CancelledError: + pass + for _ in range(100): # let the worker observe stop() + if runner.stopped: + break + await asyncio.sleep(0.02) + assert runner.stopped + + asyncio.run(scenario()) + + +# (8) Concurrent requests don't interleave / corrupt each other. +def test_concurrent_requests_isolated(): + async def scenario(): + pool = RunnerPool(runner_factory=_EchoRunner, num_runners=2) + + async def one(prompt): + async with pool.acquire() as r: + return "".join([t async for t in pool.generate_stream(r, prompt, None)]) + + out = await asyncio.gather(one("AAA"), one("BBB"), one("CCC")) + assert sorted(out) == ["AAA", "BBB", "CCC"] + + asyncio.run(scenario()) + + +class _FakeTok: + def encode(self, text, add_special_tokens=False): + return list(text.encode("utf-8")) + + +class _CachingSession: + """LLMSession-shaped fake: decode_one emits `gen_ids`, EOS on the last.""" + + def __init__(self, gen_ids=(33,)): # 33 -> byte '!' + self.gen_ids = list(gen_ids) + self.seeks = [] + self.prefilled = [] + self._pos = 0 + self._cursor = 0 + + def prefill_tokens(self, ids): + self.prefilled.append(list(ids)) + self._pos += len(ids) + + def decode_one(self, temperature=-1.0): + tid = self.gen_ids[self._cursor] if self._cursor < len(self.gen_ids) else 0 + self._cursor += 1 + self._pos += 1 + return { + "token_id": tid, + "text": bytes([tid % 128]), + "is_eos": self._cursor >= len(self.gen_ids), + } + + def seek(self, p): + self.seeks.append(p) + self._pos = p + self._cursor = 0 + + def position(self): + return self._pos + + def reset(self): + self._pos = 0 + self._cursor = 0 + + def stop(self): + pass + + +# A tokenizer makes the pool wrap sessions in PrefixCachingSession and reuse the +# shared prefix across requests. +def test_pool_prefix_caching_reuses_across_requests(): + async def scenario(): + fake = _CachingSession(gen_ids=[33]) # cache b"abc" + [33] + pool = RunnerPool( + runner_factory=lambda: fake, num_runners=1, tokenizer=_FakeTok() + ) + async with pool.acquire() as obj: + assert isinstance(obj, PrefixCachingSession) + _ = [t async for t in pool.generate_stream(obj, "abc", None)] + # Turn 2 = b"abc!XY"; cache holds b"abc"+[33]=b"abc!" (prompt + the exact + # generated id), so reuse is 4 (incl. the completion) and only b"XY" + # prefills — completion reuse, not just the static prefix. + async with pool.acquire() as obj: + _ = [t async for t in pool.generate_stream(obj, "abc!XY", None)] + assert fake.seeks[-1] == 4 + assert fake.prefilled[-1] == list(b"XY") + + asyncio.run(scenario()) + + +# No tokenizer -> stateless wrapper that resets each request (no reuse). +def test_pool_stateless_without_tokenizer(): + async def scenario(): + resets = {"n": 0} + + class R: + def reset(self): + resets["n"] += 1 + + def generate( + self, prompt, config, token_callback=None, stats_callback=None + ): + if token_callback: + token_callback("x") + + def stop(self): + pass + + pool = RunnerPool(runner_factory=R, num_runners=1) # no tokenizer + async with pool.acquire() as obj: + assert isinstance(obj, _StatelessRunner) + _ = [t async for t in pool.generate_stream(obj, "p", None)] + assert resets["n"] == 1 # stateless path resets before generation + + asyncio.run(scenario()) + + +# M4: acquire(prompt) routes to the idle session whose KV holds the longest +# matching prefix; a new conversation lands elsewhere instead of evicting it. +def test_pool_affinity_routing(): + async def scenario(): + pool = RunnerPool( + runner_factory=_CachingSession, num_runners=2, tokenizer=_FakeTok() + ) + + async with pool.acquire("AAAA") as sA: # conversation A caches on some session + _ = [t async for t in pool.generate_stream(sA, "AAAA", None)] + async with pool.acquire("AAAABB") as s2: # continuation shares "AAAA" + assert s2 is sA # affinity hit: routed back to A's session + _ = [t async for t in pool.generate_stream(s2, "AAAABB", None)] + async with pool.acquire("ZZZZ") as s3: # new conversation, no shared prefix + assert s3 is not sA # routed to the empty session, A's cache preserved + + asyncio.run(scenario()) diff --git a/extension/llm/server/python/tests/test_sampling_params.py b/extension/llm/server/python/tests/test_sampling_params.py new file mode 100644 index 00000000000..fe3166d1bb5 --- /dev/null +++ b/extension/llm/server/python/tests/test_sampling_params.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Contract tests for sampling/control params: stop sequences, n, tool_choice. + +These exercise the real server over the HTTP boundary with a FakeRunner. +""" + +import json + +CALL = ( + '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' +) +WEATHER_TOOL = { + "type": "function", + "function": {"name": "get_weather", "parameters": {}}, +} + + +def _body(**kw): + b = {"model": "test-model", "messages": [{"role": "user", "content": "hi"}]} + b.update(kw) + return b + + +def test_n_greater_than_one_is_rejected(make_client): + client, _ = make_client() + r = client.post("/v1/chat/completions", json=_body(n=2)) + assert r.status_code == 400 + assert r.json()["error"]["type"] == "invalid_request_error" + assert r.json()["error"]["code"] == "unsupported_parameter" + + +def test_unsupported_params_rejected(make_client): + client, _ = make_client() + for param in ( + {"top_p": 0.5}, + {"top_p": 2.0}, # > 1.0 is not a no-op either — only exactly 1.0/unset is + {"seed": 42}, + {"reasoning_effort": "high"}, + {"frequency_penalty": 1.0}, + {"presence_penalty": -0.5}, + {"top_k": 40}, + {"logit_bias": {"123": 5.0}}, + {"response_format": {"type": "json_object"}}, + {"logprobs": True}, + {"top_logprobs": 5}, + {"parallel_tool_calls": False}, + ): + r = client.post("/v1/chat/completions", json=_body(**param)) + assert r.status_code == 400, param + assert r.json()["error"]["code"] == "unsupported_parameter", param + + +def test_nonpositive_max_tokens_rejected(make_client): + # max_tokens=0/-2 must not be silently treated as "unbounded" (the `or` bug); + # 0 and negatives are invalid for both field names. + client, _ = make_client() + for param in ( + {"max_tokens": 0}, + {"max_tokens": -2}, + {"max_completion_tokens": 0}, + {"max_completion_tokens": -2}, + ): + r = client.post("/v1/chat/completions", json=_body(**param)) + assert r.status_code == 400, param + assert r.json()["error"]["type"] == "invalid_request_error", param + + +def test_temperature_range_rejected(make_client): + client, _ = make_client() + for param in ( + {"temperature": -1}, + {"temperature": -0.1}, + {"temperature": 2.1}, + {"temperature": 3}, + ): + r = client.post("/v1/chat/completions", json=_body(**param)) + assert r.status_code == 400, param + assert r.json()["error"]["code"] == "invalid_value", param + + +def test_noop_output_contract_fields_accepted(make_client): + # The default/no-op forms must NOT be rejected (don't break OpenAI clients + # that send them explicitly). + client, _ = make_client() + r = client.post( + "/v1/chat/completions", + json=_body( + response_format={"type": "text"}, + parallel_tool_calls=True, + max_tokens=8, + ), + ) + assert r.status_code == 200 + + +def test_zero_penalties_and_unknown_fields_accepted(make_client): + # frequency/presence_penalty=0.0 are no-ops; unknown non-generation fields + # (user/store/metadata) are ignored, not rejected (don't break OpenAI clients). + client, _ = make_client() + r = client.post( + "/v1/chat/completions", + json=_body( + frequency_penalty=0.0, + presence_penalty=0.0, + user="abc", + store=False, + metadata={"k": "v"}, + max_tokens=8, + ), + ) + assert r.status_code == 200 + + +def test_unsupported_tool_choice_rejected(make_client): + # "required" / a specific-function choice would need constrained decoding to + # force/restrict the call; v1 rejects rather than silently treating as "auto". + client, _ = make_client() + for choice in ( + "required", + {"type": "function", "function": {"name": "get_weather"}}, + ): + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], tool_choice=choice, max_tokens=8), + ) + assert r.status_code == 400, choice + assert r.json()["error"]["code"] == "unsupported_parameter", choice + + +def test_supported_params_accepted(make_client): + # top_p=1.0 (no-op) and temperature/max_tokens must NOT be rejected; neither + # should tool_choice "auto" / "none". + client, _ = make_client() + for temperature in (0.0, 1.0, 2.0): + r = client.post( + "/v1/chat/completions", + json=_body(top_p=1.0, temperature=temperature, max_tokens=8), + ) + assert r.status_code == 200, temperature + for choice in ("auto", "none"): + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], tool_choice=choice, max_tokens=8), + ) + assert r.status_code == 200, choice + + +def test_stop_sequence_truncates_nonstreaming(make_client): + client, _ = make_client(tokens=["Hello ", "world ", "STOP", " ignored"]) + r = client.post("/v1/chat/completions", json=_body(stop=["STOP"], max_tokens=32)) + assert r.status_code == 200 + content = r.json()["choices"][0]["message"]["content"] + assert content == "Hello world " + assert "STOP" not in content + + +def test_stop_sequence_truncates_streaming(make_client): + client, _ = make_client(tokens=["Hello ", "world ", "STOP", " ignored"]) + content = "" + with client.stream( + "POST", "/v1/chat/completions", json=_body(stop=["STOP"], stream=True) + ) as r: + for line in r.iter_lines(): + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + break + delta = json.loads(payload)["choices"][0]["delta"] + content += delta.get("content") or "" + assert content == "Hello world " + assert "STOP" not in content + + +def test_stop_forces_finish_reason_over_length_nonstreaming(make_client): + # 4 tokens emitted (completion reaches max_tokens=4) AND a stop is hit: + # finish_reason must be "stop" (boundary), not "length". + client, _ = make_client(tokens=["a ", "b ", "STOP", " c"]) + body = client.post( + "/v1/chat/completions", json=_body(stop=["STOP"], max_tokens=4) + ).json() + assert body["choices"][0]["finish_reason"] == "stop" + assert "STOP" not in (body["choices"][0]["message"]["content"] or "") + + +def test_stop_forces_finish_reason_over_length_streaming(make_client): + client, _ = make_client(tokens=["a ", "b ", "STOP", " c"]) + finish = None + with client.stream( + "POST", + "/v1/chat/completions", + json=_body(stop=["STOP"], max_tokens=4, stream=True), + ) as r: + for line in r.iter_lines(): + if not line.startswith("data:"): + continue + p = line[len("data:") :].strip() + if p == "[DONE]": + break + fr = json.loads(p)["choices"][0].get("finish_reason") + if fr: + finish = fr + assert finish == "stop" + + +_STOP_THEN_CALL = ( + 'Answer STOP \n{"name": "get_weather", "arguments": {}}\n' +) + + +def test_stop_before_tool_call_nonstreaming(make_client): + # Stop boundary precedes a tool call (in one chunk, so truncation — not just + # early-stop — must catch it): the call must NOT be parsed/emitted. + client, _ = make_client(tokens=[_STOP_THEN_CALL]) + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], stop=["STOP"], max_tokens=64), + ) + msg = r.json()["choices"][0]["message"] + assert msg.get("tool_calls") is None + assert "STOP" not in (msg.get("content") or "") + + +def test_stop_before_tool_call_streaming(make_client): + client, _ = make_client(tokens=[_STOP_THEN_CALL]) + saw_tool, content = False, "" + with client.stream( + "POST", + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], stop=["STOP"], stream=True), + ) as r: + for line in r.iter_lines(): + if not line.startswith("data:"): + continue + p = line[len("data:") :].strip() + if p == "[DONE]": + break + delta = json.loads(p)["choices"][0]["delta"] + if delta.get("tool_calls"): + saw_tool = True + content += delta.get("content") or "" + assert not saw_tool + assert "STOP" not in content + + +def test_tool_choice_none_disables_tools(make_client): + client, _ = make_client(tokens=[CALL]) + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], tool_choice="none", max_tokens=64), + ) + msg = r.json()["choices"][0]["message"] + assert msg.get("tool_calls") is None # tools disabled -> returned as text + assert r.json()["choices"][0]["finish_reason"] != "tool_calls" + + +def test_tool_choice_default_still_parses(make_client): + # Sanity: without tool_choice="none", the same call IS parsed as a tool call. + client, _ = make_client(tokens=[CALL]) + r = client.post( + "/v1/chat/completions", json=_body(tools=[WEATHER_TOOL], max_tokens=64) + ) + calls = r.json()["choices"][0]["message"].get("tool_calls") + assert calls and calls[0]["function"]["name"] == "get_weather" diff --git a/extension/llm/server/python/tests/test_template.py b/extension/llm/server/python/tests/test_template.py new file mode 100644 index 00000000000..5aafca20eb6 --- /dev/null +++ b/extension/llm/server/python/tests/test_template.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Contract tests for chat-template kwargs (e.g. enable_thinking) passthrough.""" + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.protocol import ( + ChatMessage, + FunctionCall, + ToolCall, +) + + +class _FakeHF: + def __init__(self): + self.seen_kwargs = None + self.seen_messages = None + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + self.seen_kwargs = kwargs + self.seen_messages = messages + return "PROMPT" + + +def _template_with_fake(defaults=None): + t = ChatTemplate( + hf_tokenizer_path=None, allow_fallback=True, default_template_kwargs=defaults + ) + fake = _FakeHF() + t._hf = fake + return t, fake + + +def test_default_template_kwargs_applied(): + t, fake = _template_with_fake(defaults={"enable_thinking": False}) + t.render([ChatMessage(role="user", content="hi")]) + assert fake.seen_kwargs == {"enable_thinking": False} + + +def test_per_request_kwargs_override_defaults(): + t, fake = _template_with_fake(defaults={"enable_thinking": False}) + t.render( + [ChatMessage(role="user", content="hi")], + template_kwargs={"enable_thinking": True}, + ) + assert fake.seen_kwargs["enable_thinking"] is True + + +def test_no_kwargs_when_none(): + t, fake = _template_with_fake(defaults=None) + t.render([ChatMessage(role="user", content="hi")]) + assert fake.seen_kwargs == {} + + +def test_fallback_ignores_kwargs_without_hf(): + # No HF tokenizer → ChatML fallback, must not raise on kwargs. + t = ChatTemplate( + hf_tokenizer_path=None, + allow_fallback=True, + default_template_kwargs={"enable_thinking": False}, + ) + out = t.render([ChatMessage(role="user", content="hi")], template_kwargs={"x": 1}) + assert "<|im_start|>user" in out and out.endswith("<|im_start|>assistant\n") + + +# (5) Chat-template behaviors: multi-turn ordering, system message, roles. +def test_multi_turn_order_preserved(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="user", content="first"), + ChatMessage(role="assistant", content="second"), + ChatMessage(role="user", content="third"), + ] + ) + assert out.index("first") < out.index("second") < out.index("third") + assert out.endswith("<|im_start|>assistant\n") # generation prompt appended + + +def test_system_message_rendered(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="system", content="You are terse."), + ChatMessage(role="user", content="hi"), + ] + ) + assert "<|im_start|>system\nYou are terse." in out + + +def test_each_role_labeled(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="system", content="s"), + ChatMessage(role="user", content="u"), + ChatMessage(role="assistant", content="a"), + ] + ) + for role in ("system", "user", "assistant"): + assert f"<|im_start|>{role}" in out + + +# Tool round-trip: a turn-2 request (assistant tool_call + tool result) must +# serialize into the shape any HF chat template consumes — the multi-turn loop +# breaks at turn 2 otherwise. +def test_tool_call_roundtrip_messages_passthrough(): + t, fake = _template_with_fake() + t.render( + [ + ChatMessage(role="user", content="weather?"), + ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + ], + ), + ChatMessage(role="tool", tool_call_id="c1", content='{"temp_c": 18}'), + ] + ) + msgs = fake.seen_messages + asst = next(m for m in msgs if m["role"] == "assistant") + assert asst["tool_calls"][0]["function"]["name"] == "get_weather" + assert asst["tool_calls"][0]["function"]["arguments"] == '{"city": "Paris"}' + tool = next(m for m in msgs if m["role"] == "tool") + assert tool["tool_call_id"] == "c1" and "temp_c" in tool["content"] diff --git a/extension/llm/server/python/tests/test_tool_calls.py b/extension/llm/server/python/tests/test_tool_calls.py new file mode 100644 index 00000000000..06ffe040d70 --- /dev/null +++ b/extension/llm/server/python/tests/test_tool_calls.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tool-calling tests (HTTP contract via the server). + +Hermes/Qwen format only. The server buffers the model's full output and parses +it once into complete OpenAI tool_calls; parse failures degrade to visible text. +""" + +import json + +WEATHER_TOOLS = [ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {"type": "object"}}, + } +] + + +def _call_text(name, args): + return f'\n{{"name": "{name}", "arguments": {json.dumps(args)}}}\n' + + +# --- HTTP contract: non-streaming --------------------------------------- + + +def test_tool_call_nonstreaming(make_client): + client, _ = make_client(tokens=[_call_text("get_weather", {"city": "Paris"})]) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather in Paris?"}], + "tools": WEATHER_TOOLS, + }, + ).json() + choice = body["choices"][0] + assert choice["finish_reason"] == "tool_calls" + tc = choice["message"]["tool_calls"][0] + assert tc["type"] == "function" + assert tc["function"]["name"] == "get_weather" + assert json.loads(tc["function"]["arguments"]) == {"city": "Paris"} + + +def test_tool_call_streaming(make_client): + client, _ = make_client(tokens=[_call_text("get_weather", {"city": "Paris"})]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather?"}], + "tools": WEATHER_TOOLS, + "stream": True, + }, + ) + # Reuse the contract helper's SSE parsing. + chunks = [] + for line in resp.text.splitlines(): + line = line.strip() + if line.startswith("data:") and "[DONE]" not in line: + chunks.append(json.loads(line[len("data:") :].strip())) + tool_deltas = [ + c["choices"][0]["delta"]["tool_calls"][0] + for c in chunks + if c["choices"] and c["choices"][0]["delta"].get("tool_calls") + ] + assert tool_deltas and tool_deltas[0]["function"]["name"] == "get_weather" + assert chunks[-1]["choices"][0]["finish_reason"] == "tool_calls" + + +def test_undefined_tool_is_not_called(make_client): + # Model calls a tool not in the request's tools -> no tool_calls; the raw + # call stays visible as content (degrade to text, never silent drop). + client, _ = make_client(tokens=[_call_text("rm_rf", {"path": "/"})]) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ).json() + msg = body["choices"][0]["message"] + assert msg.get("tool_calls") is None + assert "rm_rf" in (msg.get("content") or "") # not dropped — visible as text + + +def test_mixed_valid_and_undefined_tool_degrades_to_text(make_client): + # A response with one valid + one undefined call must NOT emit the valid one + # while silently dropping the undefined one — the whole response degrades to + # visible text so the model's full intent is preserved. + client, _ = make_client( + tokens=[ + _call_text("get_weather", {"city": "Paris"}) + + _call_text("rm_rf", {"path": "/"}) + ] + ) + msg = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ).json()["choices"][0]["message"] + assert msg.get("tool_calls") is None # no partial set + content = msg.get("content") or "" + assert "rm_rf" in content and "get_weather" in content # full intent visible + + +def test_malformed_tool_call_falls_back_to_text(make_client): + # Broken JSON inside the markers must not crash; degrade to visible text. + client, _ = make_client(tokens=["\n{not json}\n"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["message"].get("tool_calls") is None + + +def test_no_tools_field_means_text_even_if_markers_present(make_client): + # Without a tools array, tool markers are just content (not parsed). + client, _ = make_client(tokens=[_call_text("get_weather", {"city": "X"})]) + body = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ).json() + assert body["choices"][0]["message"].get("tool_calls") is None + + +def test_parallel_calls_in_one_message(make_client): + # Two complete blocks in one output -> two structured calls. + tokens = [ + _call_text("get_weather", {"city": "A"}) + + _call_text("get_weather", {"city": "B"}) + ] + client, _ = make_client(tokens=tokens) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather in A and B?"}], + "tools": WEATHER_TOOLS, + }, + ).json() + calls = body["choices"][0]["message"]["tool_calls"] + assert [json.loads(c["function"]["arguments"])["city"] for c in calls] == ["A", "B"] From c537c769e4f899f05d2c5ba4bbae0ea50e5b902b Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 16:14:35 -0700 Subject: [PATCH 2/4] [UPDATE] Update [ghstack-poisoned] --- extension/llm/server/README.md | 14 ++++++++------ extension/llm/server/python/README.md | 3 ++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/extension/llm/server/README.md b/extension/llm/server/README.md index 7d2d2393e12..7c9e2a32b7a 100644 --- a/extension/llm/server/README.md +++ b/extension/llm/server/README.md @@ -14,8 +14,9 @@ extension/llm/server/ Why this layout: the OpenAI contract is identical across languages, so the **spec** and **conformance** suite are shared, and each language gets its own implementation directory. The real cross-language reuse comes from the C++ -`TextLLMRunner` (and the planned `Session` primitives) underneath — each server -is a thin protocol shell over that engine. See `python/README.md` to run it. +`LLMEngine`/`LLMSession` primitives underneath (with `TextLLMRunner` as the +current adapter) — each server is a thin protocol shell over that engine. See +`python/README.md` to run it. Status: experimental, reliability-first and deliberately narrow. Implemented: `/health`, `/v1/models`, `/v1/chat/completions` (streaming + non-streaming), @@ -23,7 +24,8 @@ Hugging Face chat templates (`--hf-tokenizer`), `temperature` / `max_tokens` / `max_completion_tokens` / `stop`, Hermes/Qwen tool calling (`...`, complete calls only) with `tool_choice="none"`, structured API errors, cancellation, and an opt-in conservative per-runner KV -prefix cache (`--enable-prefix-cache`). Unsupported params (`top_p`, `seed`, -`n>1`, `reasoning_effort`) are rejected with a structured 400 rather than -silently ignored. See `python/README.md` to run it and `spec/README.md` for the -exact contract. +prefix cache (`--enable-prefix-cache`). Unsupported params (including `top_p`, +`seed`, `n>1`, `reasoning_effort`, penalties, `logit_bias`, `response_format`, +`logprobs`, and `tool_choice="required"`) are rejected with a structured 400 +rather than silently ignored. See `python/README.md` to run it and +`spec/README.md` for the exact contract. diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md index 1d535e9fdc7..99cb1a3abbb 100644 --- a/extension/llm/server/python/README.md +++ b/extension/llm/server/python/README.md @@ -1,6 +1,7 @@ # ExecuTorch LLM Server — Python -A thin OpenAI-compatible HTTP server over ExecuTorch's `TextLLMRunner`. +A thin OpenAI-compatible HTTP server over ExecuTorch's `LLMEngine`/`LLMSession` +serving API (with `TextLLMRunner` as the underlying adapter). ## Install From a2a707cab02bb753279dff48b9a394157ff0ad28 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 4 Jun 2026 07:26:26 -0700 Subject: [PATCH 3/4] [UPDATE] Update [ghstack-poisoned] --- extension/llm/server/python/serving_chat.py | 6 +- .../server/python/tests/test_tool_calls.py | 66 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index 8bf2779a945..94271170a11 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -265,8 +265,12 @@ def _reject_unsupported_params(req: ChatCompletionRequest) -> None: async def create(self, req: ChatCompletionRequest): self._reject_invalid_values(req) self._reject_unsupported_params(req) + # tool_choice="none" must hide tools from the model: if we still render + # the tool schemas, the model can emit a that we'd surface as + # plain text (parsing is disabled), instead of a normal answer. + template_tools = None if req.tool_choice == "none" else req.tools prompt = self._template.render( - req.messages, tools=req.tools, template_kwargs=req.chat_template_kwargs + req.messages, tools=template_tools, template_kwargs=req.chat_template_kwargs ) # Pre-flight context check: reject cleanly instead of failing mid-generation # (only possible when a tokenizer is available to count, e.g. --hf-tokenizer). diff --git a/extension/llm/server/python/tests/test_tool_calls.py b/extension/llm/server/python/tests/test_tool_calls.py index 06ffe040d70..249c7b1827b 100644 --- a/extension/llm/server/python/tests/test_tool_calls.py +++ b/extension/llm/server/python/tests/test_tool_calls.py @@ -111,6 +111,72 @@ def test_mixed_valid_and_undefined_tool_degrades_to_text(make_client): assert "rm_rf" in content and "get_weather" in content # full intent visible +def test_tool_choice_none_omits_tools_from_prompt(): + from executorch.extension.llm.server.python.chat_template import ChatTemplate + from executorch.extension.llm.server.python.runner_pool import RunnerPool + from executorch.extension.llm.server.python.server import build_app + from executorch.extension.llm.server.python.serving_chat import ServingChat + from executorch.extension.llm.server.python.tool_parsers import HermesDetector + + # tool_choice="none" must NOT inject tool schemas into the chat template; if it + # did, the model could still emit a that we'd surface as plain text + # (parsing is disabled for "none"). Assert via a recording tokenizer. + from fastapi.testclient import TestClient + + class _RecordingTok: + all_special_tokens: list = [] + + def __init__(self): + self.tools_seen = "UNSET" + + def encode(self, text): + return [0] + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + self.tools_seen = tools + return "PROMPT" + + class _Runner: + def reset(self): + pass + + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + if token_callback: + token_callback("ok") + + rec = _RecordingTok() + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + template._hf = rec + pool = RunnerPool(runner_factory=lambda: _Runner(), num_runners=1) + serving = ServingChat( + pool, template, "test-model", tool_detector_cls=HermesDetector + ) + client = TestClient(build_app(serving, "test-model")) + body = { + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + "max_tokens": 8, + } + + assert ( + client.post( + "/v1/chat/completions", json={**body, "tool_choice": "none"} + ).status_code + == 200 + ) + assert rec.tools_seen is None # tools omitted from the rendered prompt + + # Control: default ("auto") still passes the tools through to the template. + assert client.post("/v1/chat/completions", json=body).status_code == 200 + assert rec.tools_seen == WEATHER_TOOLS + + def test_malformed_tool_call_falls_back_to_text(make_client): # Broken JSON inside the markers must not crash; degrade to visible text. client, _ = make_client(tokens=["\n{not json}\n"]) From b433c1bde9df7744ec7ec6ec63582d0192350628 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 9 Jun 2026 09:23:08 -0700 Subject: [PATCH 4/4] [UPDATE] Update [ghstack-poisoned] --- extension/llm/server/README.md | 7 +- extension/llm/server/python/README.md | 31 +-- extension/llm/server/python/runner_pool.py | 132 ------------ extension/llm/server/python/server.py | 23 ++- extension/llm/server/python/serving_chat.py | 140 ++++++------- .../llm/server/python/session_runtime.py | 193 ++++++++++++++++++ extension/llm/server/python/tests/conftest.py | 58 +++++- .../server/python/tests/test_runner_pool.py | 116 ----------- .../python/tests/test_session_runtime.py | 187 +++++++++++++++++ .../server/python/tests/test_tool_calls.py | 6 +- extension/llm/server/python/worker_client.py | 163 ++++++++++++--- 11 files changed, 661 insertions(+), 395 deletions(-) delete mode 100644 extension/llm/server/python/runner_pool.py create mode 100644 extension/llm/server/python/session_runtime.py delete mode 100644 extension/llm/server/python/tests/test_runner_pool.py create mode 100644 extension/llm/server/python/tests/test_session_runtime.py diff --git a/extension/llm/server/README.md b/extension/llm/server/README.md index f90da90675a..0b18d31cae5 100644 --- a/extension/llm/server/README.md +++ b/extension/llm/server/README.md @@ -25,9 +25,10 @@ Hugging Face chat templates (`--hf-tokenizer`), `temperature` / `max_tokens` / `max_completion_tokens` / `stop`, Hermes tool calling by default (`...` JSON, complete calls only; model-specific launchers may select the Qwen XML format) with `tool_choice="none"`, -structured API errors, and best-effort cancellation. V1 serving is single-slot -(one worker, one session) with no prefix cache; KV prefix reuse, if it returns, -lives inside the worker/session, not the control plane. Unsupported params (including `top_p`, +structured API errors, and best-effort cancellation. One worker process with +serialized execution; it hosts many isolated sessions on one weight load (warm +append-only resume across turns). KV/prefix state lives inside the +worker/session, not the control plane. Unsupported params (including `top_p`, `seed`, `n>1`, `reasoning_effort`, penalties, `logit_bias`, `response_format`, `logprobs`, and `tool_choice="required"`) are rejected with a structured 400 rather than silently ignored. See `python/README.md` to run it and diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md index b5c59bb9648..e14e6176c81 100644 --- a/extension/llm/server/python/README.md +++ b/extension/llm/server/python/README.md @@ -69,7 +69,7 @@ Key flags: | `--allow-chatml-fallback` | opt into approximate ChatML when no HF tokenizer | | `--no-think` | default `enable_thinking=False` (e.g. Qwen3) | | `--max-context N` | reject over-long prompts with 400 instead of failing mid-gen | -| `--num-runners N` | V1 supports **1 only** (single-slot: one worker serves one session; concurrent requests queue) | +| `--num-runners N` | Worker processes — **1 only** (one worker hosts many isolated sessions on one weight load; more would duplicate weights) | | `--worker-bin PATH` | path to the `text_llm_worker` binary (default: `cmake-out/extension/llm/server/cpp/text_llm_worker`) | ## Use from an agent harness @@ -101,16 +101,19 @@ pytest tests/ OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest ../conformance/test_openai_contract.py ``` -`tests/` builds a `RunnerPool` over a single `FakeRunner` worker handle, so the +`tests/` builds a `SessionRuntime` over a single `FakeRunner` worker, so the real server/protocol/streaming code is tested over HTTP without a `.pte`. The worker JSONL protocol is covered separately by `tests/test_worker_client.py`. ## Architecture -Control plane (this dir, Python): server, OpenAI protocol, chat templating, -streaming bridge, tool parsing — no CUDA, no model, no pybind. Data plane (C++): -a worker process (`text_llm_worker`) owns one model session and does all token -stepping and KV mutation; it speaks one JSON object per line on stdin/stdout. +Control plane (this dir, Python): an OpenAI adapter (`serving_chat`) over a +stateful `SessionRuntime` over one `WorkerClient` — server, protocol, chat +templating, streaming bridge, tool parsing — no CUDA, no model, no pybind. Data +plane (C++): a worker process (`text_llm_worker`) that owns all model state +(many isolated sessions on one weight load, warm-resume prefix logic) and does +all token stepping and KV mutation; it speaks one JSON object per line on +stdin/stdout. JSONL protocol (stdout carries protocol JSON only; logs go to stderr): @@ -132,9 +135,9 @@ does blocking pipe I/O on its executor thread. | `server.py` | FastAPI app, routes, CLI entrypoint, worker spawn | | `protocol.py` | OpenAI request/response schemas | | `chat_template.py` | messages (+tools) → prompt string | -| `worker_client.py` | spawn a worker process + drive it over JSONL | -| `runner_pool.py` | worker pool (one in-flight request per worker) + async streaming bridge | -| `serving_chat.py` | `/v1/chat/completions` (streaming + non-streaming, stop, tools) | +| `worker_client.py` | spawn a worker process + drive it over JSONL (raw transport) | +| `session_runtime.py` | stateful runtime over one worker: open/generate/reset/close + streaming bridge | +| `serving_chat.py` | `/v1/chat/completions` OpenAI adapter (streaming + non-streaming, stop, tools) | | `tool_parsers/` | Hermes/Qwen `` parser only | | `cpp/text_llm_worker.cpp` | the generic C++ worker binary | @@ -151,11 +154,11 @@ imports an example. Backend specifics (CUDA/AOTI, Metal) stay inside the worker. ## Scope & caveats Deliberately narrow (reliability-first): Hermes/Qwen tool calling only; -unsupported sampling params are rejected, not ignored. V1 is **single-slot**: one -worker hosts one session, so `--num-runners` accepts 1 and concurrent requests -queue. Serving capacity is worker capacity, chosen by the launcher (each worker -is its own process with its own weights, so N workers cost N × the weight memory) -— an operator decision, not something the pool infers. +unsupported sampling params are rejected, not ignored. **One worker process, +serialized execution** (one in-flight request; concurrent requests queue). +Session capacity is determined by the worker/engine — a single worker hosts many +isolated sessions on one weight load — so `--num-runners` accepts 1; extra worker +processes would each carry their own copy of the weights. Cancellation is best-effort: a worker request runs to completion and is not interruptible mid-generation in V1, so `runner.stop()` means "the control plane diff --git a/extension/llm/server/python/runner_pool.py b/extension/llm/server/python/runner_pool.py deleted file mode 100644 index 83fd8fe8245..00000000000 --- a/extension/llm/server/python/runner_pool.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Pool of model-execution workers + the streaming bridge. - -The Python server is HTTP/control plane only: it never loads a model, links a -backend, or imports a runtime pybind. Each pooled worker is a separate process -(a C++ worker binary in production; a fake in tests) that owns one model session -and is driven over JSONL by a WorkerClient. The pool hands an idle worker to a -request and bridges the worker's blocking generate() into an async token stream. - -One worker == one session; a request holds a worker exclusively for its -duration, so requests beyond the worker count queue. The number of workers is -the serving capacity, chosen by the launcher: each worker is its own process -with its own weights, so N workers cost N x the weight memory — an operator -decision, not something the pool infers. - -There is no prefix cache and no prefix-affinity routing here: caching, if any, -lives inside the worker/session, not the control plane. -""" - -import asyncio -import logging -from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import AsyncIterator, Optional, Sequence - -logger = logging.getLogger(__name__) - -_SENTINEL = object() - - -@dataclass -class GenStats: - prompt_tokens: int = 0 - completion_tokens: int = 0 - # Worker-reported stop reason ("stop" | "length"), or None if not reported. - finish_reason: Optional[str] = None - - -class RunnerPool: - """A fixed pool of model-execution workers. - - `workers` is a non-empty sequence of worker handles, each exposing - ``generate(prompt, config, token_callback, stats_callback)`` / ``stop()`` - (a WorkerClient in production; a fake in tests). The pool owns scheduling - (one in-flight request per worker) and the blocking->async stream bridge; the - workers own all model execution. - """ - - def __init__(self, workers: Sequence[object]): - self._workers = list(workers) - if not self._workers: - raise ValueError("RunnerPool requires at least one worker") - n = len(self._workers) - self._busy = [False] * n - self._cond = asyncio.Condition() - # One executor thread per worker: generate() blocks on worker I/O, and a - # worker is never driven by two threads at once (the busy flags enforce - # exclusive acquisition). - self._executor = ThreadPoolExecutor(max_workers=n) - - @asynccontextmanager - async def acquire(self): - # No prefix cache, so no affinity routing — any idle worker will do. - # (V2 routes by session_id, not by prompt.) - async with self._cond: - while all(self._busy): - await self._cond.wait() - idx = next(i for i, busy in enumerate(self._busy) if not busy) - self._busy[idx] = True - try: - yield self._workers[idx] - finally: - async with self._cond: - self._busy[idx] = False - self._cond.notify() - - async def generate_stream( - self, - runner, - prompt: str, - config, - stats: Optional[GenStats] = None, - ) -> AsyncIterator[str]: - """Yield generated text pieces. If `stats` is given it's filled in place - with token counts (per-request, so concurrent streams don't race).""" - out_stats = stats if stats is not None else GenStats() - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - - def token_cb(token: str) -> None: - loop.call_soon_threadsafe(queue.put_nowait, token) - - def stats_cb(s) -> None: - out_stats.prompt_tokens = s.num_prompt_tokens - out_stats.completion_tokens = s.num_generated_tokens - out_stats.finish_reason = getattr(s, "finish_reason", None) - - def run() -> None: - try: - runner.generate(prompt, config, token_cb, stats_cb) - except Exception as e: # noqa: BLE001 - surface to the stream consumer - loop.call_soon_threadsafe(queue.put_nowait, e) - finally: - loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL) - - fut = loop.run_in_executor(self._executor, run) - try: - while True: - item = await queue.get() - if item is _SENTINEL: - break - if isinstance(item, Exception): - raise item - yield item - except asyncio.CancelledError: - runner.stop() - raise - finally: - await fut - - def close(self) -> None: - """Shut down all workers (called at server shutdown).""" - for w in self._workers: - close = getattr(w, "close", None) - if close is not None: - close() diff --git a/extension/llm/server/python/server.py b/extension/llm/server/python/server.py index 1a88eef888f..94c55479275 100644 --- a/extension/llm/server/python/server.py +++ b/extension/llm/server/python/server.py @@ -14,9 +14,10 @@ no runtime pybind. Model execution lives in a separate C++ worker process (``text_llm_worker``) driven over JSONL via WorkerClient. -V1 is single-slot: one worker hosts one session, so concurrent requests queue. -There is no prefix cache in V1 serving; caching, if it returns, lives inside the -worker/session, not the control plane. +One worker process, serialized execution (one in-flight request; concurrent +requests queue). Session capacity is set by the worker/engine -- a single worker +hosts many isolated sessions on one weight load; extra worker processes would +duplicate the weights, so `--num-runners` accepts 1. Example: python -m executorch.extension.llm.server.python.server \\ @@ -35,8 +36,8 @@ from .chat_template import ChatTemplate from .errors import APIError from .protocol import ChatCompletionRequest, ModelCard, ModelList -from .runner_pool import RunnerPool from .serving_chat import ServingChat +from .session_runtime import SessionRuntime from .tool_parsers import HermesDetector from .worker_client import spawn_worker @@ -143,7 +144,8 @@ def main() -> None: "--num-runners", type=int, default=1, - help="V1 supports 1 only (single-slot: one worker serves one session).", + help="Worker processes. 1 only: one worker hosts many isolated sessions " + "on a single weight load; more workers would duplicate the weights.", ) p.add_argument( "--worker-bin", @@ -158,7 +160,8 @@ def main() -> None: if args.num_runners != 1: p.error( - "V1 is single-slot: one worker serves one session; concurrent requests queue." + "Only 1 worker process is supported (it hosts many isolated sessions " + "on one weight load); more workers would duplicate the weights." ) default_template_kwargs = {"enable_thinking": False} if args.no_think else None @@ -168,10 +171,10 @@ def main() -> None: default_template_kwargs=default_template_kwargs, allow_fallback=args.allow_chatml_fallback, ) - worker = _spawn(args) # one worker == one session (single-slot V1) - pool = RunnerPool([worker]) + worker = _spawn(args) # one worker hosting many isolated sessions + runtime = SessionRuntime(worker) serving = ServingChat( - pool, + runtime, template, args.model_id, max_context=args.max_context, @@ -184,7 +187,7 @@ def main() -> None: @app.on_event("shutdown") def _stop_worker(): - pool.close() + runtime.close_worker() import uvicorn # imported here so build_app() is usable without the ASGI server diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index 4541e292592..fdd8032a06f 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""/v1/chat/completions handler: builds prompts, drives the runner, and emits -OpenAI-shaped responses (streaming and non-streaming).""" +"""/v1/chat/completions OpenAI adapter: validates requests, renders the chat +template, parses tool calls, and formats OpenAI responses. It owns no model or +session state -- generation goes through SessionRuntime, and the token-ID warm- +resume transcript lives in OpenAITranscriptState.""" import json import logging import math -from dataclasses import dataclass, field from typing import AsyncIterator, Optional from .chat_template import ChatTemplate @@ -28,29 +29,12 @@ ToolCall, Usage, ) -from .runner_pool import GenStats, RunnerPool +from .session_runtime import GenerationOptions, GenStats, PromptInput, SessionRuntime from .tool_parsers import HermesDetector, ToolCallItem logger = logging.getLogger(__name__) -@dataclass -class GenConfig: - """Generation parameters the control plane forwards to a worker. - - A plain dataclass (no runtime/pybind dependency): the control plane never - loads a model. The worker reads these fields off the JSONL request; only the - knobs we honor today are here (max_new_tokens, temperature, stop). - -1 max_new_tokens means "unset / let the worker pick from model metadata". - `stop` lets the worker terminate at a stop sequence instead of running to EOS - (the server still re-applies stops as a backstop). - """ - - max_new_tokens: int - temperature: float = 0.0 - stop: list[str] = field(default_factory=list) - - def _earliest_stop(text: str, stops: list[str]) -> Optional[int]: """Index of the earliest special-token occurrence in `text`, or None.""" best = None @@ -64,13 +48,13 @@ def _earliest_stop(text: str, stops: list[str]) -> Optional[int]: class ServingChat: def __init__( self, - pool: RunnerPool, + runtime: SessionRuntime, template: ChatTemplate, model_id: str, max_context: Optional[int] = None, tool_detector_cls: Optional[type[HermesDetector]] = None, ): - self._pool = pool + self._runtime = runtime self._template = template self._model_id = model_id self._max_context = max_context @@ -131,20 +115,19 @@ def _truncate_raw(self, text: str, req: ChatCompletionRequest) -> str: boundary is neither parsed nor emitted.""" return self._apply_stop(text, self._stops + self._request_stops(req)) - @staticmethod - async def _collect_until_stop(stream: AsyncIterator[str], runner, stops: list[str]): + async def _collect_until_stop(self, stream: AsyncIterator[str], stops: list[str]): """Accumulate a buffered (non-streamed) generation into one string, - halting the runner early once a stop string (special token or request + halting the runtime early once a stop string (special token or request stop) appears, then draining so stats finalize. Returns (text, stopped): `stopped` lets the caller force finish_reason="stop" even when tokens - queued before the runner observed stop() pushed the count to max_tokens.""" + queued before the runtime observed stop() pushed the count to max_tokens.""" text = "" stopped = False async for tok in stream: text += tok if stops and _earliest_stop(text, stops) is not None: stopped = True - runner.stop() + self._runtime.stop() async for _ in stream: # drain so stats_cb fires pass break @@ -193,8 +176,8 @@ async def _clean( if buf: yield buf - def _config(self, req: ChatCompletionRequest) -> GenConfig: - return GenConfig( + def _options(self, req: ChatCompletionRequest) -> GenerationOptions: + return GenerationOptions( max_new_tokens=req.resolved_max_tokens(), temperature=req.temperature if req.temperature is not None else 0.0, # Let the worker terminate at the same boundary the control plane @@ -320,26 +303,28 @@ async def create(self, req: ChatCompletionRequest): requested = req.resolved_max_tokens() if requested > 0 and count + requested > self._max_context: raise ContextLengthExceeded(count, self._max_context, requested) - config = self._config(req) + options = self._options(req) + prompt_input = PromptInput(text=prompt) if req.stream: - return self._stream(req, prompt, config) - return await self._complete(req, prompt, config) + return self._stream(req, prompt_input, options) + return await self._complete(req, prompt_input, options) async def _complete( - self, req: ChatCompletionRequest, prompt: str, config: GenConfig + self, + req: ChatCompletionRequest, + prompt: PromptInput, + options: GenerationOptions, ) -> ChatCompletionResponse: stats = GenStats() - async with self._pool.acquire() as runner: - try: - # Collect raw text (markers intact for tool parsing), halting early - # at a stop boundary (special token or request stop). - text, stopped = await self._collect_until_stop( - self._pool.generate_stream(runner, prompt, config, stats), - runner, - self._stops + self._request_stops(req), - ) - except Exception as e: # noqa: BLE001 - surface as a structured API error - raise GenerationError(str(e)) + try: + # Collect raw text (markers intact for tool parsing), halting early + # at a stop boundary (special token or request stop). + text, stopped = await self._collect_until_stop( + self._runtime.generate_stream(None, prompt, options, stats), + self._stops + self._request_stops(req), + ) + except Exception as e: # noqa: BLE001 - surface as a structured API error + raise GenerationError(str(e)) # Bound the raw output at the first stop/special token BEFORE tool # parsing, so a call after the stop boundary is not parsed/emitted. tool_calls, content = self._extract_tools(req, self._truncate_raw(text, req)) @@ -362,7 +347,10 @@ async def _complete( ) async def _stream( - self, req: ChatCompletionRequest, prompt: str, config: GenConfig + self, + req: ChatCompletionRequest, + prompt: PromptInput, + options: GenerationOptions, ) -> AsyncIterator[str]: cid = _new_id("chatcmpl") @@ -383,37 +371,35 @@ def chunk(delta: DeltaMessage, finish=None) -> str: stats = GenStats() stop_hit = [False] # set when a stop boundary is reached (forces finish="stop") stops = self._stops + self._request_stops(req) - async with self._pool.acquire() as runner: - try: - if use_tools: - # v1: buffer the (usually short) tool response, parse once. - # Halt early at a stop boundary, and bound the raw output - # BEFORE parsing so post-stop tool calls / text don't leak. - raw, stop_hit[0] = await self._collect_until_stop( - self._pool.generate_stream(runner, prompt, config, stats), - runner, - stops, - ) - tool_calls, content = self._extract_tools( - req, self._truncate_raw(raw, req) - ) - else: - # Plain chat: stream tokens live (best UX), cutting at special - # tokens or request stop sequences and halting early on a hit. - def on_stop(): - stop_hit[0] = True - runner.stop() - - async for token in self._clean( - self._pool.generate_stream(runner, prompt, config, stats), - stops, - on_stop=on_stop, - ): - yield chunk(DeltaMessage(content=token)) - except ( - Exception - ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket - error = e + try: + if use_tools: + # v1: buffer the (usually short) tool response, parse once. + # Halt early at a stop boundary, and bound the raw output + # BEFORE parsing so post-stop tool calls / text don't leak. + raw, stop_hit[0] = await self._collect_until_stop( + self._runtime.generate_stream(None, prompt, options, stats), + stops, + ) + tool_calls, content = self._extract_tools( + req, self._truncate_raw(raw, req) + ) + else: + # Plain chat: stream tokens live (best UX), cutting at special + # tokens or request stop sequences and halting early on a hit. + def on_stop(): + stop_hit[0] = True + self._runtime.stop() + + async for token in self._clean( + self._runtime.generate_stream(None, prompt, options, stats), + stops, + on_stop=on_stop, + ): + yield chunk(DeltaMessage(content=token)) + except ( + Exception + ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket + error = e if error is not None: err = { diff --git a/extension/llm/server/python/session_runtime.py b/extension/llm/server/python/session_runtime.py new file mode 100644 index 00000000000..e73dbafadc4 --- /dev/null +++ b/extension/llm/server/python/session_runtime.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Python's stateful local-LLM runtime over one C++ worker process. + +This is the internal boundary between protocol adapters (OpenAI chat, future +native/agent surfaces) and the worker. The adapter speaks sessions, prompts, and +generation parameters; the worker (driven over JSONL by a WorkerClient) owns all +model execution and session state (KV/recurrent, resident token ids, warm-resume +prefix logic). The Python server never loads a model, links a backend, or imports +a runtime pybind. + +A SessionRuntime owns exactly one worker and serializes access to it (one +in-flight request at a time), bridging the worker's blocking generate() into an +async token stream. Multi-worker scheduling / named-session affinity is out of +scope: a single worker already hosts many isolated sessions on one weight load, +routed by session_id inside the worker. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import AsyncIterator, Optional + +logger = logging.getLogger(__name__) + +_SENTINEL = object() + + +@dataclass +class PromptInput: + """A prompt as either a single rendered string or token-ID segments. Exactly + one of `text` / `segments` is set. Segments ([{"text": str} | {"ids": [int]}]) + let an adapter splice exact prior-turn token ids in place of a lossy + re-render (see openai_transcript).""" + + text: Optional[str] = None + segments: Optional[list] = None + + def __post_init__(self): + if (self.text is None) == (self.segments is None): + raise ValueError("exactly one of PromptInput.text / .segments must be set") + if self.segments is not None and not self.segments: + raise ValueError("PromptInput.segments must be non-empty") + + +@dataclass +class GenerationOptions: + """Sampling/length knobs forwarded to the worker (only what we honor today).""" + + max_new_tokens: int + temperature: float = 0.0 + stop: list[str] = field(default_factory=list) + + +@dataclass +class GenStats: + """Per-request metadata the worker reports at the end of generation.""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + # Worker-reported stop reason ("stop" | "length"), or None if not reported. + finish_reason: Optional[str] = None + # Warm-resume accounting (V2b.1): tokens served from the session's resident + # state vs prefilled this request, and why. + reused_prompt_tokens: int = 0 + prefilled_prompt_tokens: int = 0 + session_reset_reason: Optional[str] = None + # Exact token ids generated this turn (V2b.1.5), for an adapter's transcript + # store. Empty when the worker doesn't report them (e.g. a stop-trimmed turn). + generated_token_ids: list = field(default_factory=list) + + +# Forwarded to WorkerClient.generate() as the per-request config it reads fields +# off; keeps that low-level contract unchanged while the runtime's public surface +# is PromptInput + GenerationOptions + session_id. +@dataclass +class _WorkerRequest: + max_new_tokens: int + temperature: float + stop: list[str] + session_id: Optional[str] + prompt_segments: Optional[list] + + +class SessionRuntime: + """Stateful runtime over a single worker. `worker` is a WorkerClient (a fake + in tests) exposing generate()/stop()/close() and the session ops + open_session/reset_session/close_session.""" + + def __init__(self, worker): + self._worker = worker + # One executor thread; the lock guarantees the worker is never driven by + # two requests at once (it is single-in-flight). + self._executor = ThreadPoolExecutor(max_workers=1) + self._lock = asyncio.Lock() + + async def open(self, session_id: str) -> None: + """Admit a named session before generation so a capacity refusal surfaces + up front (the adapter maps it to an HTTP status) rather than mid-stream. + Idempotent.""" + await self._session_op("open_session", session_id) + + async def reset(self, session_id: str) -> None: + """Clear a named session's context (KV/recurrent + resident ids) but keep + its capacity slot. Idempotent.""" + await self._session_op("reset_session", session_id) + + async def close(self, session_id: str) -> None: + """Destroy a named session, freeing its state and slot. Idempotent.""" + await self._session_op("close_session", session_id) + + async def _session_op(self, method: str, session_id: str) -> None: + op = getattr(self._worker, method, None) + if op is None: + return # worker doesn't support sessions (e.g. a minimal fake) + loop = asyncio.get_running_loop() + async with self._lock: + await loop.run_in_executor(self._executor, op, session_id) + + def stop(self) -> None: + """Request the in-flight generation stop at the next token boundary.""" + self._worker.stop() + + async def generate_stream( + self, + session_id: Optional[str], + prompt: PromptInput, + options: GenerationOptions, + stats: Optional[GenStats] = None, + ) -> AsyncIterator[str]: + """Yield generated text pieces from the worker, holding the worker lock + for the whole generation. `stats` (if given) is filled in place with the + worker's terminal metadata (per-request, so concurrent callers don't + race). session_id None routes to the worker's anonymous scratch session.""" + out_stats = stats if stats is not None else GenStats() + req = _WorkerRequest( + max_new_tokens=options.max_new_tokens, + temperature=options.temperature, + stop=list(options.stop), + session_id=session_id, + prompt_segments=prompt.segments, + ) + prompt_text = prompt.text or "" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def token_cb(token: str) -> None: + loop.call_soon_threadsafe(queue.put_nowait, token) + + def stats_cb(s) -> None: + out_stats.prompt_tokens = s.num_prompt_tokens + out_stats.completion_tokens = s.num_generated_tokens + out_stats.finish_reason = getattr(s, "finish_reason", None) + out_stats.reused_prompt_tokens = getattr(s, "reused_prompt_tokens", 0) + out_stats.prefilled_prompt_tokens = getattr(s, "prefilled_prompt_tokens", 0) + out_stats.session_reset_reason = getattr(s, "session_reset_reason", None) + out_stats.generated_token_ids = getattr(s, "generated_token_ids", []) + + def run() -> None: + try: + self._worker.generate(prompt_text, req, token_cb, stats_cb) + except Exception as e: # noqa: BLE001 - surface to the stream consumer + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL) + + async with self._lock: + fut = loop.run_in_executor(self._executor, run) + try: + while True: + item = await queue.get() + if item is _SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + except asyncio.CancelledError: + self._worker.stop() + raise + finally: + await fut + + def close_worker(self) -> None: + """Shut the worker process down and the executor (called at shutdown).""" + close = getattr(self._worker, "close", None) + if close is not None: + close() + self._executor.shutdown(wait=False, cancel_futures=True) diff --git a/extension/llm/server/python/tests/conftest.py b/extension/llm/server/python/tests/conftest.py index 084fc4792b1..b91f0aec26e 100644 --- a/extension/llm/server/python/tests/conftest.py +++ b/extension/llm/server/python/tests/conftest.py @@ -6,10 +6,10 @@ """Fixtures for hermetic contract tests. -We build a RunnerPool over a single FakeRunner worker handle, so the tests -exercise the real server, protocol, templating, and streaming code over the -HTTP boundary with NO model, tokenizer, GPU, or worker subprocess. This mirrors -ExecuTorch's fake_llm_executor approach: fake the worker, test the real surface. +We build a SessionRuntime over a single FakeRunner worker, so the tests exercise +the real server, protocol, templating, and streaming code over the HTTP boundary +with NO model, tokenizer, GPU, or worker subprocess. This mirrors ExecuTorch's +fake_llm_executor approach: fake the worker, test the real surface. The control plane imports no runtime pybind; only fastapi, pydantic, httpx, and pytest are required. @@ -18,10 +18,11 @@ import pytest from executorch.extension.llm.server.python.chat_template import ChatTemplate -from executorch.extension.llm.server.python.runner_pool import RunnerPool from executorch.extension.llm.server.python.server import build_app from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.session_runtime import SessionRuntime from executorch.extension.llm.server.python.tool_parsers import HermesDetector +from executorch.extension.llm.server.python.worker_client import WorkerError from fastapi.testclient import TestClient @@ -32,15 +33,26 @@ class _FakeStats: class FakeRunner: - """Canned engine: emits fixed tokens, records the config it was given.""" + """Canned engine: emits fixed tokens, records the config it was given. - def __init__(self, tokens, fail=False, finish_reason=None): + With max_named_sessions > 0 it also models the worker's session admission: + open_session() enforces capacity and reports structured WorkerError codes, + matching the real worker's contract.""" + + def __init__( + self, tokens, fail=False, finish_reason=None, max_named_sessions=0, gen_ids=None + ): self._tokens = list(tokens) self._fail = fail self._finish_reason = finish_reason # worker-reported stop reason, if any + self._gen_ids = list(gen_ids or []) # ids reported per turn (V2b.1.5) self.captured_config = None self.stopped = False self.reset_count = 0 + self.max_named_sessions = max_named_sessions + self.open_named = set() # currently-open named sessions + self.opened_log = [] # every successful open, in order + self.reset_log = [] # every reset_session call, in order def reset(self): self.reset_count += 1 @@ -48,6 +60,23 @@ def reset(self): def stop(self): self.stopped = True + def open_session(self, session_id): + if session_id in self.open_named: + return # idempotent + if self.max_named_sessions == 0: + raise WorkerError("no named sessions", code="unsupported_session") + if len(self.open_named) >= self.max_named_sessions: + raise WorkerError("session capacity exhausted", code="capacity_exhausted") + self.open_named.add(session_id) + self.opened_log.append(session_id) + + def close_session(self, session_id): + self.open_named.discard(session_id) + + def reset_session(self, session_id): + # Clears context but keeps the slot (stays in open_named). + self.reset_log.append(session_id) + def generate(self, prompt, config, token_callback=None, stats_callback=None): self.captured_config = config if self._fail: @@ -59,6 +88,7 @@ def generate(self, prompt, config, token_callback=None, stats_callback=None): stats = _FakeStats() stats.num_generated_tokens = len(self._tokens) stats.finish_reason = self._finish_reason + stats.generated_token_ids = list(self._gen_ids) stats_callback(stats) @@ -87,14 +117,22 @@ def _make( prompt_tokens=None, fail=False, finish_reason=None, + max_named_sessions=0, + gen_ids=None, ): - fake = FakeRunner(tokens, fail=fail, finish_reason=finish_reason) - pool = RunnerPool([fake]) # one fake worker handle + fake = FakeRunner( + tokens, + fail=fail, + finish_reason=finish_reason, + max_named_sessions=max_named_sessions, + gen_ids=gen_ids, + ) + runtime = SessionRuntime(fake) # one fake worker template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) if prompt_tokens is not None: template._hf = _FakeTokenizer(prompt_tokens) serving = ServingChat( - pool, + runtime, template, "test-model", max_context=max_context, diff --git a/extension/llm/server/python/tests/test_runner_pool.py b/extension/llm/server/python/tests/test_runner_pool.py deleted file mode 100644 index a95682cb3e0..00000000000 --- a/extension/llm/server/python/tests/test_runner_pool.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Pool tests: idle-worker scheduling, cancellation, and concurrency isolation. - -The pool drives worker handles (a WorkerClient over a subprocess in production); -here we inject fakes — no model, GPU, or subprocess. Written with asyncio.run -(sync test bodies) to avoid depending on an async pytest plugin. -""" - -import asyncio -import threading - -import pytest - -from executorch.extension.llm.server.python.runner_pool import RunnerPool - - -class _BlockingRunner: - """Emits one token, then blocks until stop() is called.""" - - def __init__(self): - self._gate = threading.Event() - self.stopped = False - - def reset(self): - pass - - def stop(self): - self.stopped = True - self._gate.set() - - def generate(self, prompt, config, token_callback=None, stats_callback=None): - if token_callback: - token_callback("TOKEN") - self._gate.wait(timeout=5) - - -class _EchoRunner: - """Emits the prompt back as a single token; used to detect cross-talk.""" - - def reset(self): - pass - - def stop(self): - pass - - def generate(self, prompt, config, token_callback=None, stats_callback=None): - if token_callback: - token_callback(prompt) - - -def test_pool_requires_at_least_one_worker(): - with pytest.raises(ValueError): - RunnerPool([]) - - -# Client disconnect / cancellation invokes the worker's stop() HOOK — the pool's -# contract. Whether that actually halts generation is up to the worker: a -# production WorkerClient.stop() is a no-op (see worker_client.py), so early -# termination comes from worker-side stop strings / EOS, not this hook. This test -# asserts only that the pool calls the hook. -def test_cancellation_calls_stop_hook(): - async def scenario(): - worker = _BlockingRunner() - pool = RunnerPool([worker]) - async with pool.acquire() as r: - agen = pool.generate_stream(r, "p", None).__aiter__() - assert await agen.__anext__() == "TOKEN" # worker now blocking - nxt = asyncio.ensure_future(agen.__anext__()) - await asyncio.sleep(0.05) - nxt.cancel() - try: - await nxt - except asyncio.CancelledError: - pass - for _ in range(100): # let the worker observe stop() - if worker.stopped: - break - await asyncio.sleep(0.02) - assert worker.stopped - - asyncio.run(scenario()) - - -# Concurrent requests across workers don't interleave / corrupt each other, and -# requests beyond the worker count queue for an idle worker rather than failing. -def test_concurrent_requests_isolated_and_queued(): - async def scenario(): - pool = RunnerPool([_EchoRunner(), _EchoRunner()]) # two workers - - async def one(prompt): - async with pool.acquire() as r: - return "".join([t async for t in pool.generate_stream(r, prompt, None)]) - - # Three requests, two workers: the third queues; all echo correctly. - out = await asyncio.gather(one("AAA"), one("BBB"), one("CCC")) - assert sorted(out) == ["AAA", "BBB", "CCC"] - - asyncio.run(scenario()) - - -def test_close_shuts_down_workers(): - class _Closable: - def __init__(self): - self.closed = False - - def close(self): - self.closed = True - - workers = [_Closable(), _Closable()] - RunnerPool(workers).close() - assert all(w.closed for w in workers) diff --git a/extension/llm/server/python/tests/test_session_runtime.py b/extension/llm/server/python/tests/test_session_runtime.py new file mode 100644 index 00000000000..14f138cc007 --- /dev/null +++ b/extension/llm/server/python/tests/test_session_runtime.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""SessionRuntime tests: session-op routing, the blocking->async stream bridge, +cancellation, and worker shutdown. A fake worker stands in for the WorkerClient +(no model, GPU, or subprocess). asyncio.run keeps the test bodies sync.""" + +import asyncio +import threading + +from executorch.extension.llm.server.python.session_runtime import ( + GenerationOptions, + GenStats, + PromptInput, + SessionRuntime, +) + +_OPTS = GenerationOptions(max_new_tokens=8) + + +def _text(s="hi") -> PromptInput: + return PromptInput(text=s) + + +class _Worker: + """Records session ops + process close; emits nothing on generate.""" + + def __init__(self): + self.opened, self.reset_ids, self.closed_ids = [], [], [] + self.proc_closed = False + + def open_session(self, sid): + self.opened.append(sid) + + def reset_session(self, sid): + self.reset_ids.append(sid) + + def close_session(self, sid): + self.closed_ids.append(sid) + + def close(self): + self.proc_closed = True + + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + pass + + +def test_session_ops_route_to_worker(): + async def scenario(): + w = _Worker() + rt = SessionRuntime(w) + await rt.open("a") + await rt.reset("a") + await rt.close("a") + return w + + w = asyncio.run(scenario()) + assert w.opened == ["a"] and w.reset_ids == ["a"] and w.closed_ids == ["a"] + + +def test_session_ops_noop_when_worker_lacks_support(): + # A minimal worker without session ops: the runtime silently no-ops. + class _Bare: + def stop(self): + pass + + def generate(self, *a, **k): + pass + + async def scenario(): + rt = SessionRuntime(_Bare()) + await rt.open("a") + await rt.reset("a") + await rt.close("a") + + asyncio.run(scenario()) # must not raise + + +def test_generate_stream_yields_and_fills_stats(): + class _Echo: + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + token_callback("Hello") + token_callback(" world") + + class S: + num_prompt_tokens = 3 + num_generated_tokens = 2 + finish_reason = "stop" + generated_token_ids = [10, 11] + + stats_callback(S()) + + async def scenario(): + rt = SessionRuntime(_Echo()) + stats = GenStats() + out = [t async for t in rt.generate_stream("a", _text(), _OPTS, stats)] + return out, stats + + out, stats = asyncio.run(scenario()) + assert "".join(out) == "Hello world" + assert stats.completion_tokens == 2 + assert stats.finish_reason == "stop" + assert stats.generated_token_ids == [10, 11] + + +def test_generate_stream_forwards_session_and_segments_to_worker(): + captured = {} + + class _Cap: + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + captured["session_id"] = config.session_id + captured["segments"] = config.prompt_segments + captured["prompt"] = prompt + + async def scenario(): + rt = SessionRuntime(_Cap()) + seg = PromptInput(segments=[{"text": "a"}, {"ids": [1, 2]}]) + async for _ in rt.generate_stream("sess", seg, _OPTS, GenStats()): + pass + + asyncio.run(scenario()) + assert captured["session_id"] == "sess" + assert captured["segments"] == [{"text": "a"}, {"ids": [1, 2]}] + + +def test_cancellation_calls_worker_stop(): + class _Blocking: + def __init__(self): + self._gate = threading.Event() + self.stopped = False + + def stop(self): + self.stopped = True + self._gate.set() + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + token_callback("TOKEN") + self._gate.wait(timeout=5) + + async def scenario(): + w = _Blocking() + rt = SessionRuntime(w) + agen = rt.generate_stream("a", _text(), _OPTS).__aiter__() + assert await agen.__anext__() == "TOKEN" # worker now blocking + nxt = asyncio.ensure_future(agen.__anext__()) + await asyncio.sleep(0.05) + nxt.cancel() + try: + await nxt + except asyncio.CancelledError: + pass + for _ in range(100): # let the worker observe stop() + if w.stopped: + break + await asyncio.sleep(0.02) + await agen.aclose() + return w + + w = asyncio.run(scenario()) + assert w.stopped + + +def test_close_worker_shuts_down_worker(): + w = _Worker() + SessionRuntime(w).close_worker() + assert w.proc_closed + + +def test_prompt_input_requires_exactly_one(): + import pytest + + with pytest.raises(ValueError): + PromptInput() + with pytest.raises(ValueError): + PromptInput(text="x", segments=[{"text": "y"}]) diff --git a/extension/llm/server/python/tests/test_tool_calls.py b/extension/llm/server/python/tests/test_tool_calls.py index b4dcb7b5a7e..7b165fac287 100644 --- a/extension/llm/server/python/tests/test_tool_calls.py +++ b/extension/llm/server/python/tests/test_tool_calls.py @@ -113,9 +113,9 @@ def test_mixed_valid_and_undefined_tool_degrades_to_text(make_client): def test_tool_choice_none_omits_tools_from_prompt(): from executorch.extension.llm.server.python.chat_template import ChatTemplate - from executorch.extension.llm.server.python.runner_pool import RunnerPool from executorch.extension.llm.server.python.server import build_app from executorch.extension.llm.server.python.serving_chat import ServingChat + from executorch.extension.llm.server.python.session_runtime import SessionRuntime from executorch.extension.llm.server.python.tool_parsers import HermesDetector # tool_choice="none" must NOT inject tool schemas into the chat template; if it @@ -152,9 +152,9 @@ def generate(self, prompt, config, token_callback=None, stats_callback=None): rec = _RecordingTok() template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) template._hf = rec - pool = RunnerPool([_Runner()]) + runtime = SessionRuntime(_Runner()) serving = ServingChat( - pool, template, "test-model", tool_detector_cls=HermesDetector + runtime, template, "test-model", tool_detector_cls=HermesDetector ) client = TestClient(build_app(serving, "test-model")) body = { diff --git a/extension/llm/server/python/worker_client.py b/extension/llm/server/python/worker_client.py index 80e3ac207c8..6b7d4e84132 100644 --- a/extension/llm/server/python/worker_client.py +++ b/extension/llm/server/python/worker_client.py @@ -14,24 +14,41 @@ the binary and its launch args differ. Protocol (one JSON object per line): - worker -> stdout, once at startup: {"ready": true} - client -> stdin, per request: {"prompt": str, "max_new_tokens": int, - "temperature": float, "stop": [str, ...]} - worker -> stdout, per request: {"token": str} * (streamed) - {"done": true, "prompt_tokens": int, - "completion_tokens": int, - "finish_reason": "stop" | "length"} - or {"error": str} + worker -> stdout, once at startup: {"ready": true, "max_sessions": int, + "max_named_sessions": int} + client -> stdin: + generate: {"max_new_tokens": int, "temperature": float, "stop": [str, ...], + "session_id"?: str, and exactly one prompt form: + "prompt": str + "prompt_segments": [{"text": str} | {"ids": [int, ...]}]} + open: {"op": "open", "session_id": str} + close: {"op": "close", "session_id": str} + reset: {"op": "reset", "session_id": str} # clear context, keep slot + worker -> stdout: + generate: {"token": str} * (streamed) + {"done": true, "prompt_tokens": int, "completion_tokens": int, + "finish_reason": "stop" | "length", + "reused_prompt_tokens": int, "prefilled_prompt_tokens": int, + "session_reset_reason": "new"|"exact_prefix"|"dirty"|"mismatch" + |"equal", + "generated_token_ids"?: [int, ...]} # omitted if stop-trimmed + open: {"opened": true, "session_id": str} + close: {"closed": true, "session_id": str} + reset: {"reset": true, "session_id": str} + error: {"error": str, "code"?: str} # capacity_exhausted / + # unsupported_session The worker's stdout carries ONLY protocol JSON; its logs go to stderr. One -request at a time per worker (one worker == one session); the caller serializes. +request at a time per worker; the caller (SessionRuntime) serializes. A worker +hosts one engine and routes requests to per-session_id state (anonymous requests +share a scratch session); execution is synchronous. """ import json import logging import subprocess import threading -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable, Optional, Sequence logger = logging.getLogger(__name__) @@ -47,25 +64,48 @@ class WorkerStats: # stop) or "length" (ran to max_new, possibly clamped to the context window). # None if the worker didn't report it (older worker / fake). finish_reason: Optional[str] = None + # Warm-resume accounting (V2b.1): how many prompt tokens were served from the + # session's resident KV state vs actually prefilled this request, and why + # ("new"|"exact_prefix"|"dirty"|"mismatch"|"equal"). Not exposed as OpenAI + # usage; logged for measuring warm-resume hit rate. None on older workers. + reused_prompt_tokens: int = 0 + prefilled_prompt_tokens: int = 0 + session_reset_reason: Optional[str] = None + # The exact (non-terminal) token ids generated this turn. The control plane + # stores these per session and splices them back as an `ids` prompt segment + # next turn, so a prior assistant span is an exact token extension instead of + # a lossy chat-template re-render (V2b.1.5). Empty on older workers. + generated_token_ids: list = field(default_factory=list) class WorkerError(RuntimeError): - """A worker process failed, exited, or reported a generation error.""" + """A worker process failed, exited, or reported a generation error. + + `code` carries the worker's structured error code when present + ("capacity_exhausted", "unsupported_session"), so the HTTP layer can map it + to the right status; None for unstructured failures. + """ + + def __init__(self, message: str, code: Optional[str] = None): + super().__init__(message) + self.code = code class WorkerClient: - """Drives one model-execution worker process over JSONL. + """Drives one model-execution worker process over JSONL (raw transport). - Exposes the same ``generate(prompt, config, token_callback, stats_callback)`` - / ``reset()`` / ``stop()`` surface the runner pool expects, so a pool of - workers is a drop-in for the (retired) in-process session pool. One worker - hosts one session; calls are serialized by a lock (and by the pool's single - slot per worker). The control plane never executes model code. + Exposes ``generate(prompt, config, token_callback, stats_callback)`` / + ``stop()`` plus the session ops ``open_session`` / ``close_session`` / + ``reset_session`` that SessionRuntime drives. Calls are serialized by a lock + and by SessionRuntime (one in-flight request). The control plane never + executes model code. """ - def __init__(self, proc: subprocess.Popen): + def __init__(self, proc: subprocess.Popen, max_named_sessions: int = 0): self._proc = proc self._lock = threading.Lock() + # Named sessions this worker can host (0 = scratch-only / single session). + self.max_named_sessions = max_named_sessions def reset(self) -> None: # The worker resets its session at the start of each request; nothing to @@ -77,13 +117,82 @@ def stop(self) -> None: # generation in V1. pass + def open_session(self, session_id: str) -> None: + """Admit a named session (idempotent). Raises WorkerError with a `code` + ("capacity_exhausted" / "unsupported_session") if the worker refuses.""" + self._op({"op": "open", "session_id": session_id}, ack_key="opened") + + def close_session(self, session_id: str) -> None: + """Destroy a named session, freeing its state (idempotent).""" + self._op({"op": "close", "session_id": session_id}, ack_key="closed") + + def reset_session(self, session_id: str) -> None: + """Clear a named session's context (KV/recurrent + resident tokens) but + keep its capacity slot allocated (idempotent).""" + self._op({"op": "reset", "session_id": session_id}, ack_key="reset") + + def _op(self, request: dict, ack_key: str) -> None: + with self._lock: + if self._proc.poll() is not None: + raise WorkerError( + f"worker exited (code {self._proc.returncode}); restart the server" + ) + try: + self._proc.stdin.write(json.dumps(request) + "\n") + self._proc.stdin.flush() + except (BrokenPipeError, ValueError) as e: + raise WorkerError("worker stdin is closed") from e + line = self._proc.stdout.readline() + if not line: + raise WorkerError("worker exited mid-request") + msg = json.loads(line) + if msg.get(ack_key): + return + if "error" in msg: + raise WorkerError(msg["error"], code=msg.get("code")) + raise WorkerError(f"unexpected worker response: {msg}") + + @staticmethod + def _on_done(msg: dict, stats_callback) -> None: + reason = msg.get("session_reset_reason") + if reason is not None and logger.isEnabledFor(logging.DEBUG): + logger.debug( + "warm-resume: reason=%s reused=%d prefilled=%d", + reason, + msg.get("reused_prompt_tokens", 0), + msg.get("prefilled_prompt_tokens", 0), + ) + if stats_callback is not None: + stats_callback( + WorkerStats( + num_prompt_tokens=msg.get("prompt_tokens", 0), + num_generated_tokens=msg.get("completion_tokens", 0), + finish_reason=msg.get("finish_reason"), + reused_prompt_tokens=msg.get("reused_prompt_tokens", 0), + prefilled_prompt_tokens=msg.get("prefilled_prompt_tokens", 0), + session_reset_reason=reason, + generated_token_ids=msg.get("generated_token_ids", []), + ) + ) + def generate(self, prompt, config, token_callback=None, stats_callback=None): request = { - "prompt": prompt, "max_new_tokens": getattr(config, "max_new_tokens", -1), "temperature": getattr(config, "temperature", 0.0), "stop": list(getattr(config, "stop", []) or []), } + # Token-ID segments (V2b.1.5) take precedence over the rendered string: + # they let prior assistant spans be exact id runs, not lossy re-renders. + # `is not None` (not truthiness): segments is a distinct prompt form, kept + # whatever its content (the worker validates non-empty). + segments = getattr(config, "prompt_segments", None) + if segments is not None: + request["prompt_segments"] = segments + else: + request["prompt"] = prompt + session_id = getattr(config, "session_id", None) + if session_id: + request["session_id"] = session_id with self._lock: if self._proc.poll() is not None: raise WorkerError( @@ -104,17 +213,10 @@ def generate(self, prompt, config, token_callback=None, stats_callback=None): if token_callback is not None: token_callback(msg["token"]) elif msg.get("done"): - if stats_callback is not None: - stats_callback( - WorkerStats( - msg.get("prompt_tokens", 0), - msg.get("completion_tokens", 0), - msg.get("finish_reason"), - ) - ) + self._on_done(msg, stats_callback) return elif "error" in msg: - raise WorkerError(msg["error"]) + raise WorkerError(msg["error"], code=msg.get("code")) def close(self) -> None: """Terminate the worker process (called at server shutdown).""" @@ -160,5 +262,6 @@ def spawn_worker( msg = json.loads(line) if not msg.get("ready"): raise WorkerError(f"worker did not report ready: {msg}") - logger.info("Model worker ready.") - return WorkerClient(proc) + max_named = int(msg.get("max_named_sessions", 0)) + logger.info("Model worker ready (max_named_sessions=%d).", max_named) + return WorkerClient(proc, max_named_sessions=max_named)