diff --git a/extension/llm/server/README.md b/extension/llm/server/README.md
new file mode 100644
index 00000000000..0b18d31cae5
--- /dev/null
+++ b/extension/llm/server/README.md
@@ -0,0 +1,35 @@
+# ExecuTorch LLM Server
+
+OpenAI-compatible serving for ExecuTorch LLMs, so any OpenAI-compatible agent
+harness (pi, opencode, ...) can use ExecuTorch as a local backend.
+
+```
+extension/llm/server/
+ spec/ # language-neutral OpenAI contract ExecuTorch targets
+ conformance/ # one test suite every language server must pass
+ python/ # Python server implementation (current)
+ # cpp/ # future: no-Python single-binary server
+```
+
+Why this layout: the OpenAI contract is identical across languages, so the
+**spec** and **conformance** suite are shared, and each language gets its own
+implementation directory. The real cross-language reuse comes from the C++
+`LLMEngine`/`LLMSession` primitives underneath, packaged as a process-isolated
+**worker binary** (`text_llm_worker`) that any control plane drives over a small
+JSONL protocol — the server is a thin protocol shell that spawns and talks to
+that worker. See `python/README.md` to run it.
+
+Status: experimental, reliability-first and deliberately narrow. Implemented:
+`/health`, `/v1/models`, `/v1/chat/completions` (streaming + non-streaming),
+Hugging Face chat templates (`--hf-tokenizer`), `temperature` / `max_tokens` /
+`max_completion_tokens` / `stop`, Hermes tool calling by default
+(`...` JSON, complete calls only; model-specific launchers
+may select the Qwen XML format) with `tool_choice="none"`,
+structured API errors, and best-effort cancellation. One worker process with
+serialized execution; it hosts many isolated sessions on one weight load (warm
+append-only resume across turns). KV/prefix state lives inside the
+worker/session, not the control plane. Unsupported params (including `top_p`,
+`seed`, `n>1`, `reasoning_effort`, penalties, `logit_bias`, `response_format`,
+`logprobs`, and `tool_choice="required"`) are rejected with a structured 400
+rather than silently ignored. See `python/README.md` to run it and
+`spec/README.md` for the exact contract.
diff --git a/extension/llm/server/conformance/test_openai_contract.py b/extension/llm/server/conformance/test_openai_contract.py
new file mode 100644
index 00000000000..3347126b5f4
--- /dev/null
+++ b/extension/llm/server/conformance/test_openai_contract.py
@@ -0,0 +1,207 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Language-neutral OpenAI-contract conformance tests.
+
+Runs against any base URL (ExecuTorch, llama.cpp, mlx-lm, ...) so every server
+implementation is validated against one shared spec. Point it at a running
+server:
+
+ OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest test_openai_contract.py
+
+Skips automatically if no server is reachable.
+"""
+
+import json
+import os
+import urllib.error
+import urllib.request
+
+import pytest
+
+BASE_URL = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8000/v1").rstrip("/")
+MODEL = os.environ.get("OPENAI_MODEL", "executorch")
+
+
+def _post(path: str, body: dict, stream: bool = False):
+ req = urllib.request.Request(
+ f"{BASE_URL}{path}",
+ data=json.dumps(body).encode(),
+ headers={"Content-Type": "application/json"},
+ method="POST",
+ )
+ return urllib.request.urlopen(req, timeout=120)
+
+
+def _server_up() -> bool:
+ try:
+ urllib.request.urlopen(f"{BASE_URL}/models", timeout=5)
+ return True
+ except Exception:
+ return False
+
+
+pytestmark = pytest.mark.skipif(
+ not _server_up(), reason="no OpenAI server at OPENAI_BASE_URL"
+)
+
+
+def test_models_listing():
+ with urllib.request.urlopen(f"{BASE_URL}/models", timeout=10) as r:
+ data = json.loads(r.read())
+ assert data["object"] == "list"
+ assert any("id" in m for m in data["data"])
+
+
+def test_chat_completion_nonstreaming():
+ body = {
+ "model": MODEL,
+ "messages": [{"role": "user", "content": "Say hello in one word."}],
+ "max_tokens": 16,
+ "temperature": 0.0,
+ }
+ with _post("/chat/completions", body) as r:
+ data = json.loads(r.read())
+ assert data["object"] == "chat.completion"
+ assert data["choices"][0]["message"]["role"] == "assistant"
+ assert isinstance(data["choices"][0]["message"]["content"], str)
+ assert data["choices"][0]["finish_reason"] is not None
+
+
+def test_chat_completion_streaming():
+ body = {
+ "model": MODEL,
+ "messages": [{"role": "user", "content": "Count to three."}],
+ "max_tokens": 32,
+ "stream": True,
+ }
+ saw_role = saw_content = saw_done = False
+ with _post("/chat/completions", body, stream=True) as r:
+ for raw in r:
+ line = raw.decode().strip()
+ if not line.startswith("data:"):
+ continue
+ payload = line[len("data:") :].strip()
+ if payload == "[DONE]":
+ saw_done = True
+ break
+ chunk = json.loads(payload)
+ assert chunk["object"] == "chat.completion.chunk"
+ delta = chunk["choices"][0]["delta"]
+ saw_role = saw_role or delta.get("role") == "assistant"
+ saw_content = saw_content or bool(delta.get("content"))
+ assert saw_role and saw_content and saw_done
+
+
+def test_multibyte_streaming_integrity():
+ # Byte-level BPE can split a multi-byte character across tokens; the stream
+ # must reassemble it, not abort with a UTF-8 decode error.
+ body = {
+ "model": MODEL,
+ "messages": [
+ {"role": "user", "content": "Reply with exactly: 你好世界 🌍 café"}
+ ],
+ "max_tokens": 32,
+ "temperature": 0.0,
+ "stream": True,
+ }
+ content, saw_done, saw_error = "", False, False
+ with _post("/chat/completions", body, stream=True) as r:
+ for raw in r:
+ line = raw.decode().strip()
+ if not line.startswith("data:"):
+ continue
+ payload = line[len("data:") :].strip()
+ if payload == "[DONE]":
+ saw_done = True
+ break
+ chunk = json.loads(payload)
+ if "error" in chunk:
+ saw_error = True
+ content += (
+ chunk["choices"][0]["delta"].get("content", "")
+ if chunk.get("choices")
+ else ""
+ )
+ assert saw_done and not saw_error
+ assert isinstance(content, str) and content # reassembled, valid UTF-8
+
+
+def test_usage_chunk_in_stream():
+ body = {
+ "model": MODEL,
+ "messages": [{"role": "user", "content": "Say hi."}],
+ "max_tokens": 16,
+ "stream": True,
+ "stream_options": {"include_usage": True},
+ }
+ usage = None
+ with _post("/chat/completions", body, stream=True) as r:
+ for raw in r:
+ line = raw.decode().strip()
+ if not line.startswith("data:"):
+ continue
+ payload = line[len("data:") :].strip()
+ if payload == "[DONE]":
+ break
+ chunk = json.loads(payload)
+ if chunk.get("usage"):
+ usage = chunk["usage"]
+ assert usage is not None, "no usage chunk emitted with include_usage"
+ assert usage["prompt_tokens"] > 0 and usage["completion_tokens"] > 0
+ assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"]
+
+
+WEATHER_TOOL = {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the current weather for a city.",
+ "parameters": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ },
+}
+
+
+def test_tool_call_response_shape():
+ body = {
+ "model": MODEL,
+ "messages": [
+ {"role": "user", "content": "What is the weather in Paris? Use the tool."}
+ ],
+ "tools": [WEATHER_TOOL],
+ "max_tokens": 128,
+ "temperature": 0.0,
+ }
+ with _post("/chat/completions", body) as r:
+ data = json.loads(r.read())
+ calls = data["choices"][0]["message"].get("tool_calls")
+ assert calls, "expected tool_calls in response"
+ tc = calls[0]
+ assert tc["type"] == "function"
+ assert tc["id"]
+ assert tc["function"]["name"] == "get_weather"
+ json.loads(tc["function"]["arguments"]) # arguments is a JSON string
+ assert data["choices"][0]["finish_reason"] == "tool_calls"
+
+
+def test_error_body_shape():
+ # Over-long prompt -> structured 400 (OpenAI error envelope), not a 500/drop.
+ body = {
+ "model": MODEL,
+ "messages": [{"role": "user", "content": "word " * 40000}],
+ "max_tokens": 8,
+ }
+ try:
+ _post("/chat/completions", body)
+ raise AssertionError("expected an HTTP error for over-long prompt")
+ except urllib.error.HTTPError as e:
+ assert 400 <= e.code < 500
+ err = json.loads(e.read())["error"]
+ assert err["message"] and err["type"]
diff --git a/extension/llm/server/cpp/CMakeLists.txt b/extension/llm/server/cpp/CMakeLists.txt
new file mode 100644
index 00000000000..653cf61bea8
--- /dev/null
+++ b/extension/llm/server/cpp/CMakeLists.txt
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Generic model-execution worker for standard .pte TextLLM models. One binary,
+# no registry/factory: it constructs TextLLMEngine/TextLLMSession directly and
+# speaks the JSONL worker protocol (worker_client.py). Model execution is C++
+# only — the Python server is HTTP/control plane.
+#
+# Build like the example runners (standalone), e.g. from this directory: cmake
+# -S . -B /extension/llm/server/cpp \
+# -DCMAKE_PREFIX_PATH= -DEXECUTORCH_BUILD_XNNPACK=ON cmake
+# --build <...>/extension/llm/server/cpp --target text_llm_worker
+
+cmake_minimum_required(VERSION 3.24)
+project(llm_server_workers)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
+
+include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
+
+set(_common_include_directories ${EXECUTORCH_ROOT}/..)
+# Vendored single-include nlohmann/json for the worker protocol (no new dep).
+set(_json_include
+ ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include
+)
+
+# gflags
+set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../third-party/gflags)
+find_package(gflags REQUIRED)
+
+# executorch
+list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../../..)
+find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
+executorch_target_link_options_shared_lib(executorch)
+
+set(link_libraries executorch gflags)
+
+# CPU ops
+list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
+executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
+
+# Custom + quantized kernels that export_llm models need, whole-archived so the
+# static op registrations survive the linker: llama::custom_sdpa (from
+# use_sdpa_with_kv_cache) and quantized_decomposed ops (from quantized exports).
+# Without these the model loads but execution fails with "Missing operator".
+if(TARGET custom_ops)
+ executorch_target_link_options_shared_lib(custom_ops)
+ list(APPEND link_libraries custom_ops)
+endif()
+if(TARGET quantized_ops_lib)
+ list(APPEND link_libraries quantized_kernels quantized_ops_lib)
+ executorch_target_link_options_shared_lib(quantized_ops_lib)
+endif()
+
+# Extensions (Engine/Session lives in extension_llm_runner)
+list(
+ APPEND
+ link_libraries
+ extension_llm_runner
+ extension_module
+ extension_data_loader
+ extension_tensor
+ extension_flat_tensor
+)
+
+# XNNPACK: the standard CPU backend for normal .pte TextLLM models.
+list(APPEND link_libraries xnnpack_backend)
+executorch_target_link_options_shared_lib(xnnpack_backend)
+
+# Tokenizer
+list(APPEND link_libraries tokenizers::tokenizers)
+
+add_executable(text_llm_worker text_llm_worker.cpp)
+target_include_directories(
+ text_llm_worker PUBLIC ${_common_include_directories} ${_json_include}
+)
+target_link_libraries(text_llm_worker PUBLIC ${link_libraries})
+
+if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
+ target_link_options_gc_sections(text_llm_worker)
+ target_link_options(text_llm_worker PRIVATE "LINKER:-s")
+endif()
diff --git a/extension/llm/server/cpp/text_llm_worker.cpp b/extension/llm/server/cpp/text_llm_worker.cpp
new file mode 100644
index 00000000000..f7bb9d69915
--- /dev/null
+++ b/extension/llm/server/cpp/text_llm_worker.cpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+// Generic model-execution worker for standard .pte TextLLM models.
+//
+// All model execution lives here in C++ (via TextLLMEngine / TextLLMSession,
+// the stable serving abstraction) — no Python model code, no pybind, no
+// in-process Python serving. The OpenAI control plane (Python) spawns this
+// process and drives it over JSONL on stdin/stdout (see worker_client.py). The
+// JSONL protocol and the decode loop are shared across all workers in
+// worker_loop.h; this file only constructs the engine/session/tokenizer.
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+DEFINE_string(model_path, "", "Self-contained model .pte file path.");
+DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
+
+namespace {
+namespace llm = ::executorch::extension::llm;
+using ::executorch::runtime::Error;
+} // namespace
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ if (FLAGS_model_path.empty() || FLAGS_tokenizer_path.empty()) {
+ ET_LOG(
+ Error, "text_llm_worker: --model_path and --tokenizer_path required");
+ return 1;
+ }
+
+ // TextLLMEngine requires a self-contained .pte: external .ptd weights are not
+ // supported for shared sessions (a model-specific worker handles that path).
+ auto engine = llm::TextLLMEngine::create(
+ FLAGS_model_path, FLAGS_tokenizer_path, std::nullopt);
+ if (!engine) {
+ ET_LOG(Error, "text_llm_worker: failed to create engine");
+ return 1;
+ }
+ auto session_result = engine->create_session();
+ if (session_result.error() != Error::Ok) {
+ ET_LOG(Error, "text_llm_worker: failed to create session");
+ return 1;
+ }
+ auto session = std::move(session_result.get());
+
+ // The session decodes token ids to text internally; this tokenizer encodes
+ // the rendered prompt to ids. Same tokenizer.json -> same vocabulary.
+ auto tokenizer = llm::load_tokenizer(FLAGS_tokenizer_path);
+ if (!tokenizer) {
+ ET_LOG(Error, "text_llm_worker: failed to load tokenizer");
+ return 1;
+ }
+
+ return llm::run_worker_stdio_loop(*session, *tokenizer, engine->metadata());
+}
diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h
new file mode 100644
index 00000000000..883bcac69cd
--- /dev/null
+++ b/extension/llm/server/cpp/worker_loop.h
@@ -0,0 +1,192 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+// Shared model-worker generation loop + JSONL protocol, used by every model
+// worker (the generic text_llm_worker and model-specific workers like
+// qwen3_5_moe_worker). A worker only constructs its engine/session/tokenizer
+// and calls run_worker_stdio_loop(); the protocol and the decode loop live here
+// once, so protocol changes (e.g. multi-session) land in a single place.
+//
+// Protocol (one JSON object per line; matches worker_client.py):
+// worker -> stdout, once: {"ready": true}
+// client -> stdin, per request: {"prompt": str, "max_new_tokens": int,
+// "temperature": float, "stop": [str, ...]}
+// worker -> stdout, per request: {"token": str} * (streamed)
+// {"done": true, "prompt_tokens": int,
+// "completion_tokens": int,
+// "finish_reason": "stop" | "length"}
+// or {"error": str}
+//
+// stdout carries ONLY protocol JSON; all logs go to stderr (ET_LOG). One
+// request at a time (the control plane serializes; V1 is one worker == one
+// session).
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace executorch {
+namespace extension {
+namespace llm {
+
+// Emit one protocol object as a JSON line on stdout. error_handler::replace
+// keeps a stray invalid UTF-8 byte (byte-level BPE) from aborting
+// serialization.
+inline void worker_emit(const nlohmann::json& obj) {
+ std::cout << obj.dump(
+ -1, ' ', false, nlohmann::json::error_handler_t::replace)
+ << "\n";
+ std::cout.flush();
+}
+
+// One generation request: reset the session, encode the prompt, prefill, then
+// loop decode_one() streaming complete-UTF-8 text pieces. A terminal step (EOS
+// or cooperative stop) ends generation and is not emitted or counted. Throws
+// std::runtime_error on failure; the caller reports it as {"error": ...}.
+inline void worker_handle_request(
+ LLMSession& session,
+ ::tokenizers::Tokenizer& tokenizer,
+ const std::unordered_map& metadata,
+ const nlohmann::json& req) {
+ const std::string prompt = req.at("prompt").get();
+ int64_t max_new = req.value("max_new_tokens", static_cast(-1));
+ const float temperature = req.value("temperature", 0.0f);
+ // Stop strings (the request's `stop` sequences): terminate at the token
+ // boundary where one appears so we don't generate to EOS/max_new past it. The
+ // control plane also enforces these as a backstop.
+ const std::vector stops =
+ req.value("stop", std::vector{});
+
+ if (session.reset() != ::executorch::runtime::Error::Ok) {
+ throw std::runtime_error("session reset failed");
+ }
+ // No special tokens: the prompt is already rendered (the control plane
+ // applied the chat template), matching the runner's own encode path.
+ auto encode_result = tokenizer.encode(prompt, /*bos=*/0, /*eos=*/0);
+ if (!encode_result.ok()) {
+ throw std::runtime_error("prompt encode failed");
+ }
+ std::vector ids = std::move(*encode_result);
+ if (ids.empty()) {
+ throw std::runtime_error("empty prompt");
+ }
+ const int64_t num_prompt = static_cast(ids.size());
+
+ // Bound generation to the context window: default to filling the remaining
+ // room, and clamp an explicit max_new_tokens too, so decode never steps past
+ // the window (which would error mid-generation after partial output).
+ const auto ctx_it = metadata.find(kMaxContextLen);
+ if (ctx_it != metadata.end()) {
+ const int64_t room = ctx_it->second - num_prompt;
+ if (room <= 0) {
+ throw std::runtime_error(
+ "prompt fills the context window; no room to generate");
+ }
+ if (max_new <= 0 || max_new > room) {
+ max_new = room;
+ }
+ } else if (max_new <= 0) {
+ max_new = 2048;
+ }
+
+ SamplingConfig sampling;
+ sampling.temperature = temperature;
+ if (session.prefill_tokens(std::move(ids), &sampling) !=
+ ::executorch::runtime::Error::Ok) {
+ throw std::runtime_error("prefill failed");
+ }
+
+ std::string buf; // bytes not yet forming a complete UTF-8 prefix
+ std::string pending; // complete-UTF-8 text held back for stop-string matching
+ int64_t num_generated = 0;
+ std::string finish = "length"; // EOS or stop string -> "stop"
+ bool stop_string = false; // a request stop string was matched
+ for (int64_t step = 0; step < max_new; ++step) {
+ auto step_result = session.decode_one(sampling);
+ if (step_result.error() != ::executorch::runtime::Error::Ok) {
+ throw std::runtime_error("decode failed");
+ }
+ const auto& d = step_result.get();
+ if (d.is_terminal) {
+ finish = "stop";
+ break; // terminal step (EOS / cooperative stop): not emitted or counted
+ }
+ ++num_generated;
+ buf += d.text_piece;
+ const size_t cut = utf8_complete_prefix_len(buf);
+ if (cut > 0) {
+ pending += buf.substr(0, cut);
+ buf.erase(0, cut);
+ }
+ bool stop_hit = false;
+ const size_t safe = stop_safe_prefix_len(pending, stops, stop_hit);
+ if (safe > 0) {
+ worker_emit({{"token", pending.substr(0, safe)}});
+ pending.erase(0, safe);
+ }
+ if (stop_hit) {
+ finish = "stop"; // reached a stop string: drop it and everything after
+ stop_string = true;
+ break;
+ }
+ }
+ if (!stop_string) {
+ // EOS or length: flush held-back text + any trailing incomplete bytes
+ // (replaced if invalid). A stop-string hit drops the remainder instead.
+ pending += buf;
+ if (!pending.empty()) {
+ worker_emit({{"token", pending}});
+ }
+ }
+ // finish_reason: "stop" if the model emitted EOS or hit a stop string, else
+ // "length" — it ran to max_new (possibly clamped to the context window).
+ worker_emit(
+ {{"done", true},
+ {"prompt_tokens", num_prompt},
+ {"completion_tokens", num_generated},
+ {"finish_reason", finish}});
+}
+
+// Emit {"ready": true}, then read JSONL requests from stdin and dispatch each
+// to worker_handle_request, reporting exceptions as {"error": ...} and
+// continuing to serve. Returns 0 when stdin closes.
+inline int run_worker_stdio_loop(
+ LLMSession& session,
+ ::tokenizers::Tokenizer& tokenizer,
+ const std::unordered_map& metadata) {
+ worker_emit({{"ready", true}});
+ std::string line;
+ while (std::getline(std::cin, line)) {
+ if (line.empty()) {
+ continue;
+ }
+ try {
+ worker_handle_request(
+ session, tokenizer, metadata, nlohmann::json::parse(line));
+ } catch (const std::exception& e) { // report and keep serving
+ worker_emit({{"error", std::string(e.what())}});
+ }
+ }
+ return 0;
+}
+
+} // namespace llm
+} // namespace extension
+} // namespace executorch
diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md
new file mode 100644
index 00000000000..e14e6176c81
--- /dev/null
+++ b/extension/llm/server/python/README.md
@@ -0,0 +1,169 @@
+# ExecuTorch LLM Server — Python
+
+A thin OpenAI-compatible HTTP server for ExecuTorch LLMs. The Python process is
+the **control plane** only — HTTP, OpenAI protocol, chat templating, tool
+parsing, request validation. Model execution runs in a separate **C++ worker
+process** (`text_llm_worker`) that the server drives over a small JSONL protocol.
+The control plane never loads a model, links a backend, or imports a runtime
+pybind.
+
+## Install
+
+```bash
+pip install -r requirements.txt
+# transformers is optional but recommended for model-correct chat templates
+pip install transformers
+```
+
+The server itself is pure Python (fastapi, pydantic, httpx). The model runs in
+the C++ worker, which you build standalone (like the example runners) from
+`../cpp`:
+
+```bash
+cmake -S ../cpp -B /extension/llm/server/cpp \
+ -DCMAKE_PREFIX_PATH= -DEXECUTORCH_BUILD_XNNPACK=ON
+cmake --build /extension/llm/server/cpp --target text_llm_worker
+# -> /extension/llm/server/cpp/text_llm_worker
+```
+
+### Model & runtime requirements
+
+LLM `.pte` files exported via `export_llm` use ExecuTorch custom/quantized ops:
+`use_sdpa_with_kv_cache` → `llama::custom_sdpa`, and quantized exports
+(`embedding_quantize`, `8da4w`, ...) → `quantized_decomposed` ops. The worker
+binary links the kernel libraries that provide them (the C++ equivalents of
+`-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON` /
+`-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON`); see the canonical
+[Llama README](../../../../examples/models/llama/README.md). Without them the
+worker fails to load the method.
+
+Tokenizer: pass the model's tokenizer — `tokenizer.json` (HF, e.g. Qwen3) or
+`tokenizer.model` (Llama); the worker auto-detects. If you see an RE2 lookahead
+warning it falls back to PCRE2 and still works (build with
+`-DSUPPORT_REGEX_LOOKAHEAD=ON` for the native regex path).
+
+## Run
+
+```bash
+python -m executorch.extension.llm.server.python.server \
+ --model-path /path/to/model.pte \
+ --tokenizer-path /path/to/tokenizer.bin \
+ --hf-tokenizer Qwen/Qwen2.5-Coder-7B-Instruct \
+ --model-id qwen2.5-coder \
+ --host 127.0.0.1 --port 8000
+```
+
+The server spawns the worker (it blocks until the worker has loaded the model and
+reported ready, so a slow load surfaces at startup, not on the first request).
+
+`--hf-tokenizer` is **required** (it applies the model's real `chat_template`)
+unless you pass `--allow-chatml-fallback` to opt into approximate generic ChatML
+— which is wrong for many instruct/tool models and can't reproduce controls like
+`enable_thinking`.
+
+Key flags:
+
+| Flag | Effect |
+|------|--------|
+| `--hf-tokenizer` | model's HF chat template (required unless fallback) |
+| `--allow-chatml-fallback` | opt into approximate ChatML when no HF tokenizer |
+| `--no-think` | default `enable_thinking=False` (e.g. Qwen3) |
+| `--max-context N` | reject over-long prompts with 400 instead of failing mid-gen |
+| `--num-runners N` | Worker processes — **1 only** (one worker hosts many isolated sessions on one weight load; more would duplicate weights) |
+| `--worker-bin PATH` | path to the `text_llm_worker` binary (default: `cmake-out/extension/llm/server/cpp/text_llm_worker`) |
+
+## Use from an agent harness
+
+- **opencode** (`opencode.json`):
+ ```json
+ { "provider": { "executorch": {
+ "npm": "@ai-sdk/openai-compatible",
+ "options": { "baseURL": "http://127.0.0.1:8000/v1" },
+ "models": { "qwen2.5-coder": { "name": "Qwen2.5-Coder (ExecuTorch)" } } } } }
+ ```
+- **pi** (`~/.pi/agent/models.json`):
+ ```json
+ { "providers": { "executorch": {
+ "baseUrl": "http://127.0.0.1:8000/v1", "api": "openai-completions",
+ "apiKey": "x", "models": [ { "id": "qwen2.5-coder" } ] } } }
+ ```
+
+## Validate
+
+Two layers, both contract-focused (assert on the wire, not internals):
+
+```bash
+# 1. Hermetic unit tests — fake worker, no model/GPU/subprocess, fast (CI-friendly).
+pip install pytest httpx
+pytest tests/
+
+# 2. Conformance — black-box, against a LIVE server (real model, or llama.cpp/mlx-lm).
+OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest ../conformance/test_openai_contract.py
+```
+
+`tests/` builds a `SessionRuntime` over a single `FakeRunner` worker, so the
+real server/protocol/streaming code is tested over HTTP without a `.pte`. The
+worker JSONL protocol is covered separately by `tests/test_worker_client.py`.
+
+## Architecture
+
+Control plane (this dir, Python): an OpenAI adapter (`serving_chat`) over a
+stateful `SessionRuntime` over one `WorkerClient` — server, protocol, chat
+templating, streaming bridge, tool parsing — no CUDA, no model, no pybind. Data
+plane (C++): a worker process (`text_llm_worker`) that owns all model state
+(many isolated sessions on one weight load, warm-resume prefix logic) and does
+all token stepping and KV mutation; it speaks one JSON object per line on
+stdin/stdout.
+
+JSONL protocol (stdout carries protocol JSON only; logs go to stderr):
+
+```
+worker -> stdout, once at startup: {"ready": true}
+client -> stdin, per request: {"prompt", "max_new_tokens", "temperature"}
+worker -> stdout, per request: {"token": str} * (streamed)
+ {"done": true, "prompt_tokens", "completion_tokens"}
+ or {"error": str}
+```
+
+Process isolation is the reliable shape for CUDA/AOTI models: executing the model
+inside a live asyncio server process can segfault (validated with Qwen3.5-MoE);
+the worker is a plain process with no asyncio loop, and the control plane only
+does blocking pipe I/O on its executor thread.
+
+| File | Role |
+|------|------|
+| `server.py` | FastAPI app, routes, CLI entrypoint, worker spawn |
+| `protocol.py` | OpenAI request/response schemas |
+| `chat_template.py` | messages (+tools) → prompt string |
+| `worker_client.py` | spawn a worker process + drive it over JSONL (raw transport) |
+| `session_runtime.py` | stateful runtime over one worker: open/generate/reset/close + streaming bridge |
+| `serving_chat.py` | `/v1/chat/completions` OpenAI adapter (streaming + non-streaming, stop, tools) |
+| `tool_parsers/` | Hermes/Qwen `` parser only |
+| `cpp/text_llm_worker.cpp` | the generic C++ worker binary |
+
+### Model workers
+
+The generic `text_llm_worker` serves the text model (`TextLLMEngine`). A new
+model ships its own worker binary under its example (e.g.
+`examples/models/qwen3_5_moe/qwen35_moe_worker.cpp` constructs `Qwen35MoEEngine`)
+that speaks the same JSONL protocol, plus a launcher that points the same control
+plane at that binary via `--worker-bin`. The dependency points one way: a model
+example may reuse the generic control plane; the generic control plane never
+imports an example. Backend specifics (CUDA/AOTI, Metal) stay inside the worker.
+
+## Scope & caveats
+
+Deliberately narrow (reliability-first): Hermes/Qwen tool calling only;
+unsupported sampling params are rejected, not ignored. **One worker process,
+serialized execution** (one in-flight request; concurrent requests queue).
+Session capacity is determined by the worker/engine — a single worker hosts many
+isolated sessions on one weight load — so `--num-runners` accepts 1; extra worker
+processes would each carry their own copy of the weights.
+
+Cancellation is best-effort: a worker request runs to completion and is not
+interruptible mid-generation in V1, so `runner.stop()` means "the control plane
+stops consuming and the worker finishes the current request" rather than a hard
+cancel. There is **no prefix cache in V1 serving**; if KV prefix reuse returns it
+will live inside the worker/session, not in the Python control plane. Multiple
+workers, weight sharing across sessions on a backend that supports it, adaptive
+thinking, and multi-session subagents are future work.
diff --git a/extension/llm/server/python/server.py b/extension/llm/server/python/server.py
new file mode 100644
index 00000000000..94c55479275
--- /dev/null
+++ b/extension/llm/server/python/server.py
@@ -0,0 +1,198 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""OpenAI-compatible HTTP server for ExecuTorch LLMs.
+
+Point any OpenAI-compatible agent harness (pi, opencode, ...) at
+``http://:/v1``.
+
+This process is the CONTROL PLANE only: FastAPI/uvicorn + OpenAI protocol, chat
+templating, tool parsing, request validation. It runs NO model code and imports
+no runtime pybind. Model execution lives in a separate C++ worker process
+(``text_llm_worker``) driven over JSONL via WorkerClient.
+
+One worker process, serialized execution (one in-flight request; concurrent
+requests queue). Session capacity is set by the worker/engine -- a single worker
+hosts many isolated sessions on one weight load; extra worker processes would
+duplicate the weights, so `--num-runners` accepts 1.
+
+Example:
+ python -m executorch.extension.llm.server.python.server \\
+ --model-path model.pte --tokenizer-path tokenizer.bin \\
+ --hf-tokenizer Qwen/Qwen2.5-Coder-7B-Instruct --model-id qwen2.5-coder
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+from fastapi import FastAPI
+from fastapi.responses import JSONResponse, StreamingResponse
+
+from .chat_template import ChatTemplate
+from .errors import APIError
+from .protocol import ChatCompletionRequest, ModelCard, ModelList
+from .serving_chat import ServingChat
+from .session_runtime import SessionRuntime
+from .tool_parsers import HermesDetector
+from .worker_client import spawn_worker
+
+logger = logging.getLogger(__name__)
+
+
+def _default_worker_bin() -> str:
+ repo_root = Path(__file__).resolve().parents[4]
+ return str(
+ repo_root
+ / "cmake-out"
+ / "extension"
+ / "llm"
+ / "server"
+ / "cpp"
+ / "text_llm_worker"
+ )
+
+
+def _spawn(args):
+ """Spawn the C++ text_llm_worker and return a ready WorkerClient."""
+ env = dict(os.environ)
+ conda = os.environ.get("CONDA_PREFIX")
+ if conda:
+ env["LD_LIBRARY_PATH"] = f"{conda}/lib:" + env.get("LD_LIBRARY_PATH", "")
+ worker_bin = args.worker_bin or _default_worker_bin()
+ cmd = [
+ worker_bin,
+ "--model_path",
+ args.model_path,
+ "--tokenizer_path",
+ args.tokenizer_path,
+ ]
+ logger.info("Starting model worker subprocess (loads the model once)...")
+ return spawn_worker(cmd, env=env)
+
+
+def build_app(serving: ServingChat, model_id: str) -> FastAPI:
+ app = FastAPI(title="ExecuTorch LLM Server")
+
+ @app.get("/health")
+ async def health():
+ return {"status": "ok"}
+
+ @app.get("/v1/models")
+ async def list_models() -> ModelList:
+ return ModelList(data=[ModelCard(id=model_id)])
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(req: ChatCompletionRequest):
+ # Typed param → FastAPI validates the body and returns 422 on bad input.
+ # APIError (e.g. context_length_exceeded) → structured 4xx/5xx, never a
+ # dropped connection. Mid-stream failures are handled inside the stream.
+ try:
+ result = await serving.create(req)
+ except APIError as e:
+ return JSONResponse(e.body(), status_code=e.status)
+ if req.stream:
+ return StreamingResponse(result, media_type="text/event-stream")
+ return JSONResponse(result.model_dump(exclude_none=True))
+
+ return app
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(description="ExecuTorch OpenAI-compatible LLM server")
+ p.add_argument(
+ "--model-path",
+ required=True,
+ help="Path to the self-contained .pte model (external .ptd weights are not "
+ "supported by the generic text worker; use a model-specific launcher).",
+ )
+ p.add_argument("--tokenizer-path", required=True, help="Path to the tokenizer")
+ p.add_argument(
+ "--hf-tokenizer",
+ default=None,
+ help="HF tokenizer id/dir for model-correct chat templating (required unless "
+ "--allow-chatml-fallback).",
+ )
+ p.add_argument(
+ "--allow-chatml-fallback",
+ action="store_true",
+ help="Allow approximate generic ChatML templating when --hf-tokenizer is absent. "
+ "Off by default: the fallback can't reproduce model-specific controls.",
+ )
+ p.add_argument(
+ "--model-id", default="executorch", help="Model id reported on /v1/models"
+ )
+ p.add_argument(
+ "--no-think",
+ action="store_true",
+ help="Default reasoning off (sends enable_thinking=False to the chat template, "
+ "e.g. Qwen3). Per-request chat_template_kwargs still override this.",
+ )
+ p.add_argument(
+ "--max-context",
+ type=int,
+ default=None,
+ help="Model context window; if set (and a tokenizer is available), prompts that "
+ "exceed it are rejected with 400 context_length_exceeded instead of failing mid-generation. "
+ "Set this to match the value used at export.",
+ )
+ p.add_argument(
+ "--num-runners",
+ type=int,
+ default=1,
+ help="Worker processes. 1 only: one worker hosts many isolated sessions "
+ "on a single weight load; more workers would duplicate the weights.",
+ )
+ p.add_argument(
+ "--worker-bin",
+ default=None,
+ help="Path to the text_llm_worker binary "
+ "(default: cmake-out/extension/llm/server/cpp/text_llm_worker).",
+ )
+ p.add_argument("--host", default="127.0.0.1")
+ p.add_argument("--port", type=int, default=8000)
+ args = p.parse_args()
+ logging.basicConfig(level=logging.INFO)
+
+ if args.num_runners != 1:
+ p.error(
+ "Only 1 worker process is supported (it hosts many isolated sessions "
+ "on one weight load); more workers would duplicate the weights."
+ )
+
+ default_template_kwargs = {"enable_thinking": False} if args.no_think else None
+ # Requires --hf-tokenizer unless --allow-chatml-fallback (raises otherwise).
+ template = ChatTemplate(
+ args.hf_tokenizer,
+ default_template_kwargs=default_template_kwargs,
+ allow_fallback=args.allow_chatml_fallback,
+ )
+ worker = _spawn(args) # one worker hosting many isolated sessions
+ runtime = SessionRuntime(worker)
+ serving = ServingChat(
+ runtime,
+ template,
+ args.model_id,
+ max_context=args.max_context,
+ # Hermes JSON is the generic default; a model-specific server (e.g. a
+ # Qwen launcher) selects the Qwen XML detector instead.
+ tool_detector_cls=HermesDetector,
+ )
+
+ app = build_app(serving, args.model_id)
+
+ @app.on_event("shutdown")
+ def _stop_worker():
+ runtime.close_worker()
+
+ import uvicorn # imported here so build_app() is usable without the ASGI server
+
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py
new file mode 100644
index 00000000000..fdd8032a06f
--- /dev/null
+++ b/extension/llm/server/python/serving_chat.py
@@ -0,0 +1,446 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""/v1/chat/completions OpenAI adapter: validates requests, renders the chat
+template, parses tool calls, and formats OpenAI responses. It owns no model or
+session state -- generation goes through SessionRuntime, and the token-ID warm-
+resume transcript lives in OpenAITranscriptState."""
+
+import json
+import logging
+import math
+from typing import AsyncIterator, Optional
+
+from .chat_template import ChatTemplate
+from .errors import APIError, ContextLengthExceeded, GenerationError
+from .protocol import (
+ _new_id,
+ ChatCompletionChunk,
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ Choice,
+ ChunkChoice,
+ DeltaMessage,
+ FunctionCall,
+ ResponseMessage,
+ ToolCall,
+ Usage,
+)
+from .session_runtime import GenerationOptions, GenStats, PromptInput, SessionRuntime
+from .tool_parsers import HermesDetector, ToolCallItem
+
+logger = logging.getLogger(__name__)
+
+
+def _earliest_stop(text: str, stops: list[str]) -> Optional[int]:
+ """Index of the earliest special-token occurrence in `text`, or None."""
+ best = None
+ for s in stops:
+ i = text.find(s)
+ if i != -1 and (best is None or i < best):
+ best = i
+ return best
+
+
+class ServingChat:
+ def __init__(
+ self,
+ runtime: SessionRuntime,
+ template: ChatTemplate,
+ model_id: str,
+ max_context: Optional[int] = None,
+ tool_detector_cls: Optional[type[HermesDetector]] = None,
+ ):
+ self._runtime = runtime
+ self._template = template
+ self._model_id = model_id
+ self._max_context = max_context
+ # Detector CLASS; a fresh instance is created per request so streaming
+ # state is never shared across concurrent requests.
+ self._tool_detector_cls = tool_detector_cls
+ # Special tokens (e.g. <|im_end|>) the runner decodes to text; we cut the
+ # visible content at the first one so they don't leak into responses.
+ self._stops = template.special_tokens()
+
+ @staticmethod
+ def _tool_schemas(req: ChatCompletionRequest) -> dict[str, dict]:
+ """Map each defined tool name to its JSON-schema ``parameters`` object.
+
+ The detector uses the key set to validate names and the schema to coerce
+ values to their declared types (the Qwen XML format is stringly-typed)."""
+ schemas = {}
+ for t in req.tools or []:
+ fn = t.get("function", {}) if isinstance(t, dict) else {}
+ name = fn.get("name")
+ if name:
+ schemas[name] = fn.get("parameters") or {}
+ return schemas
+
+ def _strip_specials(self, text: str) -> str:
+ cut = _earliest_stop(text, self._stops)
+ return text[:cut] if cut is not None else text
+
+ @staticmethod
+ def _to_openai_tool_call(item: ToolCallItem) -> ToolCall:
+ return ToolCall(
+ index=item.tool_index,
+ id=_new_id("call"),
+ type="function",
+ function=FunctionCall(name=item.name, arguments=item.arguments),
+ )
+
+ def _tools_active(self, req: ChatCompletionRequest) -> bool:
+ # tool_choice="none" disables tools even when the client sends them.
+ return bool(self._tool_detector_cls and req.tools and req.tool_choice != "none")
+
+ @staticmethod
+ def _request_stops(req: ChatCompletionRequest) -> list[str]:
+ s = req.stop
+ if not s:
+ return []
+ return [s] if isinstance(s, str) else [x for x in s if x]
+
+ @staticmethod
+ def _apply_stop(text: str, stops: list[str]) -> str:
+ """Truncate at the earliest stop string (the stop itself is excluded)."""
+ cut = _earliest_stop(text, stops)
+ return text[:cut] if cut is not None else text
+
+ def _truncate_raw(self, text: str, req: ChatCompletionRequest) -> str:
+ """Cut raw model output at the earliest special token or request stop
+ sequence BEFORE tool parsing, so a tool call (or any text) past the stop
+ boundary is neither parsed nor emitted."""
+ return self._apply_stop(text, self._stops + self._request_stops(req))
+
+ async def _collect_until_stop(self, stream: AsyncIterator[str], stops: list[str]):
+ """Accumulate a buffered (non-streamed) generation into one string,
+ halting the runtime early once a stop string (special token or request
+ stop) appears, then draining so stats finalize. Returns (text, stopped):
+ `stopped` lets the caller force finish_reason="stop" even when tokens
+ queued before the runtime observed stop() pushed the count to max_tokens."""
+ text = ""
+ stopped = False
+ async for tok in stream:
+ text += tok
+ if stops and _earliest_stop(text, stops) is not None:
+ stopped = True
+ self._runtime.stop()
+ async for _ in stream: # drain so stats_cb fires
+ pass
+ break
+ return text, stopped
+
+ def _extract_tools(self, req: ChatCompletionRequest, text: str):
+ """Returns (tool_calls | None, content_text). Falls back to plain text."""
+ if self._tools_active(req):
+ parsed = self._tool_detector_cls().detect_and_parse(
+ text, self._tool_schemas(req)
+ )
+ if parsed.calls:
+ content = self._strip_specials(parsed.normal_text) or None
+ return [self._to_openai_tool_call(c) for c in parsed.calls], content
+ text = parsed.normal_text
+ return None, self._strip_specials(text)
+
+ async def _clean(
+ self, stream: AsyncIterator[str], stops: list[str], on_stop=None
+ ) -> AsyncIterator[str]:
+ # Yield text up to the earliest stop string (special token or request
+ # `stop`), buffering across tokens so a stop spanning chunks is caught.
+ # On a hit: optionally stop the runner early, then drain the source so it
+ # finalizes (usage stats recorded, worker thread joined).
+ hold = (
+ max((len(s) for s in stops), default=1) - 1
+ ) # keep a possible partial-stop tail
+ buf = ""
+ async for token in stream:
+ buf += token
+ cut = _earliest_stop(buf, stops)
+ if cut is not None:
+ if cut > 0:
+ yield buf[:cut]
+ if on_stop is not None:
+ on_stop()
+ async for _ in stream: # drain so stats_cb fires
+ pass
+ return
+ if hold == 0:
+ yield buf
+ buf = ""
+ elif len(buf) > hold:
+ yield buf[:-hold]
+ buf = buf[-hold:]
+ if buf:
+ yield buf
+
+ def _options(self, req: ChatCompletionRequest) -> GenerationOptions:
+ return GenerationOptions(
+ max_new_tokens=req.resolved_max_tokens(),
+ temperature=req.temperature if req.temperature is not None else 0.0,
+ # Let the worker terminate at the same boundary the control plane
+ # would cut: the model's special tokens (e.g. <|im_end|>) AND request
+ # stop sequences. This stops generation at end-of-turn even when the
+ # worker's EOS-by-token-id check misses it, instead of running to
+ # max_new (or erroring) past the turn. The server's
+ # _clean/_collect_until_stop still re-apply these as a backstop.
+ stop=self._stops + self._request_stops(req),
+ )
+
+ def _finish_reason(
+ self,
+ req: ChatCompletionRequest,
+ completion_tokens: int,
+ tool_calls=None,
+ stopped: bool = False,
+ worker_finish: Optional[str] = None,
+ ) -> str:
+ # Precedence: tool call > stop boundary > worker reason > length heuristic.
+ # `stopped` (a server-side stop sequence / special token) wins even over
+ # the worker, since that truncation happened in the control plane.
+ if tool_calls:
+ return "tool_calls"
+ if stopped:
+ return "stop"
+ # The worker knows whether it hit EOS ("stop") or ran to max_new ("length",
+ # possibly a clamp to the context window) — trust it over the token-count
+ # heuristic, which can't see a silent clamp.
+ if worker_finish in ("stop", "length"):
+ return worker_finish
+ mt = req.resolved_max_tokens()
+ return "length" if mt and mt > 0 and completion_tokens >= mt else "stop"
+
+ @staticmethod
+ def _reject_invalid_values(req: ChatCompletionRequest) -> None:
+ """Reject out-of-range values (invalid_value); these take precedence over
+ the unsupported-parameter error."""
+ if req.temperature is not None and (
+ not math.isfinite(req.temperature)
+ or req.temperature < 0.0
+ or req.temperature > 2.0
+ ):
+ raise APIError(
+ 400,
+ f"temperature must be between 0 and 2 (got {req.temperature}).",
+ "invalid_request_error",
+ "invalid_value",
+ )
+ # max_tokens / max_completion_tokens, if given, must be positive integers
+ # (OpenAI rejects 0 and negatives; our -1 sentinel means "unset/auto").
+ for field_name in ("max_tokens", "max_completion_tokens"):
+ v = getattr(req, field_name)
+ if v is not None and v <= 0:
+ raise APIError(
+ 400,
+ f"{field_name} must be a positive integer (got {v}).",
+ "invalid_request_error",
+ "invalid_value",
+ )
+
+ @staticmethod
+ def _reject_unsupported_params(req: ChatCompletionRequest) -> None:
+ """Reject params we don't honor rather than silently ignoring them (a
+ client relying on e.g. top_p/seed/logprobs would otherwise get wrong
+ behavior). Only the no-op/default value of each passes: top_p exactly
+ 1.0; penalties 0; response_format type "text"; tool_choice none/auto/
+ unset; parallel_tool_calls true (false can't be guaranteed without
+ constraining); logprobs are not returned at all."""
+ rf = req.response_format
+ flags = [
+ (req.n != 1, "n>1"),
+ (req.top_p is not None and req.top_p != 1.0, "top_p"),
+ (req.seed is not None, "seed"),
+ (req.reasoning_effort is not None, "reasoning_effort"),
+ (bool(req.frequency_penalty), "frequency_penalty"),
+ (bool(req.presence_penalty), "presence_penalty"),
+ (req.top_k is not None, "top_k"),
+ (bool(req.logit_bias), "logit_bias"),
+ (
+ bool(rf) and rf.get("type", "text") != "text",
+ "response_format (only 'text')",
+ ),
+ (bool(req.logprobs), "logprobs"),
+ (req.top_logprobs is not None, "top_logprobs"),
+ (req.parallel_tool_calls is False, "parallel_tool_calls=false"),
+ (
+ req.tool_choice not in (None, "none", "auto"),
+ "tool_choice (only 'none' or 'auto')",
+ ),
+ ]
+ unsupported = [label for cond, label in flags if cond]
+ if unsupported:
+ raise APIError(
+ 400,
+ f"Unsupported parameter(s): {', '.join(unsupported)}. This server honors "
+ "temperature, max_tokens/max_completion_tokens, stop, and tools for the "
+ "configured tool-call format.",
+ "invalid_request_error",
+ "unsupported_parameter",
+ )
+
+ async def create(self, req: ChatCompletionRequest):
+ self._reject_invalid_values(req)
+ self._reject_unsupported_params(req)
+ # tool_choice="none" must hide tools from the model: if we still render
+ # the tool schemas, the model can emit a that we'd surface as
+ # plain text (parsing is disabled), instead of a normal answer.
+ template_tools = None if req.tool_choice == "none" else req.tools
+ prompt = self._template.render(
+ req.messages, tools=template_tools, template_kwargs=req.chat_template_kwargs
+ )
+ # Pre-flight context check: reject cleanly instead of failing mid-generation
+ # (only possible when a tokenizer is available to count, e.g. --hf-tokenizer).
+ if self._max_context:
+ count = self._template.count_tokens(prompt)
+ if count is not None:
+ if count >= self._max_context:
+ raise ContextLengthExceeded(count, self._max_context)
+ # An explicit max_tokens that wouldn't fit alongside the prompt
+ # must be rejected here, not run until the worker hits the context
+ # limit mid-decode (a 500 / streaming error after partial output).
+ requested = req.resolved_max_tokens()
+ if requested > 0 and count + requested > self._max_context:
+ raise ContextLengthExceeded(count, self._max_context, requested)
+ options = self._options(req)
+ prompt_input = PromptInput(text=prompt)
+ if req.stream:
+ return self._stream(req, prompt_input, options)
+ return await self._complete(req, prompt_input, options)
+
+ async def _complete(
+ self,
+ req: ChatCompletionRequest,
+ prompt: PromptInput,
+ options: GenerationOptions,
+ ) -> ChatCompletionResponse:
+ stats = GenStats()
+ try:
+ # Collect raw text (markers intact for tool parsing), halting early
+ # at a stop boundary (special token or request stop).
+ text, stopped = await self._collect_until_stop(
+ self._runtime.generate_stream(None, prompt, options, stats),
+ self._stops + self._request_stops(req),
+ )
+ except Exception as e: # noqa: BLE001 - surface as a structured API error
+ raise GenerationError(str(e))
+ # Bound the raw output at the first stop/special token BEFORE tool
+ # parsing, so a call after the stop boundary is not parsed/emitted.
+ tool_calls, content = self._extract_tools(req, self._truncate_raw(text, req))
+ finish = self._finish_reason(
+ req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason
+ )
+ return ChatCompletionResponse(
+ model=self._model_id,
+ choices=[
+ Choice(
+ message=ResponseMessage(content=content, tool_calls=tool_calls),
+ finish_reason=finish,
+ )
+ ],
+ usage=Usage(
+ prompt_tokens=stats.prompt_tokens,
+ completion_tokens=stats.completion_tokens,
+ total_tokens=stats.prompt_tokens + stats.completion_tokens,
+ ),
+ )
+
+ async def _stream(
+ self,
+ req: ChatCompletionRequest,
+ prompt: PromptInput,
+ options: GenerationOptions,
+ ) -> AsyncIterator[str]:
+ cid = _new_id("chatcmpl")
+
+ def chunk(delta: DeltaMessage, finish=None) -> str:
+ c = ChatCompletionChunk(
+ id=cid,
+ model=self._model_id,
+ choices=[ChunkChoice(delta=delta, finish_reason=finish)],
+ )
+ return f"data: {c.model_dump_json(exclude_none=True)}\n\n"
+
+ yield chunk(DeltaMessage(role="assistant"))
+ error: Optional[Exception] = None
+ use_tools = self._tools_active(req)
+ tool_calls = None
+ content = None
+
+ stats = GenStats()
+ stop_hit = [False] # set when a stop boundary is reached (forces finish="stop")
+ stops = self._stops + self._request_stops(req)
+ try:
+ if use_tools:
+ # v1: buffer the (usually short) tool response, parse once.
+ # Halt early at a stop boundary, and bound the raw output
+ # BEFORE parsing so post-stop tool calls / text don't leak.
+ raw, stop_hit[0] = await self._collect_until_stop(
+ self._runtime.generate_stream(None, prompt, options, stats),
+ stops,
+ )
+ tool_calls, content = self._extract_tools(
+ req, self._truncate_raw(raw, req)
+ )
+ else:
+ # Plain chat: stream tokens live (best UX), cutting at special
+ # tokens or request stop sequences and halting early on a hit.
+ def on_stop():
+ stop_hit[0] = True
+ self._runtime.stop()
+
+ async for token in self._clean(
+ self._runtime.generate_stream(None, prompt, options, stats),
+ stops,
+ on_stop=on_stop,
+ ):
+ yield chunk(DeltaMessage(content=token))
+ except (
+ Exception
+ ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket
+ error = e
+
+ if error is not None:
+ err = {
+ "message": f"Generation failed: {error}",
+ "type": "server_error",
+ "code": None,
+ }
+ yield f"data: {json.dumps({'error': err})}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ if use_tools:
+ if content:
+ yield chunk(DeltaMessage(content=content))
+ for tc in tool_calls or []:
+ yield chunk(DeltaMessage(tool_calls=[tc]))
+ finish = self._finish_reason(
+ req,
+ stats.completion_tokens,
+ tool_calls,
+ stop_hit[0],
+ stats.finish_reason,
+ )
+ else:
+ finish = self._finish_reason(
+ req,
+ stats.completion_tokens,
+ stopped=stop_hit[0],
+ worker_finish=stats.finish_reason,
+ )
+ yield chunk(DeltaMessage(), finish=finish)
+ if req.stream_options and req.stream_options.include_usage:
+ usage_chunk = ChatCompletionChunk(
+ id=cid,
+ model=self._model_id,
+ choices=[],
+ usage=Usage(
+ prompt_tokens=stats.prompt_tokens,
+ completion_tokens=stats.completion_tokens,
+ total_tokens=stats.prompt_tokens + stats.completion_tokens,
+ ),
+ )
+ yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n"
+ yield "data: [DONE]\n\n"
diff --git a/extension/llm/server/python/session_runtime.py b/extension/llm/server/python/session_runtime.py
new file mode 100644
index 00000000000..e73dbafadc4
--- /dev/null
+++ b/extension/llm/server/python/session_runtime.py
@@ -0,0 +1,193 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Python's stateful local-LLM runtime over one C++ worker process.
+
+This is the internal boundary between protocol adapters (OpenAI chat, future
+native/agent surfaces) and the worker. The adapter speaks sessions, prompts, and
+generation parameters; the worker (driven over JSONL by a WorkerClient) owns all
+model execution and session state (KV/recurrent, resident token ids, warm-resume
+prefix logic). The Python server never loads a model, links a backend, or imports
+a runtime pybind.
+
+A SessionRuntime owns exactly one worker and serializes access to it (one
+in-flight request at a time), bridging the worker's blocking generate() into an
+async token stream. Multi-worker scheduling / named-session affinity is out of
+scope: a single worker already hosts many isolated sessions on one weight load,
+routed by session_id inside the worker.
+"""
+
+import asyncio
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
+from typing import AsyncIterator, Optional
+
+logger = logging.getLogger(__name__)
+
+_SENTINEL = object()
+
+
+@dataclass
+class PromptInput:
+ """A prompt as either a single rendered string or token-ID segments. Exactly
+ one of `text` / `segments` is set. Segments ([{"text": str} | {"ids": [int]}])
+ let an adapter splice exact prior-turn token ids in place of a lossy
+ re-render (see openai_transcript)."""
+
+ text: Optional[str] = None
+ segments: Optional[list] = None
+
+ def __post_init__(self):
+ if (self.text is None) == (self.segments is None):
+ raise ValueError("exactly one of PromptInput.text / .segments must be set")
+ if self.segments is not None and not self.segments:
+ raise ValueError("PromptInput.segments must be non-empty")
+
+
+@dataclass
+class GenerationOptions:
+ """Sampling/length knobs forwarded to the worker (only what we honor today)."""
+
+ max_new_tokens: int
+ temperature: float = 0.0
+ stop: list[str] = field(default_factory=list)
+
+
+@dataclass
+class GenStats:
+ """Per-request metadata the worker reports at the end of generation."""
+
+ prompt_tokens: int = 0
+ completion_tokens: int = 0
+ # Worker-reported stop reason ("stop" | "length"), or None if not reported.
+ finish_reason: Optional[str] = None
+ # Warm-resume accounting (V2b.1): tokens served from the session's resident
+ # state vs prefilled this request, and why.
+ reused_prompt_tokens: int = 0
+ prefilled_prompt_tokens: int = 0
+ session_reset_reason: Optional[str] = None
+ # Exact token ids generated this turn (V2b.1.5), for an adapter's transcript
+ # store. Empty when the worker doesn't report them (e.g. a stop-trimmed turn).
+ generated_token_ids: list = field(default_factory=list)
+
+
+# Forwarded to WorkerClient.generate() as the per-request config it reads fields
+# off; keeps that low-level contract unchanged while the runtime's public surface
+# is PromptInput + GenerationOptions + session_id.
+@dataclass
+class _WorkerRequest:
+ max_new_tokens: int
+ temperature: float
+ stop: list[str]
+ session_id: Optional[str]
+ prompt_segments: Optional[list]
+
+
+class SessionRuntime:
+ """Stateful runtime over a single worker. `worker` is a WorkerClient (a fake
+ in tests) exposing generate()/stop()/close() and the session ops
+ open_session/reset_session/close_session."""
+
+ def __init__(self, worker):
+ self._worker = worker
+ # One executor thread; the lock guarantees the worker is never driven by
+ # two requests at once (it is single-in-flight).
+ self._executor = ThreadPoolExecutor(max_workers=1)
+ self._lock = asyncio.Lock()
+
+ async def open(self, session_id: str) -> None:
+ """Admit a named session before generation so a capacity refusal surfaces
+ up front (the adapter maps it to an HTTP status) rather than mid-stream.
+ Idempotent."""
+ await self._session_op("open_session", session_id)
+
+ async def reset(self, session_id: str) -> None:
+ """Clear a named session's context (KV/recurrent + resident ids) but keep
+ its capacity slot. Idempotent."""
+ await self._session_op("reset_session", session_id)
+
+ async def close(self, session_id: str) -> None:
+ """Destroy a named session, freeing its state and slot. Idempotent."""
+ await self._session_op("close_session", session_id)
+
+ async def _session_op(self, method: str, session_id: str) -> None:
+ op = getattr(self._worker, method, None)
+ if op is None:
+ return # worker doesn't support sessions (e.g. a minimal fake)
+ loop = asyncio.get_running_loop()
+ async with self._lock:
+ await loop.run_in_executor(self._executor, op, session_id)
+
+ def stop(self) -> None:
+ """Request the in-flight generation stop at the next token boundary."""
+ self._worker.stop()
+
+ async def generate_stream(
+ self,
+ session_id: Optional[str],
+ prompt: PromptInput,
+ options: GenerationOptions,
+ stats: Optional[GenStats] = None,
+ ) -> AsyncIterator[str]:
+ """Yield generated text pieces from the worker, holding the worker lock
+ for the whole generation. `stats` (if given) is filled in place with the
+ worker's terminal metadata (per-request, so concurrent callers don't
+ race). session_id None routes to the worker's anonymous scratch session."""
+ out_stats = stats if stats is not None else GenStats()
+ req = _WorkerRequest(
+ max_new_tokens=options.max_new_tokens,
+ temperature=options.temperature,
+ stop=list(options.stop),
+ session_id=session_id,
+ prompt_segments=prompt.segments,
+ )
+ prompt_text = prompt.text or ""
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+
+ def token_cb(token: str) -> None:
+ loop.call_soon_threadsafe(queue.put_nowait, token)
+
+ def stats_cb(s) -> None:
+ out_stats.prompt_tokens = s.num_prompt_tokens
+ out_stats.completion_tokens = s.num_generated_tokens
+ out_stats.finish_reason = getattr(s, "finish_reason", None)
+ out_stats.reused_prompt_tokens = getattr(s, "reused_prompt_tokens", 0)
+ out_stats.prefilled_prompt_tokens = getattr(s, "prefilled_prompt_tokens", 0)
+ out_stats.session_reset_reason = getattr(s, "session_reset_reason", None)
+ out_stats.generated_token_ids = getattr(s, "generated_token_ids", [])
+
+ def run() -> None:
+ try:
+ self._worker.generate(prompt_text, req, token_cb, stats_cb)
+ except Exception as e: # noqa: BLE001 - surface to the stream consumer
+ loop.call_soon_threadsafe(queue.put_nowait, e)
+ finally:
+ loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL)
+
+ async with self._lock:
+ fut = loop.run_in_executor(self._executor, run)
+ try:
+ while True:
+ item = await queue.get()
+ if item is _SENTINEL:
+ break
+ if isinstance(item, Exception):
+ raise item
+ yield item
+ except asyncio.CancelledError:
+ self._worker.stop()
+ raise
+ finally:
+ await fut
+
+ def close_worker(self) -> None:
+ """Shut the worker process down and the executor (called at shutdown)."""
+ close = getattr(self._worker, "close", None)
+ if close is not None:
+ close()
+ self._executor.shutdown(wait=False, cancel_futures=True)
diff --git a/extension/llm/server/python/tests/conftest.py b/extension/llm/server/python/tests/conftest.py
new file mode 100644
index 00000000000..b91f0aec26e
--- /dev/null
+++ b/extension/llm/server/python/tests/conftest.py
@@ -0,0 +1,143 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Fixtures for hermetic contract tests.
+
+We build a SessionRuntime over a single FakeRunner worker, so the tests exercise
+the real server, protocol, templating, and streaming code over the HTTP boundary
+with NO model, tokenizer, GPU, or worker subprocess. This mirrors ExecuTorch's
+fake_llm_executor approach: fake the worker, test the real surface.
+
+The control plane imports no runtime pybind; only fastapi, pydantic, httpx, and
+pytest are required.
+"""
+
+import pytest
+
+from executorch.extension.llm.server.python.chat_template import ChatTemplate
+from executorch.extension.llm.server.python.server import build_app
+from executorch.extension.llm.server.python.serving_chat import ServingChat
+from executorch.extension.llm.server.python.session_runtime import SessionRuntime
+from executorch.extension.llm.server.python.tool_parsers import HermesDetector
+from executorch.extension.llm.server.python.worker_client import WorkerError
+from fastapi.testclient import TestClient
+
+
+class _FakeStats:
+ num_prompt_tokens = 5
+ num_generated_tokens = 0
+ finish_reason = None
+
+
+class FakeRunner:
+ """Canned engine: emits fixed tokens, records the config it was given.
+
+ With max_named_sessions > 0 it also models the worker's session admission:
+ open_session() enforces capacity and reports structured WorkerError codes,
+ matching the real worker's contract."""
+
+ def __init__(
+ self, tokens, fail=False, finish_reason=None, max_named_sessions=0, gen_ids=None
+ ):
+ self._tokens = list(tokens)
+ self._fail = fail
+ self._finish_reason = finish_reason # worker-reported stop reason, if any
+ self._gen_ids = list(gen_ids or []) # ids reported per turn (V2b.1.5)
+ self.captured_config = None
+ self.stopped = False
+ self.reset_count = 0
+ self.max_named_sessions = max_named_sessions
+ self.open_named = set() # currently-open named sessions
+ self.opened_log = [] # every successful open, in order
+ self.reset_log = [] # every reset_session call, in order
+
+ def reset(self):
+ self.reset_count += 1
+
+ def stop(self):
+ self.stopped = True
+
+ def open_session(self, session_id):
+ if session_id in self.open_named:
+ return # idempotent
+ if self.max_named_sessions == 0:
+ raise WorkerError("no named sessions", code="unsupported_session")
+ if len(self.open_named) >= self.max_named_sessions:
+ raise WorkerError("session capacity exhausted", code="capacity_exhausted")
+ self.open_named.add(session_id)
+ self.opened_log.append(session_id)
+
+ def close_session(self, session_id):
+ self.open_named.discard(session_id)
+
+ def reset_session(self, session_id):
+ # Clears context but keeps the slot (stays in open_named).
+ self.reset_log.append(session_id)
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ self.captured_config = config
+ if self._fail:
+ raise RuntimeError("Generation failed")
+ for tok in self._tokens:
+ if token_callback:
+ token_callback(tok)
+ if stats_callback:
+ stats = _FakeStats()
+ stats.num_generated_tokens = len(self._tokens)
+ stats.finish_reason = self._finish_reason
+ stats.generated_token_ids = list(self._gen_ids)
+ stats_callback(stats)
+
+
+class _FakeTokenizer:
+ """Minimal stand-in for an HF tokenizer (counting + templating)."""
+
+ all_special_tokens: list = []
+
+ def __init__(self, prompt_tokens):
+ self._n = prompt_tokens
+
+ def encode(self, text, add_special_tokens=False):
+ return [0] * self._n
+
+ def apply_chat_template(
+ self, messages, tools, add_generation_prompt, tokenize, **kwargs
+ ):
+ return "PROMPT"
+
+
+@pytest.fixture
+def make_client():
+ def _make(
+ tokens=("Hello", ", ", "world"),
+ max_context=None,
+ prompt_tokens=None,
+ fail=False,
+ finish_reason=None,
+ max_named_sessions=0,
+ gen_ids=None,
+ ):
+ fake = FakeRunner(
+ tokens,
+ fail=fail,
+ finish_reason=finish_reason,
+ max_named_sessions=max_named_sessions,
+ gen_ids=gen_ids,
+ )
+ runtime = SessionRuntime(fake) # one fake worker
+ template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)
+ if prompt_tokens is not None:
+ template._hf = _FakeTokenizer(prompt_tokens)
+ serving = ServingChat(
+ runtime,
+ template,
+ "test-model",
+ max_context=max_context,
+ tool_detector_cls=HermesDetector,
+ )
+ return TestClient(build_app(serving, "test-model")), fake
+
+ return _make
diff --git a/extension/llm/server/python/tests/test_contract.py b/extension/llm/server/python/tests/test_contract.py
new file mode 100644
index 00000000000..6a460150b18
--- /dev/null
+++ b/extension/llm/server/python/tests/test_contract.py
@@ -0,0 +1,446 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Hermetic OpenAI-contract tests (fake engine, no model/GPU).
+
+These assert on the public HTTP/wire contract only — response object shapes,
+the streaming chunk protocol, status codes — never on internal classes or
+methods. Implementation can change freely as long as these pass.
+"""
+
+import json
+
+import pytest
+
+
+def _sse_chunks(text):
+ chunks, done = [], False
+ for line in text.splitlines():
+ line = line.strip()
+ if not line.startswith("data:"):
+ continue
+ payload = line[len("data:") :].strip()
+ if payload == "[DONE]":
+ done = True
+ continue
+ chunks.append(json.loads(payload))
+ return chunks, done
+
+
+def test_health(make_client):
+ client, _ = make_client()
+ assert client.get("/health").json() == {"status": "ok"}
+
+
+def test_models_listing_shape(make_client):
+ client, _ = make_client()
+ body = client.get("/v1/models").json()
+ assert body["object"] == "list"
+ assert body["data"][0]["id"] == "test-model"
+ assert body["data"][0]["object"] == "model"
+
+
+def test_chat_nonstreaming_shape(make_client):
+ client, _ = make_client(tokens=["Hello", ", ", "world"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["object"] == "chat.completion"
+ choice = body["choices"][0]
+ assert choice["message"]["role"] == "assistant"
+ assert choice["message"]["content"] == "Hello, world"
+ assert choice["finish_reason"] == "stop"
+ for k in ("prompt_tokens", "completion_tokens", "total_tokens"):
+ assert k in body["usage"]
+ assert body["usage"]["total_tokens"] == (
+ body["usage"]["prompt_tokens"] + body["usage"]["completion_tokens"]
+ )
+
+
+def test_chat_streaming_protocol(make_client):
+ client, _ = make_client(tokens=["a", "b", "c"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": True,
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.headers["content-type"].startswith("text/event-stream")
+ chunks, done = _sse_chunks(resp.text)
+ assert done, "stream must terminate with data: [DONE]"
+ assert all(c["object"] == "chat.completion.chunk" for c in chunks)
+ assert chunks[0]["choices"][0]["delta"].get("role") == "assistant"
+ content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks)
+ assert content == "abc"
+ assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
+
+
+def test_request_params_forwarded_to_generation(make_client):
+ # Contract behavior: the server must honor max_tokens/temperature.
+ client, fake = make_client()
+ client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 7,
+ "temperature": 0.1,
+ },
+ )
+ assert fake.captured_config.max_new_tokens == 7
+ assert abs(fake.captured_config.temperature - 0.1) < 1e-6
+
+
+def test_special_tokens_forwarded_to_worker_as_stops(make_client):
+ # The worker must be told to stop at the model's end-of-turn special tokens
+ # (e.g. <|im_end|>), not just request `stop` sequences. Otherwise a worker
+ # whose EOS-by-token-id check misses the turn end runs to max_new (or errors
+ # forwarding its own end token) past it — the text_llm_worker "decode failed"
+ # seen on a real model. (Forwarding only request stops would leave this [].)
+ client, fake = make_client()
+ client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert "<|im_end|>" in (fake.captured_config.stop or [])
+
+
+def test_request_stop_forwarded_to_worker(make_client):
+ # A request `stop` sequence must reach the worker (so it can terminate early),
+ # not only be applied by the Python backstop. The stop tests elsewhere pass
+ # via that backstop even if forwarding regresses; this asserts forwarding.
+ client, fake = make_client()
+ client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stop": ["STOP"],
+ },
+ )
+ # stop == special tokens + request stops, so check membership (not equality).
+ assert "STOP" in (fake.captured_config.stop or [])
+
+
+def test_tools_field_accepted(make_client):
+ # tools is part of the contract even before parsing is enforced (M2).
+ client, _ = make_client()
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "tools": [
+ {"type": "function", "function": {"name": "f", "parameters": {}}}
+ ],
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.json()["choices"][0]["message"]["role"] == "assistant"
+
+
+def test_invalid_request_returns_422(make_client):
+ client, _ = make_client()
+ resp = client.post(
+ "/v1/chat/completions", json={"model": "test-model"}
+ ) # no messages
+ assert resp.status_code == 422
+
+
+def test_special_tokens_stripped_nonstreaming(make_client):
+ # The runner may decode EOS/special tokens to text; they must not leak.
+ client, _ = make_client(tokens=["Hello", " world", "<|im_end|>", "LEAK"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.json()["choices"][0]["message"]["content"] == "Hello world"
+
+
+def test_usage_populated_when_special_token_cuts_early(make_client):
+ # Regression: cutting at a special token must not skip usage stats.
+ client, _ = make_client(tokens=["Hello", "<|im_end|>", "LEAK"])
+ body = client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ ).json()
+ assert body["choices"][0]["message"]["content"] == "Hello"
+ assert body["usage"]["completion_tokens"] > 0
+
+
+def test_special_tokens_stripped_streaming(make_client):
+ client, _ = make_client(tokens=["Hello", " world", "<|im_end|>", "LEAK"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": True,
+ },
+ )
+ chunks, _ = _sse_chunks(resp.text)
+ content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks)
+ assert content == "Hello world"
+ assert "LEAK" not in content and "<|im_end|>" not in content
+
+
+# (1) Context-size-exceeded -> structured 400, both modes (not a dropped socket).
+def test_context_length_exceeded_returns_400(make_client):
+ client, _ = make_client(max_context=2048, prompt_tokens=2940)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "x" * 100}],
+ },
+ )
+ assert resp.status_code == 400
+ err = resp.json()["error"]
+ assert err["code"] == "context_length_exceeded"
+ assert err["type"] == "invalid_request_error"
+
+
+def test_context_length_exceeded_streaming_returns_400(make_client):
+ client, _ = make_client(max_context=2048, prompt_tokens=2940)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "x"}],
+ "stream": True,
+ },
+ )
+ # Pre-flight check rejects before the stream starts -> clean 400, no SSE.
+ assert resp.status_code == 400
+ assert resp.json()["error"]["code"] == "context_length_exceeded"
+
+
+def test_prompt_plus_max_tokens_exceeding_context_returns_400(make_client):
+ # Prompt fits (100 < 2048) but prompt + max_tokens (100 + 2000) > 2048: must
+ # reject up front, not run until the worker hits the limit mid-decode.
+ client, _ = make_client(max_context=2048, prompt_tokens=100)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 2000,
+ },
+ )
+ assert resp.status_code == 400
+ assert resp.json()["error"]["code"] == "context_length_exceeded"
+
+
+def test_prompt_plus_max_tokens_within_context_ok(make_client):
+ # Prompt + max_tokens (100 + 100) <= 2048: must NOT be rejected.
+ client, _ = make_client(max_context=2048, prompt_tokens=100)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 100,
+ },
+ )
+ assert resp.status_code == 200
+
+
+# (1) Mid-generation failure -> structured error, never a dropped connection.
+def test_generation_failure_returns_structured_error(make_client):
+ client, _ = make_client(fail=True)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.status_code == 500
+ assert resp.json()["error"]["type"] == "server_error"
+
+
+def test_generation_failure_streaming_emits_error_event(make_client):
+ client, _ = make_client(fail=True)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": True,
+ },
+ )
+ assert resp.status_code == 200 # headers already sent; error arrives as an event
+ chunks, done = _sse_chunks(resp.text)
+ assert done
+ assert any("error" in c for c in chunks)
+
+
+# (3) finish_reason == "length" when max_tokens is hit.
+def test_finish_reason_length_when_max_tokens_hit(make_client):
+ client, _ = make_client(tokens=["a", "b", "c"]) # 3 generated tokens
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 3,
+ },
+ )
+ assert resp.json()["choices"][0]["finish_reason"] == "length"
+
+
+def test_finish_reason_stop_when_under_max_tokens(make_client):
+ client, _ = make_client(tokens=["a", "b", "c"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 100,
+ },
+ )
+ assert resp.json()["choices"][0]["finish_reason"] == "stop"
+
+
+# Worker-reported finish_reason: the worker may silently clamp max_new to the
+# context window, so the token-count heuristic can't tell a real stop from a
+# truncation. Trust the worker's reason.
+def test_worker_reported_length_overrides_token_count(make_client):
+ # 3 tokens generated (< requested 100) but the worker says it ran to the cap
+ # (a context clamp): finish_reason must be "length", not "stop".
+ client, _ = make_client(tokens=["a", "b", "c"], finish_reason="length")
+ body = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 100,
+ },
+ ).json()
+ assert body["choices"][0]["finish_reason"] == "length"
+
+
+def test_worker_reported_stop_overrides_token_count(make_client):
+ # 3 tokens with max_tokens=3 (heuristic would say "length"), but the worker
+ # reports EOS: finish_reason must be "stop".
+ client, _ = make_client(tokens=["a", "b", "c"], finish_reason="stop")
+ body = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "max_tokens": 3,
+ },
+ ).json()
+ assert body["choices"][0]["finish_reason"] == "stop"
+
+
+def test_server_stop_sequence_wins_over_worker_length(make_client):
+ # A server-side stop sequence is truncation in the control plane; it must
+ # win even when the worker would report "length".
+ client, _ = make_client(
+ tokens=["Hello ", "world ", "STOP", " x"], finish_reason="length"
+ )
+ body = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stop": ["STOP"],
+ "max_tokens": 100,
+ },
+ ).json()
+ assert body["choices"][0]["finish_reason"] == "stop"
+ assert "STOP" not in (body["choices"][0]["message"]["content"] or "")
+
+
+# (4) Error-variant matrix: malformed requests -> consistent 422.
+@pytest.mark.parametrize(
+ "bad_body",
+ [
+ {"model": "m"}, # missing messages
+ {"model": "m", "messages": "not-a-list"},
+ {
+ "model": "m",
+ "messages": [{"role": "user", "content": "hi"}],
+ "temperature": "hot",
+ },
+ {
+ "model": "m",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": "maybe",
+ },
+ {"model": "m", "messages": [{"content": "no role"}]},
+ ],
+)
+def test_invalid_requests_return_422(make_client, bad_body):
+ client, _ = make_client()
+ assert client.post("/v1/chat/completions", json=bad_body).status_code == 422
+
+
+# (6) Streaming usage when stream_options.include_usage is set.
+def test_streaming_usage_included(make_client):
+ client, _ = make_client(tokens=["a", "b"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": True,
+ "stream_options": {"include_usage": True},
+ },
+ )
+ chunks, _ = _sse_chunks(resp.text)
+ usage_chunks = [c for c in chunks if c.get("usage")]
+ assert usage_chunks, "expected a chunk carrying usage"
+ u = usage_chunks[-1]["usage"]
+ assert u["total_tokens"] == u["prompt_tokens"] + u["completion_tokens"]
+
+
+def test_streaming_usage_absent_by_default(make_client):
+ client, _ = make_client(tokens=["a", "b"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": True,
+ },
+ )
+ chunks, _ = _sse_chunks(resp.text)
+ assert not any(c.get("usage") for c in chunks)
+
+
+# (2) Unicode/multibyte content survives streaming intact.
+def test_unicode_streaming_integrity(make_client):
+ pieces = ["café ", "日本語 ", "😀", "🎉"]
+ client, _ = make_client(tokens=pieces)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "stream": True,
+ },
+ )
+ chunks, _ = _sse_chunks(
+ resp.text
+ ) # _sse_chunks uses json.loads -> validates UTF-8/JSON
+ content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks)
+ assert content == "".join(pieces)
+
+
+def test_unicode_nonstreaming_integrity(make_client):
+ pieces = ["café ", "日本語 ", "😀"]
+ client, _ = make_client(tokens=pieces)
+ resp = client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.json()["choices"][0]["message"]["content"] == "".join(pieces)
diff --git a/extension/llm/server/python/tests/test_sampling_params.py b/extension/llm/server/python/tests/test_sampling_params.py
new file mode 100644
index 00000000000..fe3166d1bb5
--- /dev/null
+++ b/extension/llm/server/python/tests/test_sampling_params.py
@@ -0,0 +1,269 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Contract tests for sampling/control params: stop sequences, n, tool_choice.
+
+These exercise the real server over the HTTP boundary with a FakeRunner.
+"""
+
+import json
+
+CALL = (
+ '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n'
+)
+WEATHER_TOOL = {
+ "type": "function",
+ "function": {"name": "get_weather", "parameters": {}},
+}
+
+
+def _body(**kw):
+ b = {"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}
+ b.update(kw)
+ return b
+
+
+def test_n_greater_than_one_is_rejected(make_client):
+ client, _ = make_client()
+ r = client.post("/v1/chat/completions", json=_body(n=2))
+ assert r.status_code == 400
+ assert r.json()["error"]["type"] == "invalid_request_error"
+ assert r.json()["error"]["code"] == "unsupported_parameter"
+
+
+def test_unsupported_params_rejected(make_client):
+ client, _ = make_client()
+ for param in (
+ {"top_p": 0.5},
+ {"top_p": 2.0}, # > 1.0 is not a no-op either — only exactly 1.0/unset is
+ {"seed": 42},
+ {"reasoning_effort": "high"},
+ {"frequency_penalty": 1.0},
+ {"presence_penalty": -0.5},
+ {"top_k": 40},
+ {"logit_bias": {"123": 5.0}},
+ {"response_format": {"type": "json_object"}},
+ {"logprobs": True},
+ {"top_logprobs": 5},
+ {"parallel_tool_calls": False},
+ ):
+ r = client.post("/v1/chat/completions", json=_body(**param))
+ assert r.status_code == 400, param
+ assert r.json()["error"]["code"] == "unsupported_parameter", param
+
+
+def test_nonpositive_max_tokens_rejected(make_client):
+ # max_tokens=0/-2 must not be silently treated as "unbounded" (the `or` bug);
+ # 0 and negatives are invalid for both field names.
+ client, _ = make_client()
+ for param in (
+ {"max_tokens": 0},
+ {"max_tokens": -2},
+ {"max_completion_tokens": 0},
+ {"max_completion_tokens": -2},
+ ):
+ r = client.post("/v1/chat/completions", json=_body(**param))
+ assert r.status_code == 400, param
+ assert r.json()["error"]["type"] == "invalid_request_error", param
+
+
+def test_temperature_range_rejected(make_client):
+ client, _ = make_client()
+ for param in (
+ {"temperature": -1},
+ {"temperature": -0.1},
+ {"temperature": 2.1},
+ {"temperature": 3},
+ ):
+ r = client.post("/v1/chat/completions", json=_body(**param))
+ assert r.status_code == 400, param
+ assert r.json()["error"]["code"] == "invalid_value", param
+
+
+def test_noop_output_contract_fields_accepted(make_client):
+ # The default/no-op forms must NOT be rejected (don't break OpenAI clients
+ # that send them explicitly).
+ client, _ = make_client()
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(
+ response_format={"type": "text"},
+ parallel_tool_calls=True,
+ max_tokens=8,
+ ),
+ )
+ assert r.status_code == 200
+
+
+def test_zero_penalties_and_unknown_fields_accepted(make_client):
+ # frequency/presence_penalty=0.0 are no-ops; unknown non-generation fields
+ # (user/store/metadata) are ignored, not rejected (don't break OpenAI clients).
+ client, _ = make_client()
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(
+ frequency_penalty=0.0,
+ presence_penalty=0.0,
+ user="abc",
+ store=False,
+ metadata={"k": "v"},
+ max_tokens=8,
+ ),
+ )
+ assert r.status_code == 200
+
+
+def test_unsupported_tool_choice_rejected(make_client):
+ # "required" / a specific-function choice would need constrained decoding to
+ # force/restrict the call; v1 rejects rather than silently treating as "auto".
+ client, _ = make_client()
+ for choice in (
+ "required",
+ {"type": "function", "function": {"name": "get_weather"}},
+ ):
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(tools=[WEATHER_TOOL], tool_choice=choice, max_tokens=8),
+ )
+ assert r.status_code == 400, choice
+ assert r.json()["error"]["code"] == "unsupported_parameter", choice
+
+
+def test_supported_params_accepted(make_client):
+ # top_p=1.0 (no-op) and temperature/max_tokens must NOT be rejected; neither
+ # should tool_choice "auto" / "none".
+ client, _ = make_client()
+ for temperature in (0.0, 1.0, 2.0):
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(top_p=1.0, temperature=temperature, max_tokens=8),
+ )
+ assert r.status_code == 200, temperature
+ for choice in ("auto", "none"):
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(tools=[WEATHER_TOOL], tool_choice=choice, max_tokens=8),
+ )
+ assert r.status_code == 200, choice
+
+
+def test_stop_sequence_truncates_nonstreaming(make_client):
+ client, _ = make_client(tokens=["Hello ", "world ", "STOP", " ignored"])
+ r = client.post("/v1/chat/completions", json=_body(stop=["STOP"], max_tokens=32))
+ assert r.status_code == 200
+ content = r.json()["choices"][0]["message"]["content"]
+ assert content == "Hello world "
+ assert "STOP" not in content
+
+
+def test_stop_sequence_truncates_streaming(make_client):
+ client, _ = make_client(tokens=["Hello ", "world ", "STOP", " ignored"])
+ content = ""
+ with client.stream(
+ "POST", "/v1/chat/completions", json=_body(stop=["STOP"], stream=True)
+ ) as r:
+ for line in r.iter_lines():
+ if not line.startswith("data:"):
+ continue
+ payload = line[len("data:") :].strip()
+ if payload == "[DONE]":
+ break
+ delta = json.loads(payload)["choices"][0]["delta"]
+ content += delta.get("content") or ""
+ assert content == "Hello world "
+ assert "STOP" not in content
+
+
+def test_stop_forces_finish_reason_over_length_nonstreaming(make_client):
+ # 4 tokens emitted (completion reaches max_tokens=4) AND a stop is hit:
+ # finish_reason must be "stop" (boundary), not "length".
+ client, _ = make_client(tokens=["a ", "b ", "STOP", " c"])
+ body = client.post(
+ "/v1/chat/completions", json=_body(stop=["STOP"], max_tokens=4)
+ ).json()
+ assert body["choices"][0]["finish_reason"] == "stop"
+ assert "STOP" not in (body["choices"][0]["message"]["content"] or "")
+
+
+def test_stop_forces_finish_reason_over_length_streaming(make_client):
+ client, _ = make_client(tokens=["a ", "b ", "STOP", " c"])
+ finish = None
+ with client.stream(
+ "POST",
+ "/v1/chat/completions",
+ json=_body(stop=["STOP"], max_tokens=4, stream=True),
+ ) as r:
+ for line in r.iter_lines():
+ if not line.startswith("data:"):
+ continue
+ p = line[len("data:") :].strip()
+ if p == "[DONE]":
+ break
+ fr = json.loads(p)["choices"][0].get("finish_reason")
+ if fr:
+ finish = fr
+ assert finish == "stop"
+
+
+_STOP_THEN_CALL = (
+ 'Answer STOP \n{"name": "get_weather", "arguments": {}}\n'
+)
+
+
+def test_stop_before_tool_call_nonstreaming(make_client):
+ # Stop boundary precedes a tool call (in one chunk, so truncation — not just
+ # early-stop — must catch it): the call must NOT be parsed/emitted.
+ client, _ = make_client(tokens=[_STOP_THEN_CALL])
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(tools=[WEATHER_TOOL], stop=["STOP"], max_tokens=64),
+ )
+ msg = r.json()["choices"][0]["message"]
+ assert msg.get("tool_calls") is None
+ assert "STOP" not in (msg.get("content") or "")
+
+
+def test_stop_before_tool_call_streaming(make_client):
+ client, _ = make_client(tokens=[_STOP_THEN_CALL])
+ saw_tool, content = False, ""
+ with client.stream(
+ "POST",
+ "/v1/chat/completions",
+ json=_body(tools=[WEATHER_TOOL], stop=["STOP"], stream=True),
+ ) as r:
+ for line in r.iter_lines():
+ if not line.startswith("data:"):
+ continue
+ p = line[len("data:") :].strip()
+ if p == "[DONE]":
+ break
+ delta = json.loads(p)["choices"][0]["delta"]
+ if delta.get("tool_calls"):
+ saw_tool = True
+ content += delta.get("content") or ""
+ assert not saw_tool
+ assert "STOP" not in content
+
+
+def test_tool_choice_none_disables_tools(make_client):
+ client, _ = make_client(tokens=[CALL])
+ r = client.post(
+ "/v1/chat/completions",
+ json=_body(tools=[WEATHER_TOOL], tool_choice="none", max_tokens=64),
+ )
+ msg = r.json()["choices"][0]["message"]
+ assert msg.get("tool_calls") is None # tools disabled -> returned as text
+ assert r.json()["choices"][0]["finish_reason"] != "tool_calls"
+
+
+def test_tool_choice_default_still_parses(make_client):
+ # Sanity: without tool_choice="none", the same call IS parsed as a tool call.
+ client, _ = make_client(tokens=[CALL])
+ r = client.post(
+ "/v1/chat/completions", json=_body(tools=[WEATHER_TOOL], max_tokens=64)
+ )
+ calls = r.json()["choices"][0]["message"].get("tool_calls")
+ assert calls and calls[0]["function"]["name"] == "get_weather"
diff --git a/extension/llm/server/python/tests/test_session_runtime.py b/extension/llm/server/python/tests/test_session_runtime.py
new file mode 100644
index 00000000000..14f138cc007
--- /dev/null
+++ b/extension/llm/server/python/tests/test_session_runtime.py
@@ -0,0 +1,187 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""SessionRuntime tests: session-op routing, the blocking->async stream bridge,
+cancellation, and worker shutdown. A fake worker stands in for the WorkerClient
+(no model, GPU, or subprocess). asyncio.run keeps the test bodies sync."""
+
+import asyncio
+import threading
+
+from executorch.extension.llm.server.python.session_runtime import (
+ GenerationOptions,
+ GenStats,
+ PromptInput,
+ SessionRuntime,
+)
+
+_OPTS = GenerationOptions(max_new_tokens=8)
+
+
+def _text(s="hi") -> PromptInput:
+ return PromptInput(text=s)
+
+
+class _Worker:
+ """Records session ops + process close; emits nothing on generate."""
+
+ def __init__(self):
+ self.opened, self.reset_ids, self.closed_ids = [], [], []
+ self.proc_closed = False
+
+ def open_session(self, sid):
+ self.opened.append(sid)
+
+ def reset_session(self, sid):
+ self.reset_ids.append(sid)
+
+ def close_session(self, sid):
+ self.closed_ids.append(sid)
+
+ def close(self):
+ self.proc_closed = True
+
+ def stop(self):
+ pass
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ pass
+
+
+def test_session_ops_route_to_worker():
+ async def scenario():
+ w = _Worker()
+ rt = SessionRuntime(w)
+ await rt.open("a")
+ await rt.reset("a")
+ await rt.close("a")
+ return w
+
+ w = asyncio.run(scenario())
+ assert w.opened == ["a"] and w.reset_ids == ["a"] and w.closed_ids == ["a"]
+
+
+def test_session_ops_noop_when_worker_lacks_support():
+ # A minimal worker without session ops: the runtime silently no-ops.
+ class _Bare:
+ def stop(self):
+ pass
+
+ def generate(self, *a, **k):
+ pass
+
+ async def scenario():
+ rt = SessionRuntime(_Bare())
+ await rt.open("a")
+ await rt.reset("a")
+ await rt.close("a")
+
+ asyncio.run(scenario()) # must not raise
+
+
+def test_generate_stream_yields_and_fills_stats():
+ class _Echo:
+ def stop(self):
+ pass
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ token_callback("Hello")
+ token_callback(" world")
+
+ class S:
+ num_prompt_tokens = 3
+ num_generated_tokens = 2
+ finish_reason = "stop"
+ generated_token_ids = [10, 11]
+
+ stats_callback(S())
+
+ async def scenario():
+ rt = SessionRuntime(_Echo())
+ stats = GenStats()
+ out = [t async for t in rt.generate_stream("a", _text(), _OPTS, stats)]
+ return out, stats
+
+ out, stats = asyncio.run(scenario())
+ assert "".join(out) == "Hello world"
+ assert stats.completion_tokens == 2
+ assert stats.finish_reason == "stop"
+ assert stats.generated_token_ids == [10, 11]
+
+
+def test_generate_stream_forwards_session_and_segments_to_worker():
+ captured = {}
+
+ class _Cap:
+ def stop(self):
+ pass
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ captured["session_id"] = config.session_id
+ captured["segments"] = config.prompt_segments
+ captured["prompt"] = prompt
+
+ async def scenario():
+ rt = SessionRuntime(_Cap())
+ seg = PromptInput(segments=[{"text": "a"}, {"ids": [1, 2]}])
+ async for _ in rt.generate_stream("sess", seg, _OPTS, GenStats()):
+ pass
+
+ asyncio.run(scenario())
+ assert captured["session_id"] == "sess"
+ assert captured["segments"] == [{"text": "a"}, {"ids": [1, 2]}]
+
+
+def test_cancellation_calls_worker_stop():
+ class _Blocking:
+ def __init__(self):
+ self._gate = threading.Event()
+ self.stopped = False
+
+ def stop(self):
+ self.stopped = True
+ self._gate.set()
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ token_callback("TOKEN")
+ self._gate.wait(timeout=5)
+
+ async def scenario():
+ w = _Blocking()
+ rt = SessionRuntime(w)
+ agen = rt.generate_stream("a", _text(), _OPTS).__aiter__()
+ assert await agen.__anext__() == "TOKEN" # worker now blocking
+ nxt = asyncio.ensure_future(agen.__anext__())
+ await asyncio.sleep(0.05)
+ nxt.cancel()
+ try:
+ await nxt
+ except asyncio.CancelledError:
+ pass
+ for _ in range(100): # let the worker observe stop()
+ if w.stopped:
+ break
+ await asyncio.sleep(0.02)
+ await agen.aclose()
+ return w
+
+ w = asyncio.run(scenario())
+ assert w.stopped
+
+
+def test_close_worker_shuts_down_worker():
+ w = _Worker()
+ SessionRuntime(w).close_worker()
+ assert w.proc_closed
+
+
+def test_prompt_input_requires_exactly_one():
+ import pytest
+
+ with pytest.raises(ValueError):
+ PromptInput()
+ with pytest.raises(ValueError):
+ PromptInput(text="x", segments=[{"text": "y"}])
diff --git a/extension/llm/server/python/tests/test_template.py b/extension/llm/server/python/tests/test_template.py
new file mode 100644
index 00000000000..43c2f8f3973
--- /dev/null
+++ b/extension/llm/server/python/tests/test_template.py
@@ -0,0 +1,183 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Contract tests for chat-template kwargs (e.g. enable_thinking) passthrough."""
+
+from executorch.extension.llm.server.python.chat_template import ChatTemplate
+from executorch.extension.llm.server.python.protocol import (
+ ChatMessage,
+ FunctionCall,
+ ToolCall,
+)
+
+
+class _FakeHF:
+ def __init__(self):
+ self.seen_kwargs = None
+ self.seen_messages = None
+ self.encode_add_special = None
+
+ def apply_chat_template(
+ self, messages, tools, add_generation_prompt, tokenize, **kwargs
+ ):
+ self.seen_kwargs = kwargs
+ self.seen_messages = messages
+ return "PROMPT"
+
+ # Default add_special_tokens=True mirrors real HF tokenizers (so a caller
+ # that forgets to disable specials would over-count).
+ def encode(self, text, add_special_tokens=True):
+ self.encode_add_special = add_special_tokens
+ return list(range(len(text))) # 1 id per char, deterministic
+
+
+def _template_with_fake(defaults=None):
+ t = ChatTemplate(
+ hf_tokenizer_path=None, allow_fallback=True, default_template_kwargs=defaults
+ )
+ fake = _FakeHF()
+ t._hf = fake
+ return t, fake
+
+
+def test_count_tokens_excludes_special_tokens():
+ # The rendered prompt already carries control tokens, so count_tokens must
+ # encode with add_special_tokens=False (matching the session/prefix-cache
+ # paths) — not the tokenizer's default True, which double-counts BOS/EOS and
+ # can falsely reject near-limit requests under --max-context.
+ t, fake = _template_with_fake()
+ n = t.count_tokens("PROMPT")
+ assert fake.encode_add_special is False
+ assert n == len("PROMPT")
+
+
+def test_default_template_kwargs_applied():
+ t, fake = _template_with_fake(defaults={"enable_thinking": False})
+ t.render([ChatMessage(role="user", content="hi")])
+ assert fake.seen_kwargs == {"enable_thinking": False}
+
+
+def test_per_request_kwargs_override_defaults():
+ t, fake = _template_with_fake(defaults={"enable_thinking": False})
+ t.render(
+ [ChatMessage(role="user", content="hi")],
+ template_kwargs={"enable_thinking": True},
+ )
+ assert fake.seen_kwargs["enable_thinking"] is True
+
+
+def test_no_kwargs_when_none():
+ t, fake = _template_with_fake(defaults=None)
+ t.render([ChatMessage(role="user", content="hi")])
+ assert fake.seen_kwargs == {}
+
+
+def test_fallback_ignores_kwargs_without_hf():
+ # No HF tokenizer → ChatML fallback, must not raise on kwargs.
+ t = ChatTemplate(
+ hf_tokenizer_path=None,
+ allow_fallback=True,
+ default_template_kwargs={"enable_thinking": False},
+ )
+ out = t.render([ChatMessage(role="user", content="hi")], template_kwargs={"x": 1})
+ assert "<|im_start|>user" in out and out.endswith("<|im_start|>assistant\n")
+
+
+# (5) Chat-template behaviors: multi-turn ordering, system message, roles.
+def test_multi_turn_order_preserved():
+ t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)
+ out = t.render(
+ [
+ ChatMessage(role="user", content="first"),
+ ChatMessage(role="assistant", content="second"),
+ ChatMessage(role="user", content="third"),
+ ]
+ )
+ assert out.index("first") < out.index("second") < out.index("third")
+ assert out.endswith("<|im_start|>assistant\n") # generation prompt appended
+
+
+def test_system_message_rendered():
+ t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)
+ out = t.render(
+ [
+ ChatMessage(role="system", content="You are terse."),
+ ChatMessage(role="user", content="hi"),
+ ]
+ )
+ assert "<|im_start|>system\nYou are terse." in out
+
+
+def test_each_role_labeled():
+ t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)
+ out = t.render(
+ [
+ ChatMessage(role="system", content="s"),
+ ChatMessage(role="user", content="u"),
+ ChatMessage(role="assistant", content="a"),
+ ]
+ )
+ for role in ("system", "user", "assistant"):
+ assert f"<|im_start|>{role}" in out
+
+
+# Tool round-trip: a turn-2 request (assistant tool_call + tool result) must
+# serialize into the shape any HF chat template consumes — the multi-turn loop
+# breaks at turn 2 otherwise. OpenAI sends tool-call arguments as a JSON string;
+# HF templates expect a mapping (Qwen renders `arguments|items`), so the server
+# decodes it before templating.
+def test_tool_call_arguments_decoded_for_template():
+ t, fake = _template_with_fake()
+ t.render(
+ [
+ ChatMessage(role="user", content="weather?"),
+ ChatMessage(
+ role="assistant",
+ content=None,
+ tool_calls=[
+ ToolCall(
+ index=0,
+ id="c1",
+ type="function",
+ function=FunctionCall(
+ name="get_weather", arguments='{"city": "Paris"}'
+ ),
+ )
+ ],
+ ),
+ ChatMessage(role="tool", tool_call_id="c1", content='{"temp_c": 18}'),
+ ]
+ )
+ msgs = fake.seen_messages
+ asst = next(m for m in msgs if m["role"] == "assistant")
+ assert asst["tool_calls"][0]["function"]["name"] == "get_weather"
+ # Decoded from the JSON string into a mapping the template can iterate.
+ assert asst["tool_calls"][0]["function"]["arguments"] == {"city": "Paris"}
+ tool = next(m for m in msgs if m["role"] == "tool")
+ assert tool["tool_call_id"] == "c1" and "temp_c" in tool["content"]
+
+
+def test_tool_call_non_json_arguments_left_as_string():
+ # A non-JSON arguments value must not crash; it passes through unchanged.
+ t, fake = _template_with_fake()
+ t.render(
+ [
+ ChatMessage(
+ role="assistant",
+ content=None,
+ tool_calls=[
+ ToolCall(
+ index=0,
+ id="c1",
+ type="function",
+ function=FunctionCall(name="f", arguments="not json"),
+ )
+ ],
+ )
+ ]
+ )
+ asst = next(m for m in fake.seen_messages if m["role"] == "assistant")
+ assert asst["tool_calls"][0]["function"]["arguments"] == "not json"
diff --git a/extension/llm/server/python/tests/test_tool_calls.py b/extension/llm/server/python/tests/test_tool_calls.py
new file mode 100644
index 00000000000..7b165fac287
--- /dev/null
+++ b/extension/llm/server/python/tests/test_tool_calls.py
@@ -0,0 +1,221 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Tool-calling tests (HTTP contract via the server).
+
+Hermes/Qwen format only. The server buffers the model's full output and parses
+it once into complete OpenAI tool_calls; parse failures degrade to visible text.
+"""
+
+import json
+
+WEATHER_TOOLS = [
+ {
+ "type": "function",
+ "function": {"name": "get_weather", "parameters": {"type": "object"}},
+ }
+]
+
+
+def _call_text(name, args):
+ return f'\n{{"name": "{name}", "arguments": {json.dumps(args)}}}\n'
+
+
+# --- HTTP contract: non-streaming ---------------------------------------
+
+
+def test_tool_call_nonstreaming(make_client):
+ client, _ = make_client(tokens=[_call_text("get_weather", {"city": "Paris"})])
+ body = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "weather in Paris?"}],
+ "tools": WEATHER_TOOLS,
+ },
+ ).json()
+ choice = body["choices"][0]
+ assert choice["finish_reason"] == "tool_calls"
+ tc = choice["message"]["tool_calls"][0]
+ assert tc["type"] == "function"
+ assert tc["function"]["name"] == "get_weather"
+ assert json.loads(tc["function"]["arguments"]) == {"city": "Paris"}
+
+
+def test_tool_call_streaming(make_client):
+ client, _ = make_client(tokens=[_call_text("get_weather", {"city": "Paris"})])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "weather?"}],
+ "tools": WEATHER_TOOLS,
+ "stream": True,
+ },
+ )
+ # Reuse the contract helper's SSE parsing.
+ chunks = []
+ for line in resp.text.splitlines():
+ line = line.strip()
+ if line.startswith("data:") and "[DONE]" not in line:
+ chunks.append(json.loads(line[len("data:") :].strip()))
+ tool_deltas = [
+ c["choices"][0]["delta"]["tool_calls"][0]
+ for c in chunks
+ if c["choices"] and c["choices"][0]["delta"].get("tool_calls")
+ ]
+ assert tool_deltas and tool_deltas[0]["function"]["name"] == "get_weather"
+ assert chunks[-1]["choices"][0]["finish_reason"] == "tool_calls"
+
+
+def test_undefined_tool_is_not_called(make_client):
+ # Model calls a tool not in the request's tools -> no tool_calls; the raw
+ # call stays visible as content (degrade to text, never silent drop).
+ client, _ = make_client(tokens=[_call_text("rm_rf", {"path": "/"})])
+ body = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "tools": WEATHER_TOOLS,
+ },
+ ).json()
+ msg = body["choices"][0]["message"]
+ assert msg.get("tool_calls") is None
+ assert "rm_rf" in (msg.get("content") or "") # not dropped — visible as text
+
+
+def test_mixed_valid_and_undefined_tool_degrades_to_text(make_client):
+ # A response with one valid + one undefined call must NOT emit the valid one
+ # while silently dropping the undefined one — the whole response degrades to
+ # visible text so the model's full intent is preserved.
+ client, _ = make_client(
+ tokens=[
+ _call_text("get_weather", {"city": "Paris"})
+ + _call_text("rm_rf", {"path": "/"})
+ ]
+ )
+ msg = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "tools": WEATHER_TOOLS,
+ },
+ ).json()["choices"][0]["message"]
+ assert msg.get("tool_calls") is None # no partial set
+ content = msg.get("content") or ""
+ assert "rm_rf" in content and "get_weather" in content # full intent visible
+
+
+def test_tool_choice_none_omits_tools_from_prompt():
+ from executorch.extension.llm.server.python.chat_template import ChatTemplate
+ from executorch.extension.llm.server.python.server import build_app
+ from executorch.extension.llm.server.python.serving_chat import ServingChat
+ from executorch.extension.llm.server.python.session_runtime import SessionRuntime
+ from executorch.extension.llm.server.python.tool_parsers import HermesDetector
+
+ # tool_choice="none" must NOT inject tool schemas into the chat template; if it
+ # did, the model could still emit a that we'd surface as plain text
+ # (parsing is disabled for "none"). Assert via a recording tokenizer.
+ from fastapi.testclient import TestClient
+
+ class _RecordingTok:
+ all_special_tokens: list = []
+
+ def __init__(self):
+ self.tools_seen = "UNSET"
+
+ def encode(self, text):
+ return [0]
+
+ def apply_chat_template(
+ self, messages, tools, add_generation_prompt, tokenize, **kwargs
+ ):
+ self.tools_seen = tools
+ return "PROMPT"
+
+ class _Runner:
+ def reset(self):
+ pass
+
+ def stop(self):
+ pass
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ if token_callback:
+ token_callback("ok")
+
+ rec = _RecordingTok()
+ template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)
+ template._hf = rec
+ runtime = SessionRuntime(_Runner())
+ serving = ServingChat(
+ runtime, template, "test-model", tool_detector_cls=HermesDetector
+ )
+ client = TestClient(build_app(serving, "test-model"))
+ body = {
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "tools": WEATHER_TOOLS,
+ "max_tokens": 8,
+ }
+
+ assert (
+ client.post(
+ "/v1/chat/completions", json={**body, "tool_choice": "none"}
+ ).status_code
+ == 200
+ )
+ assert rec.tools_seen is None # tools omitted from the rendered prompt
+
+ # Control: default ("auto") still passes the tools through to the template.
+ assert client.post("/v1/chat/completions", json=body).status_code == 200
+ assert rec.tools_seen == WEATHER_TOOLS
+
+
+def test_malformed_tool_call_falls_back_to_text(make_client):
+ # Broken JSON inside the markers must not crash; degrade to visible text.
+ client, _ = make_client(tokens=["\n{not json}\n"])
+ resp = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "hi"}],
+ "tools": WEATHER_TOOLS,
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.json()["choices"][0]["message"].get("tool_calls") is None
+
+
+def test_no_tools_field_means_text_even_if_markers_present(make_client):
+ # Without a tools array, tool markers are just content (not parsed).
+ client, _ = make_client(tokens=[_call_text("get_weather", {"city": "X"})])
+ body = client.post(
+ "/v1/chat/completions",
+ json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]},
+ ).json()
+ assert body["choices"][0]["message"].get("tool_calls") is None
+
+
+def test_parallel_calls_in_one_message(make_client):
+ # Two complete blocks in one output -> two structured calls.
+ tokens = [
+ _call_text("get_weather", {"city": "A"})
+ + _call_text("get_weather", {"city": "B"})
+ ]
+ client, _ = make_client(tokens=tokens)
+ body = client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "weather in A and B?"}],
+ "tools": WEATHER_TOOLS,
+ },
+ ).json()
+ calls = body["choices"][0]["message"]["tool_calls"]
+ assert [json.loads(c["function"]["arguments"])["city"] for c in calls] == ["A", "B"]
diff --git a/extension/llm/server/python/tests/test_worker_client.py b/extension/llm/server/python/tests/test_worker_client.py
new file mode 100644
index 00000000000..dbed8d396f3
--- /dev/null
+++ b/extension/llm/server/python/tests/test_worker_client.py
@@ -0,0 +1,163 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Tests for the generic WorkerClient JSONL protocol (no model/GPU/subprocess).
+
+A fake process stands in for the C++ worker: it records what the client writes
+and replays a scripted sequence of JSONL response lines.
+"""
+
+import json
+from dataclasses import dataclass, field
+
+import pytest
+
+from executorch.extension.llm.server.python.worker_client import (
+ spawn_worker,
+ WorkerClient,
+ WorkerError,
+)
+
+
+class _FakeStdin:
+ def __init__(self):
+ self.written = []
+
+ def write(self, s):
+ self.written.append(s)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+
+class _FakeStdout:
+ def __init__(self, lines):
+ self._lines = list(lines)
+
+ def readline(self):
+ return self._lines.pop(0) if self._lines else ""
+
+
+class _FakeProc:
+ def __init__(self, stdout_lines, returncode=None):
+ self.stdin = _FakeStdin()
+ self.stdout = _FakeStdout(stdout_lines)
+ self._returncode = returncode
+
+ def poll(self):
+ return self._returncode
+
+ @property
+ def returncode(self):
+ return self._returncode
+
+
+@dataclass
+class _Cfg:
+ max_new_tokens: int = 64
+ temperature: float = 0.0
+ stop: list = field(default_factory=list)
+
+
+def _lines(*objs):
+ return [json.dumps(o) + "\n" for o in objs]
+
+
+def test_generate_streams_tokens_then_stats():
+ proc = _FakeProc(
+ _lines(
+ {"token": "Hello"},
+ {"token": " world"},
+ {"done": True, "prompt_tokens": 4, "completion_tokens": 2},
+ )
+ )
+ client = WorkerClient(proc)
+ out, stats = [], {}
+ client.generate(
+ "hi",
+ _Cfg(temperature=0.7),
+ token_callback=out.append,
+ stats_callback=lambda s: stats.update(
+ prompt=s.num_prompt_tokens, gen=s.num_generated_tokens
+ ),
+ )
+ assert "".join(out) == "Hello world"
+ assert stats == {"prompt": 4, "gen": 2}
+ # The request carried prompt + sampling, one JSON line.
+ sent = json.loads(proc.stdin.written[0])
+ assert sent == {
+ "prompt": "hi",
+ "max_new_tokens": 64,
+ "temperature": 0.7,
+ "stop": [],
+ }
+
+
+def test_generate_forwards_stop_sequences():
+ proc = _FakeProc(_lines({"done": True, "prompt_tokens": 1, "completion_tokens": 0}))
+ WorkerClient(proc).generate("hi", _Cfg(stop=["STOP", "\n\n"]))
+ sent = json.loads(proc.stdin.written[0])
+ assert sent["stop"] == ["STOP", "\n\n"]
+
+
+def test_generate_reports_finish_reason():
+ proc = _FakeProc(
+ _lines(
+ {"token": "hi"},
+ {
+ "done": True,
+ "prompt_tokens": 2,
+ "completion_tokens": 1,
+ "finish_reason": "length",
+ },
+ )
+ )
+ seen = {}
+ WorkerClient(proc).generate(
+ "hi", _Cfg(), stats_callback=lambda s: seen.update(fr=s.finish_reason)
+ )
+ assert seen["fr"] == "length"
+
+
+def test_error_message_raises_worker_error():
+ proc = _FakeProc(_lines({"error": "boom"}))
+ with pytest.raises(WorkerError, match="boom"):
+ WorkerClient(proc).generate("hi", _Cfg())
+
+
+def test_exit_mid_request_raises():
+ proc = _FakeProc([]) # readline() -> "" means the worker exited
+ with pytest.raises(WorkerError, match="exited mid-request"):
+ WorkerClient(proc).generate("hi", _Cfg())
+
+
+def test_generate_on_dead_worker_raises():
+ proc = _FakeProc([], returncode=1)
+ with pytest.raises(WorkerError, match="worker exited"):
+ WorkerClient(proc).generate("hi", _Cfg())
+
+
+def test_spawn_worker_waits_for_ready():
+ proc = _FakeProc(_lines({"ready": True}))
+ client = spawn_worker(
+ ["/fake/worker", "--model_path", "m"], popen=lambda *a, **k: proc
+ )
+ assert isinstance(client, WorkerClient)
+
+
+def test_spawn_worker_not_ready_raises():
+ proc = _FakeProc(_lines({"oops": True}))
+ with pytest.raises(WorkerError, match="did not report ready"):
+ spawn_worker(["/fake/worker"], popen=lambda *a, **k: proc)
+
+
+def test_spawn_worker_no_output_raises():
+ proc = _FakeProc([])
+ with pytest.raises(WorkerError, match="failed to start"):
+ spawn_worker(["/fake/worker"], popen=lambda *a, **k: proc)
diff --git a/extension/llm/server/python/worker_client.py b/extension/llm/server/python/worker_client.py
new file mode 100644
index 00000000000..6b7d4e84132
--- /dev/null
+++ b/extension/llm/server/python/worker_client.py
@@ -0,0 +1,267 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Generic control-plane client for a model-execution worker process.
+
+Model execution runs in a separate C++ worker process — the Python server is
+HTTP/control plane only and never loads a model, links a backend, or imports a
+pybind module. This client spawns a worker binary and drives generation over
+JSONL on the worker's stdin/stdout. The protocol is model-agnostic: the same
+client serves a TextLLM worker, a Qwen worker, or any future model worker; only
+the binary and its launch args differ.
+
+Protocol (one JSON object per line):
+ worker -> stdout, once at startup: {"ready": true, "max_sessions": int,
+ "max_named_sessions": int}
+ client -> stdin:
+ generate: {"max_new_tokens": int, "temperature": float, "stop": [str, ...],
+ "session_id"?: str, and exactly one prompt form:
+ "prompt": str
+ "prompt_segments": [{"text": str} | {"ids": [int, ...]}]}
+ open: {"op": "open", "session_id": str}
+ close: {"op": "close", "session_id": str}
+ reset: {"op": "reset", "session_id": str} # clear context, keep slot
+ worker -> stdout:
+ generate: {"token": str} * (streamed)
+ {"done": true, "prompt_tokens": int, "completion_tokens": int,
+ "finish_reason": "stop" | "length",
+ "reused_prompt_tokens": int, "prefilled_prompt_tokens": int,
+ "session_reset_reason": "new"|"exact_prefix"|"dirty"|"mismatch"
+ |"equal",
+ "generated_token_ids"?: [int, ...]} # omitted if stop-trimmed
+ open: {"opened": true, "session_id": str}
+ close: {"closed": true, "session_id": str}
+ reset: {"reset": true, "session_id": str}
+ error: {"error": str, "code"?: str} # capacity_exhausted /
+ # unsupported_session
+
+The worker's stdout carries ONLY protocol JSON; its logs go to stderr. One
+request at a time per worker; the caller (SessionRuntime) serializes. A worker
+hosts one engine and routes requests to per-session_id state (anonymous requests
+share a scratch session); execution is synchronous.
+"""
+
+import json
+import logging
+import subprocess
+import threading
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class WorkerStats:
+ """Usage reported by a worker at the end of a request."""
+
+ num_prompt_tokens: int = 0
+ num_generated_tokens: int = 0
+ # Why generation stopped, as the worker saw it: "stop" (EOS / cooperative
+ # stop) or "length" (ran to max_new, possibly clamped to the context window).
+ # None if the worker didn't report it (older worker / fake).
+ finish_reason: Optional[str] = None
+ # Warm-resume accounting (V2b.1): how many prompt tokens were served from the
+ # session's resident KV state vs actually prefilled this request, and why
+ # ("new"|"exact_prefix"|"dirty"|"mismatch"|"equal"). Not exposed as OpenAI
+ # usage; logged for measuring warm-resume hit rate. None on older workers.
+ reused_prompt_tokens: int = 0
+ prefilled_prompt_tokens: int = 0
+ session_reset_reason: Optional[str] = None
+ # The exact (non-terminal) token ids generated this turn. The control plane
+ # stores these per session and splices them back as an `ids` prompt segment
+ # next turn, so a prior assistant span is an exact token extension instead of
+ # a lossy chat-template re-render (V2b.1.5). Empty on older workers.
+ generated_token_ids: list = field(default_factory=list)
+
+
+class WorkerError(RuntimeError):
+ """A worker process failed, exited, or reported a generation error.
+
+ `code` carries the worker's structured error code when present
+ ("capacity_exhausted", "unsupported_session"), so the HTTP layer can map it
+ to the right status; None for unstructured failures.
+ """
+
+ def __init__(self, message: str, code: Optional[str] = None):
+ super().__init__(message)
+ self.code = code
+
+
+class WorkerClient:
+ """Drives one model-execution worker process over JSONL (raw transport).
+
+ Exposes ``generate(prompt, config, token_callback, stats_callback)`` /
+ ``stop()`` plus the session ops ``open_session`` / ``close_session`` /
+ ``reset_session`` that SessionRuntime drives. Calls are serialized by a lock
+ and by SessionRuntime (one in-flight request). The control plane never
+ executes model code.
+ """
+
+ def __init__(self, proc: subprocess.Popen, max_named_sessions: int = 0):
+ self._proc = proc
+ self._lock = threading.Lock()
+ # Named sessions this worker can host (0 = scratch-only / single session).
+ self.max_named_sessions = max_named_sessions
+
+ def reset(self) -> None:
+ # The worker resets its session at the start of each request; nothing to
+ # do here.
+ pass
+
+ def stop(self) -> None:
+ # Best-effort: a request is synchronous and not interruptible mid-
+ # generation in V1.
+ pass
+
+ def open_session(self, session_id: str) -> None:
+ """Admit a named session (idempotent). Raises WorkerError with a `code`
+ ("capacity_exhausted" / "unsupported_session") if the worker refuses."""
+ self._op({"op": "open", "session_id": session_id}, ack_key="opened")
+
+ def close_session(self, session_id: str) -> None:
+ """Destroy a named session, freeing its state (idempotent)."""
+ self._op({"op": "close", "session_id": session_id}, ack_key="closed")
+
+ def reset_session(self, session_id: str) -> None:
+ """Clear a named session's context (KV/recurrent + resident tokens) but
+ keep its capacity slot allocated (idempotent)."""
+ self._op({"op": "reset", "session_id": session_id}, ack_key="reset")
+
+ def _op(self, request: dict, ack_key: str) -> None:
+ with self._lock:
+ if self._proc.poll() is not None:
+ raise WorkerError(
+ f"worker exited (code {self._proc.returncode}); restart the server"
+ )
+ try:
+ self._proc.stdin.write(json.dumps(request) + "\n")
+ self._proc.stdin.flush()
+ except (BrokenPipeError, ValueError) as e:
+ raise WorkerError("worker stdin is closed") from e
+ line = self._proc.stdout.readline()
+ if not line:
+ raise WorkerError("worker exited mid-request")
+ msg = json.loads(line)
+ if msg.get(ack_key):
+ return
+ if "error" in msg:
+ raise WorkerError(msg["error"], code=msg.get("code"))
+ raise WorkerError(f"unexpected worker response: {msg}")
+
+ @staticmethod
+ def _on_done(msg: dict, stats_callback) -> None:
+ reason = msg.get("session_reset_reason")
+ if reason is not None and logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "warm-resume: reason=%s reused=%d prefilled=%d",
+ reason,
+ msg.get("reused_prompt_tokens", 0),
+ msg.get("prefilled_prompt_tokens", 0),
+ )
+ if stats_callback is not None:
+ stats_callback(
+ WorkerStats(
+ num_prompt_tokens=msg.get("prompt_tokens", 0),
+ num_generated_tokens=msg.get("completion_tokens", 0),
+ finish_reason=msg.get("finish_reason"),
+ reused_prompt_tokens=msg.get("reused_prompt_tokens", 0),
+ prefilled_prompt_tokens=msg.get("prefilled_prompt_tokens", 0),
+ session_reset_reason=reason,
+ generated_token_ids=msg.get("generated_token_ids", []),
+ )
+ )
+
+ def generate(self, prompt, config, token_callback=None, stats_callback=None):
+ request = {
+ "max_new_tokens": getattr(config, "max_new_tokens", -1),
+ "temperature": getattr(config, "temperature", 0.0),
+ "stop": list(getattr(config, "stop", []) or []),
+ }
+ # Token-ID segments (V2b.1.5) take precedence over the rendered string:
+ # they let prior assistant spans be exact id runs, not lossy re-renders.
+ # `is not None` (not truthiness): segments is a distinct prompt form, kept
+ # whatever its content (the worker validates non-empty).
+ segments = getattr(config, "prompt_segments", None)
+ if segments is not None:
+ request["prompt_segments"] = segments
+ else:
+ request["prompt"] = prompt
+ session_id = getattr(config, "session_id", None)
+ if session_id:
+ request["session_id"] = session_id
+ with self._lock:
+ if self._proc.poll() is not None:
+ raise WorkerError(
+ f"worker exited (code {self._proc.returncode}); restart the server"
+ )
+ try:
+ self._proc.stdin.write(json.dumps(request) + "\n")
+ self._proc.stdin.flush()
+ except (BrokenPipeError, ValueError) as e:
+ raise WorkerError("worker stdin is closed") from e
+
+ while True:
+ line = self._proc.stdout.readline()
+ if not line:
+ raise WorkerError("worker exited mid-request")
+ msg = json.loads(line)
+ if "token" in msg:
+ if token_callback is not None:
+ token_callback(msg["token"])
+ elif msg.get("done"):
+ self._on_done(msg, stats_callback)
+ return
+ elif "error" in msg:
+ raise WorkerError(msg["error"], code=msg.get("code"))
+
+ def close(self) -> None:
+ """Terminate the worker process (called at server shutdown)."""
+ if self._proc.poll() is not None:
+ return
+ try:
+ if self._proc.stdin is not None:
+ self._proc.stdin.close()
+ except OSError:
+ pass
+ try:
+ self._proc.terminate()
+ self._proc.wait(timeout=5)
+ except Exception: # noqa: BLE001 - shutdown best-effort
+ self._proc.kill()
+
+
+def spawn_worker(
+ cmd: Sequence[str],
+ env: Optional[dict] = None,
+ cwd: Optional[str] = None,
+ popen: Callable[..., subprocess.Popen] = subprocess.Popen,
+) -> WorkerClient:
+ """Start a worker process and block until it reports ``{"ready": true}``.
+
+ `cmd` is the worker binary and its launch args (model/tokenizer paths). The
+ worker loads the model once before reporting ready, so a slow load surfaces
+ here rather than on the first request.
+ """
+ logger.info("Starting model worker: %s", cmd[0])
+ proc = popen(
+ list(cmd),
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ text=True,
+ bufsize=1,
+ env=env,
+ cwd=cwd,
+ )
+ line = proc.stdout.readline()
+ if not line:
+ raise WorkerError("worker failed to start (no output; check its stderr).")
+ msg = json.loads(line)
+ if not msg.get("ready"):
+ raise WorkerError(f"worker did not report ready: {msg}")
+ max_named = int(msg.get("max_named_sessions", 0))
+ logger.info("Model worker ready (max_named_sessions=%d).", max_named)
+ return WorkerClient(proc, max_named_sessions=max_named)