diff --git a/extension/llm/server/README.md b/extension/llm/server/README.md new file mode 100644 index 00000000000..0b18d31cae5 --- /dev/null +++ b/extension/llm/server/README.md @@ -0,0 +1,35 @@ +# ExecuTorch LLM Server + +OpenAI-compatible serving for ExecuTorch LLMs, so any OpenAI-compatible agent +harness (pi, opencode, ...) can use ExecuTorch as a local backend. + +``` +extension/llm/server/ + spec/ # language-neutral OpenAI contract ExecuTorch targets + conformance/ # one test suite every language server must pass + python/ # Python server implementation (current) + # cpp/ # future: no-Python single-binary server +``` + +Why this layout: the OpenAI contract is identical across languages, so the +**spec** and **conformance** suite are shared, and each language gets its own +implementation directory. The real cross-language reuse comes from the C++ +`LLMEngine`/`LLMSession` primitives underneath, packaged as a process-isolated +**worker binary** (`text_llm_worker`) that any control plane drives over a small +JSONL protocol — the server is a thin protocol shell that spawns and talks to +that worker. See `python/README.md` to run it. + +Status: experimental, reliability-first and deliberately narrow. Implemented: +`/health`, `/v1/models`, `/v1/chat/completions` (streaming + non-streaming), +Hugging Face chat templates (`--hf-tokenizer`), `temperature` / `max_tokens` / +`max_completion_tokens` / `stop`, Hermes tool calling by default +(`...` JSON, complete calls only; model-specific launchers +may select the Qwen XML format) with `tool_choice="none"`, +structured API errors, and best-effort cancellation. One worker process with +serialized execution; it hosts many isolated sessions on one weight load (warm +append-only resume across turns). KV/prefix state lives inside the +worker/session, not the control plane. Unsupported params (including `top_p`, +`seed`, `n>1`, `reasoning_effort`, penalties, `logit_bias`, `response_format`, +`logprobs`, and `tool_choice="required"`) are rejected with a structured 400 +rather than silently ignored. See `python/README.md` to run it and +`spec/README.md` for the exact contract. diff --git a/extension/llm/server/conformance/test_openai_contract.py b/extension/llm/server/conformance/test_openai_contract.py new file mode 100644 index 00000000000..3347126b5f4 --- /dev/null +++ b/extension/llm/server/conformance/test_openai_contract.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Language-neutral OpenAI-contract conformance tests. + +Runs against any base URL (ExecuTorch, llama.cpp, mlx-lm, ...) so every server +implementation is validated against one shared spec. Point it at a running +server: + + OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest test_openai_contract.py + +Skips automatically if no server is reachable. +""" + +import json +import os +import urllib.error +import urllib.request + +import pytest + +BASE_URL = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8000/v1").rstrip("/") +MODEL = os.environ.get("OPENAI_MODEL", "executorch") + + +def _post(path: str, body: dict, stream: bool = False): + req = urllib.request.Request( + f"{BASE_URL}{path}", + data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, + method="POST", + ) + return urllib.request.urlopen(req, timeout=120) + + +def _server_up() -> bool: + try: + urllib.request.urlopen(f"{BASE_URL}/models", timeout=5) + return True + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _server_up(), reason="no OpenAI server at OPENAI_BASE_URL" +) + + +def test_models_listing(): + with urllib.request.urlopen(f"{BASE_URL}/models", timeout=10) as r: + data = json.loads(r.read()) + assert data["object"] == "list" + assert any("id" in m for m in data["data"]) + + +def test_chat_completion_nonstreaming(): + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "Say hello in one word."}], + "max_tokens": 16, + "temperature": 0.0, + } + with _post("/chat/completions", body) as r: + data = json.loads(r.read()) + assert data["object"] == "chat.completion" + assert data["choices"][0]["message"]["role"] == "assistant" + assert isinstance(data["choices"][0]["message"]["content"], str) + assert data["choices"][0]["finish_reason"] is not None + + +def test_chat_completion_streaming(): + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "Count to three."}], + "max_tokens": 32, + "stream": True, + } + saw_role = saw_content = saw_done = False + with _post("/chat/completions", body, stream=True) as r: + for raw in r: + line = raw.decode().strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + saw_done = True + break + chunk = json.loads(payload) + assert chunk["object"] == "chat.completion.chunk" + delta = chunk["choices"][0]["delta"] + saw_role = saw_role or delta.get("role") == "assistant" + saw_content = saw_content or bool(delta.get("content")) + assert saw_role and saw_content and saw_done + + +def test_multibyte_streaming_integrity(): + # Byte-level BPE can split a multi-byte character across tokens; the stream + # must reassemble it, not abort with a UTF-8 decode error. + body = { + "model": MODEL, + "messages": [ + {"role": "user", "content": "Reply with exactly: 你好世界 🌍 café"} + ], + "max_tokens": 32, + "temperature": 0.0, + "stream": True, + } + content, saw_done, saw_error = "", False, False + with _post("/chat/completions", body, stream=True) as r: + for raw in r: + line = raw.decode().strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + saw_done = True + break + chunk = json.loads(payload) + if "error" in chunk: + saw_error = True + content += ( + chunk["choices"][0]["delta"].get("content", "") + if chunk.get("choices") + else "" + ) + assert saw_done and not saw_error + assert isinstance(content, str) and content # reassembled, valid UTF-8 + + +def test_usage_chunk_in_stream(): + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "Say hi."}], + "max_tokens": 16, + "stream": True, + "stream_options": {"include_usage": True}, + } + usage = None + with _post("/chat/completions", body, stream=True) as r: + for raw in r: + line = raw.decode().strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + break + chunk = json.loads(payload) + if chunk.get("usage"): + usage = chunk["usage"] + assert usage is not None, "no usage chunk emitted with include_usage" + assert usage["prompt_tokens"] > 0 and usage["completion_tokens"] > 0 + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a city.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, +} + + +def test_tool_call_response_shape(): + body = { + "model": MODEL, + "messages": [ + {"role": "user", "content": "What is the weather in Paris? Use the tool."} + ], + "tools": [WEATHER_TOOL], + "max_tokens": 128, + "temperature": 0.0, + } + with _post("/chat/completions", body) as r: + data = json.loads(r.read()) + calls = data["choices"][0]["message"].get("tool_calls") + assert calls, "expected tool_calls in response" + tc = calls[0] + assert tc["type"] == "function" + assert tc["id"] + assert tc["function"]["name"] == "get_weather" + json.loads(tc["function"]["arguments"]) # arguments is a JSON string + assert data["choices"][0]["finish_reason"] == "tool_calls" + + +def test_error_body_shape(): + # Over-long prompt -> structured 400 (OpenAI error envelope), not a 500/drop. + body = { + "model": MODEL, + "messages": [{"role": "user", "content": "word " * 40000}], + "max_tokens": 8, + } + try: + _post("/chat/completions", body) + raise AssertionError("expected an HTTP error for over-long prompt") + except urllib.error.HTTPError as e: + assert 400 <= e.code < 500 + err = json.loads(e.read())["error"] + assert err["message"] and err["type"] diff --git a/extension/llm/server/cpp/CMakeLists.txt b/extension/llm/server/cpp/CMakeLists.txt new file mode 100644 index 00000000000..653cf61bea8 --- /dev/null +++ b/extension/llm/server/cpp/CMakeLists.txt @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Generic model-execution worker for standard .pte TextLLM models. One binary, +# no registry/factory: it constructs TextLLMEngine/TextLLMSession directly and +# speaks the JSONL worker protocol (worker_client.py). Model execution is C++ +# only — the Python server is HTTP/control plane. +# +# Build like the example runners (standalone), e.g. from this directory: cmake +# -S . -B /extension/llm/server/cpp \ +# -DCMAKE_PREFIX_PATH= -DEXECUTORCH_BUILD_XNNPACK=ON cmake +# --build <...>/extension/llm/server/cpp --target text_llm_worker + +cmake_minimum_required(VERSION 3.24) +project(llm_server_workers) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) +# Vendored single-include nlohmann/json for the worker protocol (no new dep). +set(_json_include + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include +) + +# gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../third-party/gflags) +find_package(gflags REQUIRED) + +# executorch +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# CPU ops +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# Custom + quantized kernels that export_llm models need, whole-archived so the +# static op registrations survive the linker: llama::custom_sdpa (from +# use_sdpa_with_kv_cache) and quantized_decomposed ops (from quantized exports). +# Without these the model loads but execution fails with "Missing operator". +if(TARGET custom_ops) + executorch_target_link_options_shared_lib(custom_ops) + list(APPEND link_libraries custom_ops) +endif() +if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_kernels quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) +endif() + +# Extensions (Engine/Session lives in extension_llm_runner) +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# XNNPACK: the standard CPU backend for normal .pte TextLLM models. +list(APPEND link_libraries xnnpack_backend) +executorch_target_link_options_shared_lib(xnnpack_backend) + +# Tokenizer +list(APPEND link_libraries tokenizers::tokenizers) + +add_executable(text_llm_worker text_llm_worker.cpp) +target_include_directories( + text_llm_worker PUBLIC ${_common_include_directories} ${_json_include} +) +target_link_libraries(text_llm_worker PUBLIC ${link_libraries}) + +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(text_llm_worker) + target_link_options(text_llm_worker PRIVATE "LINKER:-s") +endif() diff --git a/extension/llm/server/cpp/text_llm_worker.cpp b/extension/llm/server/cpp/text_llm_worker.cpp new file mode 100644 index 00000000000..f7bb9d69915 --- /dev/null +++ b/extension/llm/server/cpp/text_llm_worker.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Generic model-execution worker for standard .pte TextLLM models. +// +// All model execution lives here in C++ (via TextLLMEngine / TextLLMSession, +// the stable serving abstraction) — no Python model code, no pybind, no +// in-process Python serving. The OpenAI control plane (Python) spawns this +// process and drives it over JSONL on stdin/stdout (see worker_client.py). The +// JSONL protocol and the decode loop are shared across all workers in +// worker_loop.h; this file only constructs the engine/session/tokenizer. + +#include + +#include +#include +#include +#include + +#include +#include + +DEFINE_string(model_path, "", "Self-contained model .pte file path."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); + +namespace { +namespace llm = ::executorch::extension::llm; +using ::executorch::runtime::Error; +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty() || FLAGS_tokenizer_path.empty()) { + ET_LOG( + Error, "text_llm_worker: --model_path and --tokenizer_path required"); + return 1; + } + + // TextLLMEngine requires a self-contained .pte: external .ptd weights are not + // supported for shared sessions (a model-specific worker handles that path). + auto engine = llm::TextLLMEngine::create( + FLAGS_model_path, FLAGS_tokenizer_path, std::nullopt); + if (!engine) { + ET_LOG(Error, "text_llm_worker: failed to create engine"); + return 1; + } + auto session_result = engine->create_session(); + if (session_result.error() != Error::Ok) { + ET_LOG(Error, "text_llm_worker: failed to create session"); + return 1; + } + auto session = std::move(session_result.get()); + + // The session decodes token ids to text internally; this tokenizer encodes + // the rendered prompt to ids. Same tokenizer.json -> same vocabulary. + auto tokenizer = llm::load_tokenizer(FLAGS_tokenizer_path); + if (!tokenizer) { + ET_LOG(Error, "text_llm_worker: failed to load tokenizer"); + return 1; + } + + return llm::run_worker_stdio_loop(*session, *tokenizer, engine->metadata()); +} diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h new file mode 100644 index 00000000000..883bcac69cd --- /dev/null +++ b/extension/llm/server/cpp/worker_loop.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Shared model-worker generation loop + JSONL protocol, used by every model +// worker (the generic text_llm_worker and model-specific workers like +// qwen3_5_moe_worker). A worker only constructs its engine/session/tokenizer +// and calls run_worker_stdio_loop(); the protocol and the decode loop live here +// once, so protocol changes (e.g. multi-session) land in a single place. +// +// Protocol (one JSON object per line; matches worker_client.py): +// worker -> stdout, once: {"ready": true} +// client -> stdin, per request: {"prompt": str, "max_new_tokens": int, +// "temperature": float, "stop": [str, ...]} +// worker -> stdout, per request: {"token": str} * (streamed) +// {"done": true, "prompt_tokens": int, +// "completion_tokens": int, +// "finish_reason": "stop" | "length"} +// or {"error": str} +// +// stdout carries ONLY protocol JSON; all logs go to stderr (ET_LOG). One +// request at a time (the control plane serializes; V1 is one worker == one +// session). + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +// Emit one protocol object as a JSON line on stdout. error_handler::replace +// keeps a stray invalid UTF-8 byte (byte-level BPE) from aborting +// serialization. +inline void worker_emit(const nlohmann::json& obj) { + std::cout << obj.dump( + -1, ' ', false, nlohmann::json::error_handler_t::replace) + << "\n"; + std::cout.flush(); +} + +// One generation request: reset the session, encode the prompt, prefill, then +// loop decode_one() streaming complete-UTF-8 text pieces. A terminal step (EOS +// or cooperative stop) ends generation and is not emitted or counted. Throws +// std::runtime_error on failure; the caller reports it as {"error": ...}. +inline void worker_handle_request( + LLMSession& session, + ::tokenizers::Tokenizer& tokenizer, + const std::unordered_map& metadata, + const nlohmann::json& req) { + const std::string prompt = req.at("prompt").get(); + int64_t max_new = req.value("max_new_tokens", static_cast(-1)); + const float temperature = req.value("temperature", 0.0f); + // Stop strings (the request's `stop` sequences): terminate at the token + // boundary where one appears so we don't generate to EOS/max_new past it. The + // control plane also enforces these as a backstop. + const std::vector stops = + req.value("stop", std::vector{}); + + if (session.reset() != ::executorch::runtime::Error::Ok) { + throw std::runtime_error("session reset failed"); + } + // No special tokens: the prompt is already rendered (the control plane + // applied the chat template), matching the runner's own encode path. + auto encode_result = tokenizer.encode(prompt, /*bos=*/0, /*eos=*/0); + if (!encode_result.ok()) { + throw std::runtime_error("prompt encode failed"); + } + std::vector ids = std::move(*encode_result); + if (ids.empty()) { + throw std::runtime_error("empty prompt"); + } + const int64_t num_prompt = static_cast(ids.size()); + + // Bound generation to the context window: default to filling the remaining + // room, and clamp an explicit max_new_tokens too, so decode never steps past + // the window (which would error mid-generation after partial output). + const auto ctx_it = metadata.find(kMaxContextLen); + if (ctx_it != metadata.end()) { + const int64_t room = ctx_it->second - num_prompt; + if (room <= 0) { + throw std::runtime_error( + "prompt fills the context window; no room to generate"); + } + if (max_new <= 0 || max_new > room) { + max_new = room; + } + } else if (max_new <= 0) { + max_new = 2048; + } + + SamplingConfig sampling; + sampling.temperature = temperature; + if (session.prefill_tokens(std::move(ids), &sampling) != + ::executorch::runtime::Error::Ok) { + throw std::runtime_error("prefill failed"); + } + + std::string buf; // bytes not yet forming a complete UTF-8 prefix + std::string pending; // complete-UTF-8 text held back for stop-string matching + int64_t num_generated = 0; + std::string finish = "length"; // EOS or stop string -> "stop" + bool stop_string = false; // a request stop string was matched + for (int64_t step = 0; step < max_new; ++step) { + auto step_result = session.decode_one(sampling); + if (step_result.error() != ::executorch::runtime::Error::Ok) { + throw std::runtime_error("decode failed"); + } + const auto& d = step_result.get(); + if (d.is_terminal) { + finish = "stop"; + break; // terminal step (EOS / cooperative stop): not emitted or counted + } + ++num_generated; + buf += d.text_piece; + const size_t cut = utf8_complete_prefix_len(buf); + if (cut > 0) { + pending += buf.substr(0, cut); + buf.erase(0, cut); + } + bool stop_hit = false; + const size_t safe = stop_safe_prefix_len(pending, stops, stop_hit); + if (safe > 0) { + worker_emit({{"token", pending.substr(0, safe)}}); + pending.erase(0, safe); + } + if (stop_hit) { + finish = "stop"; // reached a stop string: drop it and everything after + stop_string = true; + break; + } + } + if (!stop_string) { + // EOS or length: flush held-back text + any trailing incomplete bytes + // (replaced if invalid). A stop-string hit drops the remainder instead. + pending += buf; + if (!pending.empty()) { + worker_emit({{"token", pending}}); + } + } + // finish_reason: "stop" if the model emitted EOS or hit a stop string, else + // "length" — it ran to max_new (possibly clamped to the context window). + worker_emit( + {{"done", true}, + {"prompt_tokens", num_prompt}, + {"completion_tokens", num_generated}, + {"finish_reason", finish}}); +} + +// Emit {"ready": true}, then read JSONL requests from stdin and dispatch each +// to worker_handle_request, reporting exceptions as {"error": ...} and +// continuing to serve. Returns 0 when stdin closes. +inline int run_worker_stdio_loop( + LLMSession& session, + ::tokenizers::Tokenizer& tokenizer, + const std::unordered_map& metadata) { + worker_emit({{"ready", true}}); + std::string line; + while (std::getline(std::cin, line)) { + if (line.empty()) { + continue; + } + try { + worker_handle_request( + session, tokenizer, metadata, nlohmann::json::parse(line)); + } catch (const std::exception& e) { // report and keep serving + worker_emit({{"error", std::string(e.what())}}); + } + } + return 0; +} + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md new file mode 100644 index 00000000000..e14e6176c81 --- /dev/null +++ b/extension/llm/server/python/README.md @@ -0,0 +1,169 @@ +# ExecuTorch LLM Server — Python + +A thin OpenAI-compatible HTTP server for ExecuTorch LLMs. The Python process is +the **control plane** only — HTTP, OpenAI protocol, chat templating, tool +parsing, request validation. Model execution runs in a separate **C++ worker +process** (`text_llm_worker`) that the server drives over a small JSONL protocol. +The control plane never loads a model, links a backend, or imports a runtime +pybind. + +## Install + +```bash +pip install -r requirements.txt +# transformers is optional but recommended for model-correct chat templates +pip install transformers +``` + +The server itself is pure Python (fastapi, pydantic, httpx). The model runs in +the C++ worker, which you build standalone (like the example runners) from +`../cpp`: + +```bash +cmake -S ../cpp -B /extension/llm/server/cpp \ + -DCMAKE_PREFIX_PATH= -DEXECUTORCH_BUILD_XNNPACK=ON +cmake --build /extension/llm/server/cpp --target text_llm_worker +# -> /extension/llm/server/cpp/text_llm_worker +``` + +### Model & runtime requirements + +LLM `.pte` files exported via `export_llm` use ExecuTorch custom/quantized ops: +`use_sdpa_with_kv_cache` → `llama::custom_sdpa`, and quantized exports +(`embedding_quantize`, `8da4w`, ...) → `quantized_decomposed` ops. The worker +binary links the kernel libraries that provide them (the C++ equivalents of +`-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON` / +`-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON`); see the canonical +[Llama README](../../../../examples/models/llama/README.md). Without them the +worker fails to load the method. + +Tokenizer: pass the model's tokenizer — `tokenizer.json` (HF, e.g. Qwen3) or +`tokenizer.model` (Llama); the worker auto-detects. If you see an RE2 lookahead +warning it falls back to PCRE2 and still works (build with +`-DSUPPORT_REGEX_LOOKAHEAD=ON` for the native regex path). + +## Run + +```bash +python -m executorch.extension.llm.server.python.server \ + --model-path /path/to/model.pte \ + --tokenizer-path /path/to/tokenizer.bin \ + --hf-tokenizer Qwen/Qwen2.5-Coder-7B-Instruct \ + --model-id qwen2.5-coder \ + --host 127.0.0.1 --port 8000 +``` + +The server spawns the worker (it blocks until the worker has loaded the model and +reported ready, so a slow load surfaces at startup, not on the first request). + +`--hf-tokenizer` is **required** (it applies the model's real `chat_template`) +unless you pass `--allow-chatml-fallback` to opt into approximate generic ChatML +— which is wrong for many instruct/tool models and can't reproduce controls like +`enable_thinking`. + +Key flags: + +| Flag | Effect | +|------|--------| +| `--hf-tokenizer` | model's HF chat template (required unless fallback) | +| `--allow-chatml-fallback` | opt into approximate ChatML when no HF tokenizer | +| `--no-think` | default `enable_thinking=False` (e.g. Qwen3) | +| `--max-context N` | reject over-long prompts with 400 instead of failing mid-gen | +| `--num-runners N` | Worker processes — **1 only** (one worker hosts many isolated sessions on one weight load; more would duplicate weights) | +| `--worker-bin PATH` | path to the `text_llm_worker` binary (default: `cmake-out/extension/llm/server/cpp/text_llm_worker`) | + +## Use from an agent harness + +- **opencode** (`opencode.json`): + ```json + { "provider": { "executorch": { + "npm": "@ai-sdk/openai-compatible", + "options": { "baseURL": "http://127.0.0.1:8000/v1" }, + "models": { "qwen2.5-coder": { "name": "Qwen2.5-Coder (ExecuTorch)" } } } } } + ``` +- **pi** (`~/.pi/agent/models.json`): + ```json + { "providers": { "executorch": { + "baseUrl": "http://127.0.0.1:8000/v1", "api": "openai-completions", + "apiKey": "x", "models": [ { "id": "qwen2.5-coder" } ] } } } + ``` + +## Validate + +Two layers, both contract-focused (assert on the wire, not internals): + +```bash +# 1. Hermetic unit tests — fake worker, no model/GPU/subprocess, fast (CI-friendly). +pip install pytest httpx +pytest tests/ + +# 2. Conformance — black-box, against a LIVE server (real model, or llama.cpp/mlx-lm). +OPENAI_BASE_URL=http://127.0.0.1:8000/v1 pytest ../conformance/test_openai_contract.py +``` + +`tests/` builds a `SessionRuntime` over a single `FakeRunner` worker, so the +real server/protocol/streaming code is tested over HTTP without a `.pte`. The +worker JSONL protocol is covered separately by `tests/test_worker_client.py`. + +## Architecture + +Control plane (this dir, Python): an OpenAI adapter (`serving_chat`) over a +stateful `SessionRuntime` over one `WorkerClient` — server, protocol, chat +templating, streaming bridge, tool parsing — no CUDA, no model, no pybind. Data +plane (C++): a worker process (`text_llm_worker`) that owns all model state +(many isolated sessions on one weight load, warm-resume prefix logic) and does +all token stepping and KV mutation; it speaks one JSON object per line on +stdin/stdout. + +JSONL protocol (stdout carries protocol JSON only; logs go to stderr): + +``` +worker -> stdout, once at startup: {"ready": true} +client -> stdin, per request: {"prompt", "max_new_tokens", "temperature"} +worker -> stdout, per request: {"token": str} * (streamed) + {"done": true, "prompt_tokens", "completion_tokens"} + or {"error": str} +``` + +Process isolation is the reliable shape for CUDA/AOTI models: executing the model +inside a live asyncio server process can segfault (validated with Qwen3.5-MoE); +the worker is a plain process with no asyncio loop, and the control plane only +does blocking pipe I/O on its executor thread. + +| File | Role | +|------|------| +| `server.py` | FastAPI app, routes, CLI entrypoint, worker spawn | +| `protocol.py` | OpenAI request/response schemas | +| `chat_template.py` | messages (+tools) → prompt string | +| `worker_client.py` | spawn a worker process + drive it over JSONL (raw transport) | +| `session_runtime.py` | stateful runtime over one worker: open/generate/reset/close + streaming bridge | +| `serving_chat.py` | `/v1/chat/completions` OpenAI adapter (streaming + non-streaming, stop, tools) | +| `tool_parsers/` | Hermes/Qwen `` parser only | +| `cpp/text_llm_worker.cpp` | the generic C++ worker binary | + +### Model workers + +The generic `text_llm_worker` serves the text model (`TextLLMEngine`). A new +model ships its own worker binary under its example (e.g. +`examples/models/qwen3_5_moe/qwen35_moe_worker.cpp` constructs `Qwen35MoEEngine`) +that speaks the same JSONL protocol, plus a launcher that points the same control +plane at that binary via `--worker-bin`. The dependency points one way: a model +example may reuse the generic control plane; the generic control plane never +imports an example. Backend specifics (CUDA/AOTI, Metal) stay inside the worker. + +## Scope & caveats + +Deliberately narrow (reliability-first): Hermes/Qwen tool calling only; +unsupported sampling params are rejected, not ignored. **One worker process, +serialized execution** (one in-flight request; concurrent requests queue). +Session capacity is determined by the worker/engine — a single worker hosts many +isolated sessions on one weight load — so `--num-runners` accepts 1; extra worker +processes would each carry their own copy of the weights. + +Cancellation is best-effort: a worker request runs to completion and is not +interruptible mid-generation in V1, so `runner.stop()` means "the control plane +stops consuming and the worker finishes the current request" rather than a hard +cancel. There is **no prefix cache in V1 serving**; if KV prefix reuse returns it +will live inside the worker/session, not in the Python control plane. Multiple +workers, weight sharing across sessions on a backend that supports it, adaptive +thinking, and multi-session subagents are future work. diff --git a/extension/llm/server/python/server.py b/extension/llm/server/python/server.py new file mode 100644 index 00000000000..94c55479275 --- /dev/null +++ b/extension/llm/server/python/server.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible HTTP server for ExecuTorch LLMs. + +Point any OpenAI-compatible agent harness (pi, opencode, ...) at +``http://:/v1``. + +This process is the CONTROL PLANE only: FastAPI/uvicorn + OpenAI protocol, chat +templating, tool parsing, request validation. It runs NO model code and imports +no runtime pybind. Model execution lives in a separate C++ worker process +(``text_llm_worker``) driven over JSONL via WorkerClient. + +One worker process, serialized execution (one in-flight request; concurrent +requests queue). Session capacity is set by the worker/engine -- a single worker +hosts many isolated sessions on one weight load; extra worker processes would +duplicate the weights, so `--num-runners` accepts 1. + +Example: + python -m executorch.extension.llm.server.python.server \\ + --model-path model.pte --tokenizer-path tokenizer.bin \\ + --hf-tokenizer Qwen/Qwen2.5-Coder-7B-Instruct --model-id qwen2.5-coder +""" + +import argparse +import logging +import os +from pathlib import Path + +from fastapi import FastAPI +from fastapi.responses import JSONResponse, StreamingResponse + +from .chat_template import ChatTemplate +from .errors import APIError +from .protocol import ChatCompletionRequest, ModelCard, ModelList +from .serving_chat import ServingChat +from .session_runtime import SessionRuntime +from .tool_parsers import HermesDetector +from .worker_client import spawn_worker + +logger = logging.getLogger(__name__) + + +def _default_worker_bin() -> str: + repo_root = Path(__file__).resolve().parents[4] + return str( + repo_root + / "cmake-out" + / "extension" + / "llm" + / "server" + / "cpp" + / "text_llm_worker" + ) + + +def _spawn(args): + """Spawn the C++ text_llm_worker and return a ready WorkerClient.""" + env = dict(os.environ) + conda = os.environ.get("CONDA_PREFIX") + if conda: + env["LD_LIBRARY_PATH"] = f"{conda}/lib:" + env.get("LD_LIBRARY_PATH", "") + worker_bin = args.worker_bin or _default_worker_bin() + cmd = [ + worker_bin, + "--model_path", + args.model_path, + "--tokenizer_path", + args.tokenizer_path, + ] + logger.info("Starting model worker subprocess (loads the model once)...") + return spawn_worker(cmd, env=env) + + +def build_app(serving: ServingChat, model_id: str) -> FastAPI: + app = FastAPI(title="ExecuTorch LLM Server") + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/v1/models") + async def list_models() -> ModelList: + return ModelList(data=[ModelCard(id=model_id)]) + + @app.post("/v1/chat/completions") + async def chat_completions(req: ChatCompletionRequest): + # Typed param → FastAPI validates the body and returns 422 on bad input. + # APIError (e.g. context_length_exceeded) → structured 4xx/5xx, never a + # dropped connection. Mid-stream failures are handled inside the stream. + try: + result = await serving.create(req) + except APIError as e: + return JSONResponse(e.body(), status_code=e.status) + if req.stream: + return StreamingResponse(result, media_type="text/event-stream") + return JSONResponse(result.model_dump(exclude_none=True)) + + return app + + +def main() -> None: + p = argparse.ArgumentParser(description="ExecuTorch OpenAI-compatible LLM server") + p.add_argument( + "--model-path", + required=True, + help="Path to the self-contained .pte model (external .ptd weights are not " + "supported by the generic text worker; use a model-specific launcher).", + ) + p.add_argument("--tokenizer-path", required=True, help="Path to the tokenizer") + p.add_argument( + "--hf-tokenizer", + default=None, + help="HF tokenizer id/dir for model-correct chat templating (required unless " + "--allow-chatml-fallback).", + ) + p.add_argument( + "--allow-chatml-fallback", + action="store_true", + help="Allow approximate generic ChatML templating when --hf-tokenizer is absent. " + "Off by default: the fallback can't reproduce model-specific controls.", + ) + p.add_argument( + "--model-id", default="executorch", help="Model id reported on /v1/models" + ) + p.add_argument( + "--no-think", + action="store_true", + help="Default reasoning off (sends enable_thinking=False to the chat template, " + "e.g. Qwen3). Per-request chat_template_kwargs still override this.", + ) + p.add_argument( + "--max-context", + type=int, + default=None, + help="Model context window; if set (and a tokenizer is available), prompts that " + "exceed it are rejected with 400 context_length_exceeded instead of failing mid-generation. " + "Set this to match the value used at export.", + ) + p.add_argument( + "--num-runners", + type=int, + default=1, + help="Worker processes. 1 only: one worker hosts many isolated sessions " + "on a single weight load; more workers would duplicate the weights.", + ) + p.add_argument( + "--worker-bin", + default=None, + help="Path to the text_llm_worker binary " + "(default: cmake-out/extension/llm/server/cpp/text_llm_worker).", + ) + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, default=8000) + args = p.parse_args() + logging.basicConfig(level=logging.INFO) + + if args.num_runners != 1: + p.error( + "Only 1 worker process is supported (it hosts many isolated sessions " + "on one weight load); more workers would duplicate the weights." + ) + + default_template_kwargs = {"enable_thinking": False} if args.no_think else None + # Requires --hf-tokenizer unless --allow-chatml-fallback (raises otherwise). + template = ChatTemplate( + args.hf_tokenizer, + default_template_kwargs=default_template_kwargs, + allow_fallback=args.allow_chatml_fallback, + ) + worker = _spawn(args) # one worker hosting many isolated sessions + runtime = SessionRuntime(worker) + serving = ServingChat( + runtime, + template, + args.model_id, + max_context=args.max_context, + # Hermes JSON is the generic default; a model-specific server (e.g. a + # Qwen launcher) selects the Qwen XML detector instead. + tool_detector_cls=HermesDetector, + ) + + app = build_app(serving, args.model_id) + + @app.on_event("shutdown") + def _stop_worker(): + runtime.close_worker() + + import uvicorn # imported here so build_app() is usable without the ASGI server + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py new file mode 100644 index 00000000000..fdd8032a06f --- /dev/null +++ b/extension/llm/server/python/serving_chat.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""/v1/chat/completions OpenAI adapter: validates requests, renders the chat +template, parses tool calls, and formats OpenAI responses. It owns no model or +session state -- generation goes through SessionRuntime, and the token-ID warm- +resume transcript lives in OpenAITranscriptState.""" + +import json +import logging +import math +from typing import AsyncIterator, Optional + +from .chat_template import ChatTemplate +from .errors import APIError, ContextLengthExceeded, GenerationError +from .protocol import ( + _new_id, + ChatCompletionChunk, + ChatCompletionRequest, + ChatCompletionResponse, + Choice, + ChunkChoice, + DeltaMessage, + FunctionCall, + ResponseMessage, + ToolCall, + Usage, +) +from .session_runtime import GenerationOptions, GenStats, PromptInput, SessionRuntime +from .tool_parsers import HermesDetector, ToolCallItem + +logger = logging.getLogger(__name__) + + +def _earliest_stop(text: str, stops: list[str]) -> Optional[int]: + """Index of the earliest special-token occurrence in `text`, or None.""" + best = None + for s in stops: + i = text.find(s) + if i != -1 and (best is None or i < best): + best = i + return best + + +class ServingChat: + def __init__( + self, + runtime: SessionRuntime, + template: ChatTemplate, + model_id: str, + max_context: Optional[int] = None, + tool_detector_cls: Optional[type[HermesDetector]] = None, + ): + self._runtime = runtime + self._template = template + self._model_id = model_id + self._max_context = max_context + # Detector CLASS; a fresh instance is created per request so streaming + # state is never shared across concurrent requests. + self._tool_detector_cls = tool_detector_cls + # Special tokens (e.g. <|im_end|>) the runner decodes to text; we cut the + # visible content at the first one so they don't leak into responses. + self._stops = template.special_tokens() + + @staticmethod + def _tool_schemas(req: ChatCompletionRequest) -> dict[str, dict]: + """Map each defined tool name to its JSON-schema ``parameters`` object. + + The detector uses the key set to validate names and the schema to coerce + values to their declared types (the Qwen XML format is stringly-typed).""" + schemas = {} + for t in req.tools or []: + fn = t.get("function", {}) if isinstance(t, dict) else {} + name = fn.get("name") + if name: + schemas[name] = fn.get("parameters") or {} + return schemas + + def _strip_specials(self, text: str) -> str: + cut = _earliest_stop(text, self._stops) + return text[:cut] if cut is not None else text + + @staticmethod + def _to_openai_tool_call(item: ToolCallItem) -> ToolCall: + return ToolCall( + index=item.tool_index, + id=_new_id("call"), + type="function", + function=FunctionCall(name=item.name, arguments=item.arguments), + ) + + def _tools_active(self, req: ChatCompletionRequest) -> bool: + # tool_choice="none" disables tools even when the client sends them. + return bool(self._tool_detector_cls and req.tools and req.tool_choice != "none") + + @staticmethod + def _request_stops(req: ChatCompletionRequest) -> list[str]: + s = req.stop + if not s: + return [] + return [s] if isinstance(s, str) else [x for x in s if x] + + @staticmethod + def _apply_stop(text: str, stops: list[str]) -> str: + """Truncate at the earliest stop string (the stop itself is excluded).""" + cut = _earliest_stop(text, stops) + return text[:cut] if cut is not None else text + + def _truncate_raw(self, text: str, req: ChatCompletionRequest) -> str: + """Cut raw model output at the earliest special token or request stop + sequence BEFORE tool parsing, so a tool call (or any text) past the stop + boundary is neither parsed nor emitted.""" + return self._apply_stop(text, self._stops + self._request_stops(req)) + + async def _collect_until_stop(self, stream: AsyncIterator[str], stops: list[str]): + """Accumulate a buffered (non-streamed) generation into one string, + halting the runtime early once a stop string (special token or request + stop) appears, then draining so stats finalize. Returns (text, stopped): + `stopped` lets the caller force finish_reason="stop" even when tokens + queued before the runtime observed stop() pushed the count to max_tokens.""" + text = "" + stopped = False + async for tok in stream: + text += tok + if stops and _earliest_stop(text, stops) is not None: + stopped = True + self._runtime.stop() + async for _ in stream: # drain so stats_cb fires + pass + break + return text, stopped + + def _extract_tools(self, req: ChatCompletionRequest, text: str): + """Returns (tool_calls | None, content_text). Falls back to plain text.""" + if self._tools_active(req): + parsed = self._tool_detector_cls().detect_and_parse( + text, self._tool_schemas(req) + ) + if parsed.calls: + content = self._strip_specials(parsed.normal_text) or None + return [self._to_openai_tool_call(c) for c in parsed.calls], content + text = parsed.normal_text + return None, self._strip_specials(text) + + async def _clean( + self, stream: AsyncIterator[str], stops: list[str], on_stop=None + ) -> AsyncIterator[str]: + # Yield text up to the earliest stop string (special token or request + # `stop`), buffering across tokens so a stop spanning chunks is caught. + # On a hit: optionally stop the runner early, then drain the source so it + # finalizes (usage stats recorded, worker thread joined). + hold = ( + max((len(s) for s in stops), default=1) - 1 + ) # keep a possible partial-stop tail + buf = "" + async for token in stream: + buf += token + cut = _earliest_stop(buf, stops) + if cut is not None: + if cut > 0: + yield buf[:cut] + if on_stop is not None: + on_stop() + async for _ in stream: # drain so stats_cb fires + pass + return + if hold == 0: + yield buf + buf = "" + elif len(buf) > hold: + yield buf[:-hold] + buf = buf[-hold:] + if buf: + yield buf + + def _options(self, req: ChatCompletionRequest) -> GenerationOptions: + return GenerationOptions( + max_new_tokens=req.resolved_max_tokens(), + temperature=req.temperature if req.temperature is not None else 0.0, + # Let the worker terminate at the same boundary the control plane + # would cut: the model's special tokens (e.g. <|im_end|>) AND request + # stop sequences. This stops generation at end-of-turn even when the + # worker's EOS-by-token-id check misses it, instead of running to + # max_new (or erroring) past the turn. The server's + # _clean/_collect_until_stop still re-apply these as a backstop. + stop=self._stops + self._request_stops(req), + ) + + def _finish_reason( + self, + req: ChatCompletionRequest, + completion_tokens: int, + tool_calls=None, + stopped: bool = False, + worker_finish: Optional[str] = None, + ) -> str: + # Precedence: tool call > stop boundary > worker reason > length heuristic. + # `stopped` (a server-side stop sequence / special token) wins even over + # the worker, since that truncation happened in the control plane. + if tool_calls: + return "tool_calls" + if stopped: + return "stop" + # The worker knows whether it hit EOS ("stop") or ran to max_new ("length", + # possibly a clamp to the context window) — trust it over the token-count + # heuristic, which can't see a silent clamp. + if worker_finish in ("stop", "length"): + return worker_finish + mt = req.resolved_max_tokens() + return "length" if mt and mt > 0 and completion_tokens >= mt else "stop" + + @staticmethod + def _reject_invalid_values(req: ChatCompletionRequest) -> None: + """Reject out-of-range values (invalid_value); these take precedence over + the unsupported-parameter error.""" + if req.temperature is not None and ( + not math.isfinite(req.temperature) + or req.temperature < 0.0 + or req.temperature > 2.0 + ): + raise APIError( + 400, + f"temperature must be between 0 and 2 (got {req.temperature}).", + "invalid_request_error", + "invalid_value", + ) + # max_tokens / max_completion_tokens, if given, must be positive integers + # (OpenAI rejects 0 and negatives; our -1 sentinel means "unset/auto"). + for field_name in ("max_tokens", "max_completion_tokens"): + v = getattr(req, field_name) + if v is not None and v <= 0: + raise APIError( + 400, + f"{field_name} must be a positive integer (got {v}).", + "invalid_request_error", + "invalid_value", + ) + + @staticmethod + def _reject_unsupported_params(req: ChatCompletionRequest) -> None: + """Reject params we don't honor rather than silently ignoring them (a + client relying on e.g. top_p/seed/logprobs would otherwise get wrong + behavior). Only the no-op/default value of each passes: top_p exactly + 1.0; penalties 0; response_format type "text"; tool_choice none/auto/ + unset; parallel_tool_calls true (false can't be guaranteed without + constraining); logprobs are not returned at all.""" + rf = req.response_format + flags = [ + (req.n != 1, "n>1"), + (req.top_p is not None and req.top_p != 1.0, "top_p"), + (req.seed is not None, "seed"), + (req.reasoning_effort is not None, "reasoning_effort"), + (bool(req.frequency_penalty), "frequency_penalty"), + (bool(req.presence_penalty), "presence_penalty"), + (req.top_k is not None, "top_k"), + (bool(req.logit_bias), "logit_bias"), + ( + bool(rf) and rf.get("type", "text") != "text", + "response_format (only 'text')", + ), + (bool(req.logprobs), "logprobs"), + (req.top_logprobs is not None, "top_logprobs"), + (req.parallel_tool_calls is False, "parallel_tool_calls=false"), + ( + req.tool_choice not in (None, "none", "auto"), + "tool_choice (only 'none' or 'auto')", + ), + ] + unsupported = [label for cond, label in flags if cond] + if unsupported: + raise APIError( + 400, + f"Unsupported parameter(s): {', '.join(unsupported)}. This server honors " + "temperature, max_tokens/max_completion_tokens, stop, and tools for the " + "configured tool-call format.", + "invalid_request_error", + "unsupported_parameter", + ) + + async def create(self, req: ChatCompletionRequest): + self._reject_invalid_values(req) + self._reject_unsupported_params(req) + # tool_choice="none" must hide tools from the model: if we still render + # the tool schemas, the model can emit a that we'd surface as + # plain text (parsing is disabled), instead of a normal answer. + template_tools = None if req.tool_choice == "none" else req.tools + prompt = self._template.render( + req.messages, tools=template_tools, template_kwargs=req.chat_template_kwargs + ) + # Pre-flight context check: reject cleanly instead of failing mid-generation + # (only possible when a tokenizer is available to count, e.g. --hf-tokenizer). + if self._max_context: + count = self._template.count_tokens(prompt) + if count is not None: + if count >= self._max_context: + raise ContextLengthExceeded(count, self._max_context) + # An explicit max_tokens that wouldn't fit alongside the prompt + # must be rejected here, not run until the worker hits the context + # limit mid-decode (a 500 / streaming error after partial output). + requested = req.resolved_max_tokens() + if requested > 0 and count + requested > self._max_context: + raise ContextLengthExceeded(count, self._max_context, requested) + options = self._options(req) + prompt_input = PromptInput(text=prompt) + if req.stream: + return self._stream(req, prompt_input, options) + return await self._complete(req, prompt_input, options) + + async def _complete( + self, + req: ChatCompletionRequest, + prompt: PromptInput, + options: GenerationOptions, + ) -> ChatCompletionResponse: + stats = GenStats() + try: + # Collect raw text (markers intact for tool parsing), halting early + # at a stop boundary (special token or request stop). + text, stopped = await self._collect_until_stop( + self._runtime.generate_stream(None, prompt, options, stats), + self._stops + self._request_stops(req), + ) + except Exception as e: # noqa: BLE001 - surface as a structured API error + raise GenerationError(str(e)) + # Bound the raw output at the first stop/special token BEFORE tool + # parsing, so a call after the stop boundary is not parsed/emitted. + tool_calls, content = self._extract_tools(req, self._truncate_raw(text, req)) + finish = self._finish_reason( + req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason + ) + return ChatCompletionResponse( + model=self._model_id, + choices=[ + Choice( + message=ResponseMessage(content=content, tool_calls=tool_calls), + finish_reason=finish, + ) + ], + usage=Usage( + prompt_tokens=stats.prompt_tokens, + completion_tokens=stats.completion_tokens, + total_tokens=stats.prompt_tokens + stats.completion_tokens, + ), + ) + + async def _stream( + self, + req: ChatCompletionRequest, + prompt: PromptInput, + options: GenerationOptions, + ) -> AsyncIterator[str]: + cid = _new_id("chatcmpl") + + def chunk(delta: DeltaMessage, finish=None) -> str: + c = ChatCompletionChunk( + id=cid, + model=self._model_id, + choices=[ChunkChoice(delta=delta, finish_reason=finish)], + ) + return f"data: {c.model_dump_json(exclude_none=True)}\n\n" + + yield chunk(DeltaMessage(role="assistant")) + error: Optional[Exception] = None + use_tools = self._tools_active(req) + tool_calls = None + content = None + + stats = GenStats() + stop_hit = [False] # set when a stop boundary is reached (forces finish="stop") + stops = self._stops + self._request_stops(req) + try: + if use_tools: + # v1: buffer the (usually short) tool response, parse once. + # Halt early at a stop boundary, and bound the raw output + # BEFORE parsing so post-stop tool calls / text don't leak. + raw, stop_hit[0] = await self._collect_until_stop( + self._runtime.generate_stream(None, prompt, options, stats), + stops, + ) + tool_calls, content = self._extract_tools( + req, self._truncate_raw(raw, req) + ) + else: + # Plain chat: stream tokens live (best UX), cutting at special + # tokens or request stop sequences and halting early on a hit. + def on_stop(): + stop_hit[0] = True + self._runtime.stop() + + async for token in self._clean( + self._runtime.generate_stream(None, prompt, options, stats), + stops, + on_stop=on_stop, + ): + yield chunk(DeltaMessage(content=token)) + except ( + Exception + ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket + error = e + + if error is not None: + err = { + "message": f"Generation failed: {error}", + "type": "server_error", + "code": None, + } + yield f"data: {json.dumps({'error': err})}\n\n" + yield "data: [DONE]\n\n" + return + + if use_tools: + if content: + yield chunk(DeltaMessage(content=content)) + for tc in tool_calls or []: + yield chunk(DeltaMessage(tool_calls=[tc])) + finish = self._finish_reason( + req, + stats.completion_tokens, + tool_calls, + stop_hit[0], + stats.finish_reason, + ) + else: + finish = self._finish_reason( + req, + stats.completion_tokens, + stopped=stop_hit[0], + worker_finish=stats.finish_reason, + ) + yield chunk(DeltaMessage(), finish=finish) + if req.stream_options and req.stream_options.include_usage: + usage_chunk = ChatCompletionChunk( + id=cid, + model=self._model_id, + choices=[], + usage=Usage( + prompt_tokens=stats.prompt_tokens, + completion_tokens=stats.completion_tokens, + total_tokens=stats.prompt_tokens + stats.completion_tokens, + ), + ) + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" + yield "data: [DONE]\n\n" diff --git a/extension/llm/server/python/session_runtime.py b/extension/llm/server/python/session_runtime.py new file mode 100644 index 00000000000..e73dbafadc4 --- /dev/null +++ b/extension/llm/server/python/session_runtime.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Python's stateful local-LLM runtime over one C++ worker process. + +This is the internal boundary between protocol adapters (OpenAI chat, future +native/agent surfaces) and the worker. The adapter speaks sessions, prompts, and +generation parameters; the worker (driven over JSONL by a WorkerClient) owns all +model execution and session state (KV/recurrent, resident token ids, warm-resume +prefix logic). The Python server never loads a model, links a backend, or imports +a runtime pybind. + +A SessionRuntime owns exactly one worker and serializes access to it (one +in-flight request at a time), bridging the worker's blocking generate() into an +async token stream. Multi-worker scheduling / named-session affinity is out of +scope: a single worker already hosts many isolated sessions on one weight load, +routed by session_id inside the worker. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import AsyncIterator, Optional + +logger = logging.getLogger(__name__) + +_SENTINEL = object() + + +@dataclass +class PromptInput: + """A prompt as either a single rendered string or token-ID segments. Exactly + one of `text` / `segments` is set. Segments ([{"text": str} | {"ids": [int]}]) + let an adapter splice exact prior-turn token ids in place of a lossy + re-render (see openai_transcript).""" + + text: Optional[str] = None + segments: Optional[list] = None + + def __post_init__(self): + if (self.text is None) == (self.segments is None): + raise ValueError("exactly one of PromptInput.text / .segments must be set") + if self.segments is not None and not self.segments: + raise ValueError("PromptInput.segments must be non-empty") + + +@dataclass +class GenerationOptions: + """Sampling/length knobs forwarded to the worker (only what we honor today).""" + + max_new_tokens: int + temperature: float = 0.0 + stop: list[str] = field(default_factory=list) + + +@dataclass +class GenStats: + """Per-request metadata the worker reports at the end of generation.""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + # Worker-reported stop reason ("stop" | "length"), or None if not reported. + finish_reason: Optional[str] = None + # Warm-resume accounting (V2b.1): tokens served from the session's resident + # state vs prefilled this request, and why. + reused_prompt_tokens: int = 0 + prefilled_prompt_tokens: int = 0 + session_reset_reason: Optional[str] = None + # Exact token ids generated this turn (V2b.1.5), for an adapter's transcript + # store. Empty when the worker doesn't report them (e.g. a stop-trimmed turn). + generated_token_ids: list = field(default_factory=list) + + +# Forwarded to WorkerClient.generate() as the per-request config it reads fields +# off; keeps that low-level contract unchanged while the runtime's public surface +# is PromptInput + GenerationOptions + session_id. +@dataclass +class _WorkerRequest: + max_new_tokens: int + temperature: float + stop: list[str] + session_id: Optional[str] + prompt_segments: Optional[list] + + +class SessionRuntime: + """Stateful runtime over a single worker. `worker` is a WorkerClient (a fake + in tests) exposing generate()/stop()/close() and the session ops + open_session/reset_session/close_session.""" + + def __init__(self, worker): + self._worker = worker + # One executor thread; the lock guarantees the worker is never driven by + # two requests at once (it is single-in-flight). + self._executor = ThreadPoolExecutor(max_workers=1) + self._lock = asyncio.Lock() + + async def open(self, session_id: str) -> None: + """Admit a named session before generation so a capacity refusal surfaces + up front (the adapter maps it to an HTTP status) rather than mid-stream. + Idempotent.""" + await self._session_op("open_session", session_id) + + async def reset(self, session_id: str) -> None: + """Clear a named session's context (KV/recurrent + resident ids) but keep + its capacity slot. Idempotent.""" + await self._session_op("reset_session", session_id) + + async def close(self, session_id: str) -> None: + """Destroy a named session, freeing its state and slot. Idempotent.""" + await self._session_op("close_session", session_id) + + async def _session_op(self, method: str, session_id: str) -> None: + op = getattr(self._worker, method, None) + if op is None: + return # worker doesn't support sessions (e.g. a minimal fake) + loop = asyncio.get_running_loop() + async with self._lock: + await loop.run_in_executor(self._executor, op, session_id) + + def stop(self) -> None: + """Request the in-flight generation stop at the next token boundary.""" + self._worker.stop() + + async def generate_stream( + self, + session_id: Optional[str], + prompt: PromptInput, + options: GenerationOptions, + stats: Optional[GenStats] = None, + ) -> AsyncIterator[str]: + """Yield generated text pieces from the worker, holding the worker lock + for the whole generation. `stats` (if given) is filled in place with the + worker's terminal metadata (per-request, so concurrent callers don't + race). session_id None routes to the worker's anonymous scratch session.""" + out_stats = stats if stats is not None else GenStats() + req = _WorkerRequest( + max_new_tokens=options.max_new_tokens, + temperature=options.temperature, + stop=list(options.stop), + session_id=session_id, + prompt_segments=prompt.segments, + ) + prompt_text = prompt.text or "" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def token_cb(token: str) -> None: + loop.call_soon_threadsafe(queue.put_nowait, token) + + def stats_cb(s) -> None: + out_stats.prompt_tokens = s.num_prompt_tokens + out_stats.completion_tokens = s.num_generated_tokens + out_stats.finish_reason = getattr(s, "finish_reason", None) + out_stats.reused_prompt_tokens = getattr(s, "reused_prompt_tokens", 0) + out_stats.prefilled_prompt_tokens = getattr(s, "prefilled_prompt_tokens", 0) + out_stats.session_reset_reason = getattr(s, "session_reset_reason", None) + out_stats.generated_token_ids = getattr(s, "generated_token_ids", []) + + def run() -> None: + try: + self._worker.generate(prompt_text, req, token_cb, stats_cb) + except Exception as e: # noqa: BLE001 - surface to the stream consumer + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL) + + async with self._lock: + fut = loop.run_in_executor(self._executor, run) + try: + while True: + item = await queue.get() + if item is _SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + except asyncio.CancelledError: + self._worker.stop() + raise + finally: + await fut + + def close_worker(self) -> None: + """Shut the worker process down and the executor (called at shutdown).""" + close = getattr(self._worker, "close", None) + if close is not None: + close() + self._executor.shutdown(wait=False, cancel_futures=True) diff --git a/extension/llm/server/python/tests/conftest.py b/extension/llm/server/python/tests/conftest.py new file mode 100644 index 00000000000..b91f0aec26e --- /dev/null +++ b/extension/llm/server/python/tests/conftest.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fixtures for hermetic contract tests. + +We build a SessionRuntime over a single FakeRunner worker, so the tests exercise +the real server, protocol, templating, and streaming code over the HTTP boundary +with NO model, tokenizer, GPU, or worker subprocess. This mirrors ExecuTorch's +fake_llm_executor approach: fake the worker, test the real surface. + +The control plane imports no runtime pybind; only fastapi, pydantic, httpx, and +pytest are required. +""" + +import pytest + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.server import build_app +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.session_runtime import SessionRuntime +from executorch.extension.llm.server.python.tool_parsers import HermesDetector +from executorch.extension.llm.server.python.worker_client import WorkerError +from fastapi.testclient import TestClient + + +class _FakeStats: + num_prompt_tokens = 5 + num_generated_tokens = 0 + finish_reason = None + + +class FakeRunner: + """Canned engine: emits fixed tokens, records the config it was given. + + With max_named_sessions > 0 it also models the worker's session admission: + open_session() enforces capacity and reports structured WorkerError codes, + matching the real worker's contract.""" + + def __init__( + self, tokens, fail=False, finish_reason=None, max_named_sessions=0, gen_ids=None + ): + self._tokens = list(tokens) + self._fail = fail + self._finish_reason = finish_reason # worker-reported stop reason, if any + self._gen_ids = list(gen_ids or []) # ids reported per turn (V2b.1.5) + self.captured_config = None + self.stopped = False + self.reset_count = 0 + self.max_named_sessions = max_named_sessions + self.open_named = set() # currently-open named sessions + self.opened_log = [] # every successful open, in order + self.reset_log = [] # every reset_session call, in order + + def reset(self): + self.reset_count += 1 + + def stop(self): + self.stopped = True + + def open_session(self, session_id): + if session_id in self.open_named: + return # idempotent + if self.max_named_sessions == 0: + raise WorkerError("no named sessions", code="unsupported_session") + if len(self.open_named) >= self.max_named_sessions: + raise WorkerError("session capacity exhausted", code="capacity_exhausted") + self.open_named.add(session_id) + self.opened_log.append(session_id) + + def close_session(self, session_id): + self.open_named.discard(session_id) + + def reset_session(self, session_id): + # Clears context but keeps the slot (stays in open_named). + self.reset_log.append(session_id) + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + self.captured_config = config + if self._fail: + raise RuntimeError("Generation failed") + for tok in self._tokens: + if token_callback: + token_callback(tok) + if stats_callback: + stats = _FakeStats() + stats.num_generated_tokens = len(self._tokens) + stats.finish_reason = self._finish_reason + stats.generated_token_ids = list(self._gen_ids) + stats_callback(stats) + + +class _FakeTokenizer: + """Minimal stand-in for an HF tokenizer (counting + templating).""" + + all_special_tokens: list = [] + + def __init__(self, prompt_tokens): + self._n = prompt_tokens + + def encode(self, text, add_special_tokens=False): + return [0] * self._n + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + return "PROMPT" + + +@pytest.fixture +def make_client(): + def _make( + tokens=("Hello", ", ", "world"), + max_context=None, + prompt_tokens=None, + fail=False, + finish_reason=None, + max_named_sessions=0, + gen_ids=None, + ): + fake = FakeRunner( + tokens, + fail=fail, + finish_reason=finish_reason, + max_named_sessions=max_named_sessions, + gen_ids=gen_ids, + ) + runtime = SessionRuntime(fake) # one fake worker + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + if prompt_tokens is not None: + template._hf = _FakeTokenizer(prompt_tokens) + serving = ServingChat( + runtime, + template, + "test-model", + max_context=max_context, + tool_detector_cls=HermesDetector, + ) + return TestClient(build_app(serving, "test-model")), fake + + return _make diff --git a/extension/llm/server/python/tests/test_contract.py b/extension/llm/server/python/tests/test_contract.py new file mode 100644 index 00000000000..6a460150b18 --- /dev/null +++ b/extension/llm/server/python/tests/test_contract.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hermetic OpenAI-contract tests (fake engine, no model/GPU). + +These assert on the public HTTP/wire contract only — response object shapes, +the streaming chunk protocol, status codes — never on internal classes or +methods. Implementation can change freely as long as these pass. +""" + +import json + +import pytest + + +def _sse_chunks(text): + chunks, done = [], False + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + done = True + continue + chunks.append(json.loads(payload)) + return chunks, done + + +def test_health(make_client): + client, _ = make_client() + assert client.get("/health").json() == {"status": "ok"} + + +def test_models_listing_shape(make_client): + client, _ = make_client() + body = client.get("/v1/models").json() + assert body["object"] == "list" + assert body["data"][0]["id"] == "test-model" + assert body["data"][0]["object"] == "model" + + +def test_chat_nonstreaming_shape(make_client): + client, _ = make_client(tokens=["Hello", ", ", "world"]) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "chat.completion" + choice = body["choices"][0] + assert choice["message"]["role"] == "assistant" + assert choice["message"]["content"] == "Hello, world" + assert choice["finish_reason"] == "stop" + for k in ("prompt_tokens", "completion_tokens", "total_tokens"): + assert k in body["usage"] + assert body["usage"]["total_tokens"] == ( + body["usage"]["prompt_tokens"] + body["usage"]["completion_tokens"] + ) + + +def test_chat_streaming_protocol(make_client): + client, _ = make_client(tokens=["a", "b", "c"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + chunks, done = _sse_chunks(resp.text) + assert done, "stream must terminate with data: [DONE]" + assert all(c["object"] == "chat.completion.chunk" for c in chunks) + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks) + assert content == "abc" + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + + +def test_request_params_forwarded_to_generation(make_client): + # Contract behavior: the server must honor max_tokens/temperature. + client, fake = make_client() + client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 7, + "temperature": 0.1, + }, + ) + assert fake.captured_config.max_new_tokens == 7 + assert abs(fake.captured_config.temperature - 0.1) < 1e-6 + + +def test_special_tokens_forwarded_to_worker_as_stops(make_client): + # The worker must be told to stop at the model's end-of-turn special tokens + # (e.g. <|im_end|>), not just request `stop` sequences. Otherwise a worker + # whose EOS-by-token-id check misses the turn end runs to max_new (or errors + # forwarding its own end token) past it — the text_llm_worker "decode failed" + # seen on a real model. (Forwarding only request stops would leave this [].) + client, fake = make_client() + client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert "<|im_end|>" in (fake.captured_config.stop or []) + + +def test_request_stop_forwarded_to_worker(make_client): + # A request `stop` sequence must reach the worker (so it can terminate early), + # not only be applied by the Python backstop. The stop tests elsewhere pass + # via that backstop even if forwarding regresses; this asserts forwarding. + client, fake = make_client() + client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stop": ["STOP"], + }, + ) + # stop == special tokens + request stops, so check membership (not equality). + assert "STOP" in (fake.captured_config.stop or []) + + +def test_tools_field_accepted(make_client): + # tools is part of the contract even before parsing is enforced (M2). + client, _ = make_client() + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": [ + {"type": "function", "function": {"name": "f", "parameters": {}}} + ], + }, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["message"]["role"] == "assistant" + + +def test_invalid_request_returns_422(make_client): + client, _ = make_client() + resp = client.post( + "/v1/chat/completions", json={"model": "test-model"} + ) # no messages + assert resp.status_code == 422 + + +def test_special_tokens_stripped_nonstreaming(make_client): + # The runner may decode EOS/special tokens to text; they must not leak. + client, _ = make_client(tokens=["Hello", " world", "<|im_end|>", "LEAK"]) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.json()["choices"][0]["message"]["content"] == "Hello world" + + +def test_usage_populated_when_special_token_cuts_early(make_client): + # Regression: cutting at a special token must not skip usage stats. + client, _ = make_client(tokens=["Hello", "<|im_end|>", "LEAK"]) + body = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ).json() + assert body["choices"][0]["message"]["content"] == "Hello" + assert body["usage"]["completion_tokens"] > 0 + + +def test_special_tokens_stripped_streaming(make_client): + client, _ = make_client(tokens=["Hello", " world", "<|im_end|>", "LEAK"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + chunks, _ = _sse_chunks(resp.text) + content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks) + assert content == "Hello world" + assert "LEAK" not in content and "<|im_end|>" not in content + + +# (1) Context-size-exceeded -> structured 400, both modes (not a dropped socket). +def test_context_length_exceeded_returns_400(make_client): + client, _ = make_client(max_context=2048, prompt_tokens=2940) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "x" * 100}], + }, + ) + assert resp.status_code == 400 + err = resp.json()["error"] + assert err["code"] == "context_length_exceeded" + assert err["type"] == "invalid_request_error" + + +def test_context_length_exceeded_streaming_returns_400(make_client): + client, _ = make_client(max_context=2048, prompt_tokens=2940) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "x"}], + "stream": True, + }, + ) + # Pre-flight check rejects before the stream starts -> clean 400, no SSE. + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "context_length_exceeded" + + +def test_prompt_plus_max_tokens_exceeding_context_returns_400(make_client): + # Prompt fits (100 < 2048) but prompt + max_tokens (100 + 2000) > 2048: must + # reject up front, not run until the worker hits the limit mid-decode. + client, _ = make_client(max_context=2048, prompt_tokens=100) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 2000, + }, + ) + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "context_length_exceeded" + + +def test_prompt_plus_max_tokens_within_context_ok(make_client): + # Prompt + max_tokens (100 + 100) <= 2048: must NOT be rejected. + client, _ = make_client(max_context=2048, prompt_tokens=100) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 100, + }, + ) + assert resp.status_code == 200 + + +# (1) Mid-generation failure -> structured error, never a dropped connection. +def test_generation_failure_returns_structured_error(make_client): + client, _ = make_client(fail=True) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 500 + assert resp.json()["error"]["type"] == "server_error" + + +def test_generation_failure_streaming_emits_error_event(make_client): + client, _ = make_client(fail=True) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + assert resp.status_code == 200 # headers already sent; error arrives as an event + chunks, done = _sse_chunks(resp.text) + assert done + assert any("error" in c for c in chunks) + + +# (3) finish_reason == "length" when max_tokens is hit. +def test_finish_reason_length_when_max_tokens_hit(make_client): + client, _ = make_client(tokens=["a", "b", "c"]) # 3 generated tokens + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 3, + }, + ) + assert resp.json()["choices"][0]["finish_reason"] == "length" + + +def test_finish_reason_stop_when_under_max_tokens(make_client): + client, _ = make_client(tokens=["a", "b", "c"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 100, + }, + ) + assert resp.json()["choices"][0]["finish_reason"] == "stop" + + +# Worker-reported finish_reason: the worker may silently clamp max_new to the +# context window, so the token-count heuristic can't tell a real stop from a +# truncation. Trust the worker's reason. +def test_worker_reported_length_overrides_token_count(make_client): + # 3 tokens generated (< requested 100) but the worker says it ran to the cap + # (a context clamp): finish_reason must be "length", not "stop". + client, _ = make_client(tokens=["a", "b", "c"], finish_reason="length") + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 100, + }, + ).json() + assert body["choices"][0]["finish_reason"] == "length" + + +def test_worker_reported_stop_overrides_token_count(make_client): + # 3 tokens with max_tokens=3 (heuristic would say "length"), but the worker + # reports EOS: finish_reason must be "stop". + client, _ = make_client(tokens=["a", "b", "c"], finish_reason="stop") + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 3, + }, + ).json() + assert body["choices"][0]["finish_reason"] == "stop" + + +def test_server_stop_sequence_wins_over_worker_length(make_client): + # A server-side stop sequence is truncation in the control plane; it must + # win even when the worker would report "length". + client, _ = make_client( + tokens=["Hello ", "world ", "STOP", " x"], finish_reason="length" + ) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stop": ["STOP"], + "max_tokens": 100, + }, + ).json() + assert body["choices"][0]["finish_reason"] == "stop" + assert "STOP" not in (body["choices"][0]["message"]["content"] or "") + + +# (4) Error-variant matrix: malformed requests -> consistent 422. +@pytest.mark.parametrize( + "bad_body", + [ + {"model": "m"}, # missing messages + {"model": "m", "messages": "not-a-list"}, + { + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "temperature": "hot", + }, + { + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "stream": "maybe", + }, + {"model": "m", "messages": [{"content": "no role"}]}, + ], +) +def test_invalid_requests_return_422(make_client, bad_body): + client, _ = make_client() + assert client.post("/v1/chat/completions", json=bad_body).status_code == 422 + + +# (6) Streaming usage when stream_options.include_usage is set. +def test_streaming_usage_included(make_client): + client, _ = make_client(tokens=["a", "b"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + chunks, _ = _sse_chunks(resp.text) + usage_chunks = [c for c in chunks if c.get("usage")] + assert usage_chunks, "expected a chunk carrying usage" + u = usage_chunks[-1]["usage"] + assert u["total_tokens"] == u["prompt_tokens"] + u["completion_tokens"] + + +def test_streaming_usage_absent_by_default(make_client): + client, _ = make_client(tokens=["a", "b"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + chunks, _ = _sse_chunks(resp.text) + assert not any(c.get("usage") for c in chunks) + + +# (2) Unicode/multibyte content survives streaming intact. +def test_unicode_streaming_integrity(make_client): + pieces = ["café ", "日本語 ", "😀", "🎉"] + client, _ = make_client(tokens=pieces) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + chunks, _ = _sse_chunks( + resp.text + ) # _sse_chunks uses json.loads -> validates UTF-8/JSON + content = "".join(c["choices"][0]["delta"].get("content") or "" for c in chunks) + assert content == "".join(pieces) + + +def test_unicode_nonstreaming_integrity(make_client): + pieces = ["café ", "日本語 ", "😀"] + client, _ = make_client(tokens=pieces) + resp = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.json()["choices"][0]["message"]["content"] == "".join(pieces) diff --git a/extension/llm/server/python/tests/test_sampling_params.py b/extension/llm/server/python/tests/test_sampling_params.py new file mode 100644 index 00000000000..fe3166d1bb5 --- /dev/null +++ b/extension/llm/server/python/tests/test_sampling_params.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Contract tests for sampling/control params: stop sequences, n, tool_choice. + +These exercise the real server over the HTTP boundary with a FakeRunner. +""" + +import json + +CALL = ( + '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' +) +WEATHER_TOOL = { + "type": "function", + "function": {"name": "get_weather", "parameters": {}}, +} + + +def _body(**kw): + b = {"model": "test-model", "messages": [{"role": "user", "content": "hi"}]} + b.update(kw) + return b + + +def test_n_greater_than_one_is_rejected(make_client): + client, _ = make_client() + r = client.post("/v1/chat/completions", json=_body(n=2)) + assert r.status_code == 400 + assert r.json()["error"]["type"] == "invalid_request_error" + assert r.json()["error"]["code"] == "unsupported_parameter" + + +def test_unsupported_params_rejected(make_client): + client, _ = make_client() + for param in ( + {"top_p": 0.5}, + {"top_p": 2.0}, # > 1.0 is not a no-op either — only exactly 1.0/unset is + {"seed": 42}, + {"reasoning_effort": "high"}, + {"frequency_penalty": 1.0}, + {"presence_penalty": -0.5}, + {"top_k": 40}, + {"logit_bias": {"123": 5.0}}, + {"response_format": {"type": "json_object"}}, + {"logprobs": True}, + {"top_logprobs": 5}, + {"parallel_tool_calls": False}, + ): + r = client.post("/v1/chat/completions", json=_body(**param)) + assert r.status_code == 400, param + assert r.json()["error"]["code"] == "unsupported_parameter", param + + +def test_nonpositive_max_tokens_rejected(make_client): + # max_tokens=0/-2 must not be silently treated as "unbounded" (the `or` bug); + # 0 and negatives are invalid for both field names. + client, _ = make_client() + for param in ( + {"max_tokens": 0}, + {"max_tokens": -2}, + {"max_completion_tokens": 0}, + {"max_completion_tokens": -2}, + ): + r = client.post("/v1/chat/completions", json=_body(**param)) + assert r.status_code == 400, param + assert r.json()["error"]["type"] == "invalid_request_error", param + + +def test_temperature_range_rejected(make_client): + client, _ = make_client() + for param in ( + {"temperature": -1}, + {"temperature": -0.1}, + {"temperature": 2.1}, + {"temperature": 3}, + ): + r = client.post("/v1/chat/completions", json=_body(**param)) + assert r.status_code == 400, param + assert r.json()["error"]["code"] == "invalid_value", param + + +def test_noop_output_contract_fields_accepted(make_client): + # The default/no-op forms must NOT be rejected (don't break OpenAI clients + # that send them explicitly). + client, _ = make_client() + r = client.post( + "/v1/chat/completions", + json=_body( + response_format={"type": "text"}, + parallel_tool_calls=True, + max_tokens=8, + ), + ) + assert r.status_code == 200 + + +def test_zero_penalties_and_unknown_fields_accepted(make_client): + # frequency/presence_penalty=0.0 are no-ops; unknown non-generation fields + # (user/store/metadata) are ignored, not rejected (don't break OpenAI clients). + client, _ = make_client() + r = client.post( + "/v1/chat/completions", + json=_body( + frequency_penalty=0.0, + presence_penalty=0.0, + user="abc", + store=False, + metadata={"k": "v"}, + max_tokens=8, + ), + ) + assert r.status_code == 200 + + +def test_unsupported_tool_choice_rejected(make_client): + # "required" / a specific-function choice would need constrained decoding to + # force/restrict the call; v1 rejects rather than silently treating as "auto". + client, _ = make_client() + for choice in ( + "required", + {"type": "function", "function": {"name": "get_weather"}}, + ): + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], tool_choice=choice, max_tokens=8), + ) + assert r.status_code == 400, choice + assert r.json()["error"]["code"] == "unsupported_parameter", choice + + +def test_supported_params_accepted(make_client): + # top_p=1.0 (no-op) and temperature/max_tokens must NOT be rejected; neither + # should tool_choice "auto" / "none". + client, _ = make_client() + for temperature in (0.0, 1.0, 2.0): + r = client.post( + "/v1/chat/completions", + json=_body(top_p=1.0, temperature=temperature, max_tokens=8), + ) + assert r.status_code == 200, temperature + for choice in ("auto", "none"): + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], tool_choice=choice, max_tokens=8), + ) + assert r.status_code == 200, choice + + +def test_stop_sequence_truncates_nonstreaming(make_client): + client, _ = make_client(tokens=["Hello ", "world ", "STOP", " ignored"]) + r = client.post("/v1/chat/completions", json=_body(stop=["STOP"], max_tokens=32)) + assert r.status_code == 200 + content = r.json()["choices"][0]["message"]["content"] + assert content == "Hello world " + assert "STOP" not in content + + +def test_stop_sequence_truncates_streaming(make_client): + client, _ = make_client(tokens=["Hello ", "world ", "STOP", " ignored"]) + content = "" + with client.stream( + "POST", "/v1/chat/completions", json=_body(stop=["STOP"], stream=True) + ) as r: + for line in r.iter_lines(): + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + break + delta = json.loads(payload)["choices"][0]["delta"] + content += delta.get("content") or "" + assert content == "Hello world " + assert "STOP" not in content + + +def test_stop_forces_finish_reason_over_length_nonstreaming(make_client): + # 4 tokens emitted (completion reaches max_tokens=4) AND a stop is hit: + # finish_reason must be "stop" (boundary), not "length". + client, _ = make_client(tokens=["a ", "b ", "STOP", " c"]) + body = client.post( + "/v1/chat/completions", json=_body(stop=["STOP"], max_tokens=4) + ).json() + assert body["choices"][0]["finish_reason"] == "stop" + assert "STOP" not in (body["choices"][0]["message"]["content"] or "") + + +def test_stop_forces_finish_reason_over_length_streaming(make_client): + client, _ = make_client(tokens=["a ", "b ", "STOP", " c"]) + finish = None + with client.stream( + "POST", + "/v1/chat/completions", + json=_body(stop=["STOP"], max_tokens=4, stream=True), + ) as r: + for line in r.iter_lines(): + if not line.startswith("data:"): + continue + p = line[len("data:") :].strip() + if p == "[DONE]": + break + fr = json.loads(p)["choices"][0].get("finish_reason") + if fr: + finish = fr + assert finish == "stop" + + +_STOP_THEN_CALL = ( + 'Answer STOP \n{"name": "get_weather", "arguments": {}}\n' +) + + +def test_stop_before_tool_call_nonstreaming(make_client): + # Stop boundary precedes a tool call (in one chunk, so truncation — not just + # early-stop — must catch it): the call must NOT be parsed/emitted. + client, _ = make_client(tokens=[_STOP_THEN_CALL]) + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], stop=["STOP"], max_tokens=64), + ) + msg = r.json()["choices"][0]["message"] + assert msg.get("tool_calls") is None + assert "STOP" not in (msg.get("content") or "") + + +def test_stop_before_tool_call_streaming(make_client): + client, _ = make_client(tokens=[_STOP_THEN_CALL]) + saw_tool, content = False, "" + with client.stream( + "POST", + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], stop=["STOP"], stream=True), + ) as r: + for line in r.iter_lines(): + if not line.startswith("data:"): + continue + p = line[len("data:") :].strip() + if p == "[DONE]": + break + delta = json.loads(p)["choices"][0]["delta"] + if delta.get("tool_calls"): + saw_tool = True + content += delta.get("content") or "" + assert not saw_tool + assert "STOP" not in content + + +def test_tool_choice_none_disables_tools(make_client): + client, _ = make_client(tokens=[CALL]) + r = client.post( + "/v1/chat/completions", + json=_body(tools=[WEATHER_TOOL], tool_choice="none", max_tokens=64), + ) + msg = r.json()["choices"][0]["message"] + assert msg.get("tool_calls") is None # tools disabled -> returned as text + assert r.json()["choices"][0]["finish_reason"] != "tool_calls" + + +def test_tool_choice_default_still_parses(make_client): + # Sanity: without tool_choice="none", the same call IS parsed as a tool call. + client, _ = make_client(tokens=[CALL]) + r = client.post( + "/v1/chat/completions", json=_body(tools=[WEATHER_TOOL], max_tokens=64) + ) + calls = r.json()["choices"][0]["message"].get("tool_calls") + assert calls and calls[0]["function"]["name"] == "get_weather" diff --git a/extension/llm/server/python/tests/test_session_runtime.py b/extension/llm/server/python/tests/test_session_runtime.py new file mode 100644 index 00000000000..14f138cc007 --- /dev/null +++ b/extension/llm/server/python/tests/test_session_runtime.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""SessionRuntime tests: session-op routing, the blocking->async stream bridge, +cancellation, and worker shutdown. A fake worker stands in for the WorkerClient +(no model, GPU, or subprocess). asyncio.run keeps the test bodies sync.""" + +import asyncio +import threading + +from executorch.extension.llm.server.python.session_runtime import ( + GenerationOptions, + GenStats, + PromptInput, + SessionRuntime, +) + +_OPTS = GenerationOptions(max_new_tokens=8) + + +def _text(s="hi") -> PromptInput: + return PromptInput(text=s) + + +class _Worker: + """Records session ops + process close; emits nothing on generate.""" + + def __init__(self): + self.opened, self.reset_ids, self.closed_ids = [], [], [] + self.proc_closed = False + + def open_session(self, sid): + self.opened.append(sid) + + def reset_session(self, sid): + self.reset_ids.append(sid) + + def close_session(self, sid): + self.closed_ids.append(sid) + + def close(self): + self.proc_closed = True + + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + pass + + +def test_session_ops_route_to_worker(): + async def scenario(): + w = _Worker() + rt = SessionRuntime(w) + await rt.open("a") + await rt.reset("a") + await rt.close("a") + return w + + w = asyncio.run(scenario()) + assert w.opened == ["a"] and w.reset_ids == ["a"] and w.closed_ids == ["a"] + + +def test_session_ops_noop_when_worker_lacks_support(): + # A minimal worker without session ops: the runtime silently no-ops. + class _Bare: + def stop(self): + pass + + def generate(self, *a, **k): + pass + + async def scenario(): + rt = SessionRuntime(_Bare()) + await rt.open("a") + await rt.reset("a") + await rt.close("a") + + asyncio.run(scenario()) # must not raise + + +def test_generate_stream_yields_and_fills_stats(): + class _Echo: + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + token_callback("Hello") + token_callback(" world") + + class S: + num_prompt_tokens = 3 + num_generated_tokens = 2 + finish_reason = "stop" + generated_token_ids = [10, 11] + + stats_callback(S()) + + async def scenario(): + rt = SessionRuntime(_Echo()) + stats = GenStats() + out = [t async for t in rt.generate_stream("a", _text(), _OPTS, stats)] + return out, stats + + out, stats = asyncio.run(scenario()) + assert "".join(out) == "Hello world" + assert stats.completion_tokens == 2 + assert stats.finish_reason == "stop" + assert stats.generated_token_ids == [10, 11] + + +def test_generate_stream_forwards_session_and_segments_to_worker(): + captured = {} + + class _Cap: + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + captured["session_id"] = config.session_id + captured["segments"] = config.prompt_segments + captured["prompt"] = prompt + + async def scenario(): + rt = SessionRuntime(_Cap()) + seg = PromptInput(segments=[{"text": "a"}, {"ids": [1, 2]}]) + async for _ in rt.generate_stream("sess", seg, _OPTS, GenStats()): + pass + + asyncio.run(scenario()) + assert captured["session_id"] == "sess" + assert captured["segments"] == [{"text": "a"}, {"ids": [1, 2]}] + + +def test_cancellation_calls_worker_stop(): + class _Blocking: + def __init__(self): + self._gate = threading.Event() + self.stopped = False + + def stop(self): + self.stopped = True + self._gate.set() + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + token_callback("TOKEN") + self._gate.wait(timeout=5) + + async def scenario(): + w = _Blocking() + rt = SessionRuntime(w) + agen = rt.generate_stream("a", _text(), _OPTS).__aiter__() + assert await agen.__anext__() == "TOKEN" # worker now blocking + nxt = asyncio.ensure_future(agen.__anext__()) + await asyncio.sleep(0.05) + nxt.cancel() + try: + await nxt + except asyncio.CancelledError: + pass + for _ in range(100): # let the worker observe stop() + if w.stopped: + break + await asyncio.sleep(0.02) + await agen.aclose() + return w + + w = asyncio.run(scenario()) + assert w.stopped + + +def test_close_worker_shuts_down_worker(): + w = _Worker() + SessionRuntime(w).close_worker() + assert w.proc_closed + + +def test_prompt_input_requires_exactly_one(): + import pytest + + with pytest.raises(ValueError): + PromptInput() + with pytest.raises(ValueError): + PromptInput(text="x", segments=[{"text": "y"}]) diff --git a/extension/llm/server/python/tests/test_template.py b/extension/llm/server/python/tests/test_template.py new file mode 100644 index 00000000000..43c2f8f3973 --- /dev/null +++ b/extension/llm/server/python/tests/test_template.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Contract tests for chat-template kwargs (e.g. enable_thinking) passthrough.""" + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.protocol import ( + ChatMessage, + FunctionCall, + ToolCall, +) + + +class _FakeHF: + def __init__(self): + self.seen_kwargs = None + self.seen_messages = None + self.encode_add_special = None + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + self.seen_kwargs = kwargs + self.seen_messages = messages + return "PROMPT" + + # Default add_special_tokens=True mirrors real HF tokenizers (so a caller + # that forgets to disable specials would over-count). + def encode(self, text, add_special_tokens=True): + self.encode_add_special = add_special_tokens + return list(range(len(text))) # 1 id per char, deterministic + + +def _template_with_fake(defaults=None): + t = ChatTemplate( + hf_tokenizer_path=None, allow_fallback=True, default_template_kwargs=defaults + ) + fake = _FakeHF() + t._hf = fake + return t, fake + + +def test_count_tokens_excludes_special_tokens(): + # The rendered prompt already carries control tokens, so count_tokens must + # encode with add_special_tokens=False (matching the session/prefix-cache + # paths) — not the tokenizer's default True, which double-counts BOS/EOS and + # can falsely reject near-limit requests under --max-context. + t, fake = _template_with_fake() + n = t.count_tokens("PROMPT") + assert fake.encode_add_special is False + assert n == len("PROMPT") + + +def test_default_template_kwargs_applied(): + t, fake = _template_with_fake(defaults={"enable_thinking": False}) + t.render([ChatMessage(role="user", content="hi")]) + assert fake.seen_kwargs == {"enable_thinking": False} + + +def test_per_request_kwargs_override_defaults(): + t, fake = _template_with_fake(defaults={"enable_thinking": False}) + t.render( + [ChatMessage(role="user", content="hi")], + template_kwargs={"enable_thinking": True}, + ) + assert fake.seen_kwargs["enable_thinking"] is True + + +def test_no_kwargs_when_none(): + t, fake = _template_with_fake(defaults=None) + t.render([ChatMessage(role="user", content="hi")]) + assert fake.seen_kwargs == {} + + +def test_fallback_ignores_kwargs_without_hf(): + # No HF tokenizer → ChatML fallback, must not raise on kwargs. + t = ChatTemplate( + hf_tokenizer_path=None, + allow_fallback=True, + default_template_kwargs={"enable_thinking": False}, + ) + out = t.render([ChatMessage(role="user", content="hi")], template_kwargs={"x": 1}) + assert "<|im_start|>user" in out and out.endswith("<|im_start|>assistant\n") + + +# (5) Chat-template behaviors: multi-turn ordering, system message, roles. +def test_multi_turn_order_preserved(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="user", content="first"), + ChatMessage(role="assistant", content="second"), + ChatMessage(role="user", content="third"), + ] + ) + assert out.index("first") < out.index("second") < out.index("third") + assert out.endswith("<|im_start|>assistant\n") # generation prompt appended + + +def test_system_message_rendered(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="system", content="You are terse."), + ChatMessage(role="user", content="hi"), + ] + ) + assert "<|im_start|>system\nYou are terse." in out + + +def test_each_role_labeled(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="system", content="s"), + ChatMessage(role="user", content="u"), + ChatMessage(role="assistant", content="a"), + ] + ) + for role in ("system", "user", "assistant"): + assert f"<|im_start|>{role}" in out + + +# Tool round-trip: a turn-2 request (assistant tool_call + tool result) must +# serialize into the shape any HF chat template consumes — the multi-turn loop +# breaks at turn 2 otherwise. OpenAI sends tool-call arguments as a JSON string; +# HF templates expect a mapping (Qwen renders `arguments|items`), so the server +# decodes it before templating. +def test_tool_call_arguments_decoded_for_template(): + t, fake = _template_with_fake() + t.render( + [ + ChatMessage(role="user", content="weather?"), + ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + ], + ), + ChatMessage(role="tool", tool_call_id="c1", content='{"temp_c": 18}'), + ] + ) + msgs = fake.seen_messages + asst = next(m for m in msgs if m["role"] == "assistant") + assert asst["tool_calls"][0]["function"]["name"] == "get_weather" + # Decoded from the JSON string into a mapping the template can iterate. + assert asst["tool_calls"][0]["function"]["arguments"] == {"city": "Paris"} + tool = next(m for m in msgs if m["role"] == "tool") + assert tool["tool_call_id"] == "c1" and "temp_c" in tool["content"] + + +def test_tool_call_non_json_arguments_left_as_string(): + # A non-JSON arguments value must not crash; it passes through unchanged. + t, fake = _template_with_fake() + t.render( + [ + ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall(name="f", arguments="not json"), + ) + ], + ) + ] + ) + asst = next(m for m in fake.seen_messages if m["role"] == "assistant") + assert asst["tool_calls"][0]["function"]["arguments"] == "not json" diff --git a/extension/llm/server/python/tests/test_tool_calls.py b/extension/llm/server/python/tests/test_tool_calls.py new file mode 100644 index 00000000000..7b165fac287 --- /dev/null +++ b/extension/llm/server/python/tests/test_tool_calls.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tool-calling tests (HTTP contract via the server). + +Hermes/Qwen format only. The server buffers the model's full output and parses +it once into complete OpenAI tool_calls; parse failures degrade to visible text. +""" + +import json + +WEATHER_TOOLS = [ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {"type": "object"}}, + } +] + + +def _call_text(name, args): + return f'\n{{"name": "{name}", "arguments": {json.dumps(args)}}}\n' + + +# --- HTTP contract: non-streaming --------------------------------------- + + +def test_tool_call_nonstreaming(make_client): + client, _ = make_client(tokens=[_call_text("get_weather", {"city": "Paris"})]) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather in Paris?"}], + "tools": WEATHER_TOOLS, + }, + ).json() + choice = body["choices"][0] + assert choice["finish_reason"] == "tool_calls" + tc = choice["message"]["tool_calls"][0] + assert tc["type"] == "function" + assert tc["function"]["name"] == "get_weather" + assert json.loads(tc["function"]["arguments"]) == {"city": "Paris"} + + +def test_tool_call_streaming(make_client): + client, _ = make_client(tokens=[_call_text("get_weather", {"city": "Paris"})]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather?"}], + "tools": WEATHER_TOOLS, + "stream": True, + }, + ) + # Reuse the contract helper's SSE parsing. + chunks = [] + for line in resp.text.splitlines(): + line = line.strip() + if line.startswith("data:") and "[DONE]" not in line: + chunks.append(json.loads(line[len("data:") :].strip())) + tool_deltas = [ + c["choices"][0]["delta"]["tool_calls"][0] + for c in chunks + if c["choices"] and c["choices"][0]["delta"].get("tool_calls") + ] + assert tool_deltas and tool_deltas[0]["function"]["name"] == "get_weather" + assert chunks[-1]["choices"][0]["finish_reason"] == "tool_calls" + + +def test_undefined_tool_is_not_called(make_client): + # Model calls a tool not in the request's tools -> no tool_calls; the raw + # call stays visible as content (degrade to text, never silent drop). + client, _ = make_client(tokens=[_call_text("rm_rf", {"path": "/"})]) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ).json() + msg = body["choices"][0]["message"] + assert msg.get("tool_calls") is None + assert "rm_rf" in (msg.get("content") or "") # not dropped — visible as text + + +def test_mixed_valid_and_undefined_tool_degrades_to_text(make_client): + # A response with one valid + one undefined call must NOT emit the valid one + # while silently dropping the undefined one — the whole response degrades to + # visible text so the model's full intent is preserved. + client, _ = make_client( + tokens=[ + _call_text("get_weather", {"city": "Paris"}) + + _call_text("rm_rf", {"path": "/"}) + ] + ) + msg = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ).json()["choices"][0]["message"] + assert msg.get("tool_calls") is None # no partial set + content = msg.get("content") or "" + assert "rm_rf" in content and "get_weather" in content # full intent visible + + +def test_tool_choice_none_omits_tools_from_prompt(): + from executorch.extension.llm.server.python.chat_template import ChatTemplate + from executorch.extension.llm.server.python.server import build_app + from executorch.extension.llm.server.python.serving_chat import ServingChat + from executorch.extension.llm.server.python.session_runtime import SessionRuntime + from executorch.extension.llm.server.python.tool_parsers import HermesDetector + + # tool_choice="none" must NOT inject tool schemas into the chat template; if it + # did, the model could still emit a that we'd surface as plain text + # (parsing is disabled for "none"). Assert via a recording tokenizer. + from fastapi.testclient import TestClient + + class _RecordingTok: + all_special_tokens: list = [] + + def __init__(self): + self.tools_seen = "UNSET" + + def encode(self, text): + return [0] + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + self.tools_seen = tools + return "PROMPT" + + class _Runner: + def reset(self): + pass + + def stop(self): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + if token_callback: + token_callback("ok") + + rec = _RecordingTok() + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + template._hf = rec + runtime = SessionRuntime(_Runner()) + serving = ServingChat( + runtime, template, "test-model", tool_detector_cls=HermesDetector + ) + client = TestClient(build_app(serving, "test-model")) + body = { + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + "max_tokens": 8, + } + + assert ( + client.post( + "/v1/chat/completions", json={**body, "tool_choice": "none"} + ).status_code + == 200 + ) + assert rec.tools_seen is None # tools omitted from the rendered prompt + + # Control: default ("auto") still passes the tools through to the template. + assert client.post("/v1/chat/completions", json=body).status_code == 200 + assert rec.tools_seen == WEATHER_TOOLS + + +def test_malformed_tool_call_falls_back_to_text(make_client): + # Broken JSON inside the markers must not crash; degrade to visible text. + client, _ = make_client(tokens=["\n{not json}\n"]) + resp = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["message"].get("tool_calls") is None + + +def test_no_tools_field_means_text_even_if_markers_present(make_client): + # Without a tools array, tool markers are just content (not parsed). + client, _ = make_client(tokens=[_call_text("get_weather", {"city": "X"})]) + body = client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ).json() + assert body["choices"][0]["message"].get("tool_calls") is None + + +def test_parallel_calls_in_one_message(make_client): + # Two complete blocks in one output -> two structured calls. + tokens = [ + _call_text("get_weather", {"city": "A"}) + + _call_text("get_weather", {"city": "B"}) + ] + client, _ = make_client(tokens=tokens) + body = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather in A and B?"}], + "tools": WEATHER_TOOLS, + }, + ).json() + calls = body["choices"][0]["message"]["tool_calls"] + assert [json.loads(c["function"]["arguments"])["city"] for c in calls] == ["A", "B"] diff --git a/extension/llm/server/python/tests/test_worker_client.py b/extension/llm/server/python/tests/test_worker_client.py new file mode 100644 index 00000000000..dbed8d396f3 --- /dev/null +++ b/extension/llm/server/python/tests/test_worker_client.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the generic WorkerClient JSONL protocol (no model/GPU/subprocess). + +A fake process stands in for the C++ worker: it records what the client writes +and replays a scripted sequence of JSONL response lines. +""" + +import json +from dataclasses import dataclass, field + +import pytest + +from executorch.extension.llm.server.python.worker_client import ( + spawn_worker, + WorkerClient, + WorkerError, +) + + +class _FakeStdin: + def __init__(self): + self.written = [] + + def write(self, s): + self.written.append(s) + + def flush(self): + pass + + def close(self): + pass + + +class _FakeStdout: + def __init__(self, lines): + self._lines = list(lines) + + def readline(self): + return self._lines.pop(0) if self._lines else "" + + +class _FakeProc: + def __init__(self, stdout_lines, returncode=None): + self.stdin = _FakeStdin() + self.stdout = _FakeStdout(stdout_lines) + self._returncode = returncode + + def poll(self): + return self._returncode + + @property + def returncode(self): + return self._returncode + + +@dataclass +class _Cfg: + max_new_tokens: int = 64 + temperature: float = 0.0 + stop: list = field(default_factory=list) + + +def _lines(*objs): + return [json.dumps(o) + "\n" for o in objs] + + +def test_generate_streams_tokens_then_stats(): + proc = _FakeProc( + _lines( + {"token": "Hello"}, + {"token": " world"}, + {"done": True, "prompt_tokens": 4, "completion_tokens": 2}, + ) + ) + client = WorkerClient(proc) + out, stats = [], {} + client.generate( + "hi", + _Cfg(temperature=0.7), + token_callback=out.append, + stats_callback=lambda s: stats.update( + prompt=s.num_prompt_tokens, gen=s.num_generated_tokens + ), + ) + assert "".join(out) == "Hello world" + assert stats == {"prompt": 4, "gen": 2} + # The request carried prompt + sampling, one JSON line. + sent = json.loads(proc.stdin.written[0]) + assert sent == { + "prompt": "hi", + "max_new_tokens": 64, + "temperature": 0.7, + "stop": [], + } + + +def test_generate_forwards_stop_sequences(): + proc = _FakeProc(_lines({"done": True, "prompt_tokens": 1, "completion_tokens": 0})) + WorkerClient(proc).generate("hi", _Cfg(stop=["STOP", "\n\n"])) + sent = json.loads(proc.stdin.written[0]) + assert sent["stop"] == ["STOP", "\n\n"] + + +def test_generate_reports_finish_reason(): + proc = _FakeProc( + _lines( + {"token": "hi"}, + { + "done": True, + "prompt_tokens": 2, + "completion_tokens": 1, + "finish_reason": "length", + }, + ) + ) + seen = {} + WorkerClient(proc).generate( + "hi", _Cfg(), stats_callback=lambda s: seen.update(fr=s.finish_reason) + ) + assert seen["fr"] == "length" + + +def test_error_message_raises_worker_error(): + proc = _FakeProc(_lines({"error": "boom"})) + with pytest.raises(WorkerError, match="boom"): + WorkerClient(proc).generate("hi", _Cfg()) + + +def test_exit_mid_request_raises(): + proc = _FakeProc([]) # readline() -> "" means the worker exited + with pytest.raises(WorkerError, match="exited mid-request"): + WorkerClient(proc).generate("hi", _Cfg()) + + +def test_generate_on_dead_worker_raises(): + proc = _FakeProc([], returncode=1) + with pytest.raises(WorkerError, match="worker exited"): + WorkerClient(proc).generate("hi", _Cfg()) + + +def test_spawn_worker_waits_for_ready(): + proc = _FakeProc(_lines({"ready": True})) + client = spawn_worker( + ["/fake/worker", "--model_path", "m"], popen=lambda *a, **k: proc + ) + assert isinstance(client, WorkerClient) + + +def test_spawn_worker_not_ready_raises(): + proc = _FakeProc(_lines({"oops": True})) + with pytest.raises(WorkerError, match="did not report ready"): + spawn_worker(["/fake/worker"], popen=lambda *a, **k: proc) + + +def test_spawn_worker_no_output_raises(): + proc = _FakeProc([]) + with pytest.raises(WorkerError, match="failed to start"): + spawn_worker(["/fake/worker"], popen=lambda *a, **k: proc) diff --git a/extension/llm/server/python/worker_client.py b/extension/llm/server/python/worker_client.py new file mode 100644 index 00000000000..6b7d4e84132 --- /dev/null +++ b/extension/llm/server/python/worker_client.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Generic control-plane client for a model-execution worker process. + +Model execution runs in a separate C++ worker process — the Python server is +HTTP/control plane only and never loads a model, links a backend, or imports a +pybind module. This client spawns a worker binary and drives generation over +JSONL on the worker's stdin/stdout. The protocol is model-agnostic: the same +client serves a TextLLM worker, a Qwen worker, or any future model worker; only +the binary and its launch args differ. + +Protocol (one JSON object per line): + worker -> stdout, once at startup: {"ready": true, "max_sessions": int, + "max_named_sessions": int} + client -> stdin: + generate: {"max_new_tokens": int, "temperature": float, "stop": [str, ...], + "session_id"?: str, and exactly one prompt form: + "prompt": str + "prompt_segments": [{"text": str} | {"ids": [int, ...]}]} + open: {"op": "open", "session_id": str} + close: {"op": "close", "session_id": str} + reset: {"op": "reset", "session_id": str} # clear context, keep slot + worker -> stdout: + generate: {"token": str} * (streamed) + {"done": true, "prompt_tokens": int, "completion_tokens": int, + "finish_reason": "stop" | "length", + "reused_prompt_tokens": int, "prefilled_prompt_tokens": int, + "session_reset_reason": "new"|"exact_prefix"|"dirty"|"mismatch" + |"equal", + "generated_token_ids"?: [int, ...]} # omitted if stop-trimmed + open: {"opened": true, "session_id": str} + close: {"closed": true, "session_id": str} + reset: {"reset": true, "session_id": str} + error: {"error": str, "code"?: str} # capacity_exhausted / + # unsupported_session + +The worker's stdout carries ONLY protocol JSON; its logs go to stderr. One +request at a time per worker; the caller (SessionRuntime) serializes. A worker +hosts one engine and routes requests to per-session_id state (anonymous requests +share a scratch session); execution is synchronous. +""" + +import json +import logging +import subprocess +import threading +from dataclasses import dataclass, field +from typing import Callable, Optional, Sequence + +logger = logging.getLogger(__name__) + + +@dataclass +class WorkerStats: + """Usage reported by a worker at the end of a request.""" + + num_prompt_tokens: int = 0 + num_generated_tokens: int = 0 + # Why generation stopped, as the worker saw it: "stop" (EOS / cooperative + # stop) or "length" (ran to max_new, possibly clamped to the context window). + # None if the worker didn't report it (older worker / fake). + finish_reason: Optional[str] = None + # Warm-resume accounting (V2b.1): how many prompt tokens were served from the + # session's resident KV state vs actually prefilled this request, and why + # ("new"|"exact_prefix"|"dirty"|"mismatch"|"equal"). Not exposed as OpenAI + # usage; logged for measuring warm-resume hit rate. None on older workers. + reused_prompt_tokens: int = 0 + prefilled_prompt_tokens: int = 0 + session_reset_reason: Optional[str] = None + # The exact (non-terminal) token ids generated this turn. The control plane + # stores these per session and splices them back as an `ids` prompt segment + # next turn, so a prior assistant span is an exact token extension instead of + # a lossy chat-template re-render (V2b.1.5). Empty on older workers. + generated_token_ids: list = field(default_factory=list) + + +class WorkerError(RuntimeError): + """A worker process failed, exited, or reported a generation error. + + `code` carries the worker's structured error code when present + ("capacity_exhausted", "unsupported_session"), so the HTTP layer can map it + to the right status; None for unstructured failures. + """ + + def __init__(self, message: str, code: Optional[str] = None): + super().__init__(message) + self.code = code + + +class WorkerClient: + """Drives one model-execution worker process over JSONL (raw transport). + + Exposes ``generate(prompt, config, token_callback, stats_callback)`` / + ``stop()`` plus the session ops ``open_session`` / ``close_session`` / + ``reset_session`` that SessionRuntime drives. Calls are serialized by a lock + and by SessionRuntime (one in-flight request). The control plane never + executes model code. + """ + + def __init__(self, proc: subprocess.Popen, max_named_sessions: int = 0): + self._proc = proc + self._lock = threading.Lock() + # Named sessions this worker can host (0 = scratch-only / single session). + self.max_named_sessions = max_named_sessions + + def reset(self) -> None: + # The worker resets its session at the start of each request; nothing to + # do here. + pass + + def stop(self) -> None: + # Best-effort: a request is synchronous and not interruptible mid- + # generation in V1. + pass + + def open_session(self, session_id: str) -> None: + """Admit a named session (idempotent). Raises WorkerError with a `code` + ("capacity_exhausted" / "unsupported_session") if the worker refuses.""" + self._op({"op": "open", "session_id": session_id}, ack_key="opened") + + def close_session(self, session_id: str) -> None: + """Destroy a named session, freeing its state (idempotent).""" + self._op({"op": "close", "session_id": session_id}, ack_key="closed") + + def reset_session(self, session_id: str) -> None: + """Clear a named session's context (KV/recurrent + resident tokens) but + keep its capacity slot allocated (idempotent).""" + self._op({"op": "reset", "session_id": session_id}, ack_key="reset") + + def _op(self, request: dict, ack_key: str) -> None: + with self._lock: + if self._proc.poll() is not None: + raise WorkerError( + f"worker exited (code {self._proc.returncode}); restart the server" + ) + try: + self._proc.stdin.write(json.dumps(request) + "\n") + self._proc.stdin.flush() + except (BrokenPipeError, ValueError) as e: + raise WorkerError("worker stdin is closed") from e + line = self._proc.stdout.readline() + if not line: + raise WorkerError("worker exited mid-request") + msg = json.loads(line) + if msg.get(ack_key): + return + if "error" in msg: + raise WorkerError(msg["error"], code=msg.get("code")) + raise WorkerError(f"unexpected worker response: {msg}") + + @staticmethod + def _on_done(msg: dict, stats_callback) -> None: + reason = msg.get("session_reset_reason") + if reason is not None and logger.isEnabledFor(logging.DEBUG): + logger.debug( + "warm-resume: reason=%s reused=%d prefilled=%d", + reason, + msg.get("reused_prompt_tokens", 0), + msg.get("prefilled_prompt_tokens", 0), + ) + if stats_callback is not None: + stats_callback( + WorkerStats( + num_prompt_tokens=msg.get("prompt_tokens", 0), + num_generated_tokens=msg.get("completion_tokens", 0), + finish_reason=msg.get("finish_reason"), + reused_prompt_tokens=msg.get("reused_prompt_tokens", 0), + prefilled_prompt_tokens=msg.get("prefilled_prompt_tokens", 0), + session_reset_reason=reason, + generated_token_ids=msg.get("generated_token_ids", []), + ) + ) + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + request = { + "max_new_tokens": getattr(config, "max_new_tokens", -1), + "temperature": getattr(config, "temperature", 0.0), + "stop": list(getattr(config, "stop", []) or []), + } + # Token-ID segments (V2b.1.5) take precedence over the rendered string: + # they let prior assistant spans be exact id runs, not lossy re-renders. + # `is not None` (not truthiness): segments is a distinct prompt form, kept + # whatever its content (the worker validates non-empty). + segments = getattr(config, "prompt_segments", None) + if segments is not None: + request["prompt_segments"] = segments + else: + request["prompt"] = prompt + session_id = getattr(config, "session_id", None) + if session_id: + request["session_id"] = session_id + with self._lock: + if self._proc.poll() is not None: + raise WorkerError( + f"worker exited (code {self._proc.returncode}); restart the server" + ) + try: + self._proc.stdin.write(json.dumps(request) + "\n") + self._proc.stdin.flush() + except (BrokenPipeError, ValueError) as e: + raise WorkerError("worker stdin is closed") from e + + while True: + line = self._proc.stdout.readline() + if not line: + raise WorkerError("worker exited mid-request") + msg = json.loads(line) + if "token" in msg: + if token_callback is not None: + token_callback(msg["token"]) + elif msg.get("done"): + self._on_done(msg, stats_callback) + return + elif "error" in msg: + raise WorkerError(msg["error"], code=msg.get("code")) + + def close(self) -> None: + """Terminate the worker process (called at server shutdown).""" + if self._proc.poll() is not None: + return + try: + if self._proc.stdin is not None: + self._proc.stdin.close() + except OSError: + pass + try: + self._proc.terminate() + self._proc.wait(timeout=5) + except Exception: # noqa: BLE001 - shutdown best-effort + self._proc.kill() + + +def spawn_worker( + cmd: Sequence[str], + env: Optional[dict] = None, + cwd: Optional[str] = None, + popen: Callable[..., subprocess.Popen] = subprocess.Popen, +) -> WorkerClient: + """Start a worker process and block until it reports ``{"ready": true}``. + + `cmd` is the worker binary and its launch args (model/tokenizer paths). The + worker loads the model once before reporting ready, so a slow load surfaces + here rather than on the first request. + """ + logger.info("Starting model worker: %s", cmd[0]) + proc = popen( + list(cmd), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + bufsize=1, + env=env, + cwd=cwd, + ) + line = proc.stdout.readline() + if not line: + raise WorkerError("worker failed to start (no output; check its stderr).") + msg = json.loads(line) + if not msg.get("ready"): + raise WorkerError(f"worker did not report ready: {msg}") + max_named = int(msg.get("max_named_sessions", 0)) + logger.info("Model worker ready (max_named_sessions=%d).", max_named) + return WorkerClient(proc, max_named_sessions=max_named)