From 729842940a5ef7d7cbe57efcd9fd2726c28e9197 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 15:32:32 -0700 Subject: [PATCH 1/3] [INITIAL] Update [ghstack-poisoned] --- extension/llm/runner/__init__.py | 4 + extension/llm/runner/_llm_runner.pyi | 60 ++++ extension/llm/runner/pybindings.cpp | 257 +++++++++++++++++- .../llm/runner/test/test_runner_pybindings.py | 132 +++++++++ 4 files changed, 446 insertions(+), 7 deletions(-) diff --git a/extension/llm/runner/__init__.py b/extension/llm/runner/__init__.py index 4e0ced33b21..97950cc95fa 100644 --- a/extension/llm/runner/__init__.py +++ b/extension/llm/runner/__init__.py @@ -18,6 +18,8 @@ from executorch.extension.llm.runner._llm_runner import ( # noqa: F401 GenerationConfig, Image, + LLMEngine, + LLMSession, make_audio_input, make_image_input, make_raw_audio_input, @@ -234,5 +236,7 @@ def generate_text_hf( "MultimodalInput", "MultimodalRunner", "TextLLMRunner", + "LLMEngine", + "LLMSession", "Stats", ] diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 271cf1e1540..79e5e3adc8c 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -411,6 +411,66 @@ class TextLLMRunner: def __repr__(self) -> str: ... +class LLMSession: + """A per-conversation session created by LLMEngine: reuses the engine's + program/resources (weight sharing is backend-dependent — see + LLMEngine.serving_capacity()) but owns its own KV cache. Backend calls + (prefill_tokens/decode_one) are serialized across the engine's sessions by + an engine-owned lock.""" + + def prefill_tokens(self, token_ids: List[int]) -> None: ... + def decode_one(self, temperature: float = -1.0) -> dict: + """One decode step -> {"token_id": int, "text": bytes, "is_eos": bool}.""" + ... + + def seek(self, pos: int) -> None: ... + def position(self) -> int: ... + def reset(self) -> None: ... + def stop(self) -> None: + """Token-boundary cooperative stop: safe from another thread, but it + does not abort a decode_one() already running — it takes effect before + the next decode_one().""" + ... + + def __repr__(self) -> str: ... + +class LLMEngine: + """Engine for multi-session text generation over one loaded program. + + Loads the model's program once; create_session() returns a LLMSession that + reuses it but owns its own KV cache. Whether extra sessions avoid + duplicating packed weights is backend-dependent — ask serving_capacity(). Backend execution across all sessions of one engine is + serialized by an engine-owned lock (backend ops are not assumed + thread-safe), so it is safe to drive multiple sessions from multiple Python + threads. + """ + + def __init__( + self, + model_path: str, + tokenizer_path: str, + data_path: Optional[str] = None, + method_name: str = "forward", + temperature: float = -1.0, + ) -> None: ... + def create_session(self) -> LLMSession: + """Create a session that reuses this engine's program/resources (weight + sharing is backend-dependent — see serving_capacity()), with its own KV + cache.""" + ... + + def serving_capacity(self) -> dict: + """Serving-capacity dict: max_physical_sessions_without_weight_duplication + (1 = single-slot, no weight duplication) and estimated_bytes_per_session + (0 = unknown). The server clamps physical sessions to this.""" + ... + + def metadata(self) -> dict: + """Model metadata from the .pte, e.g. get_max_context_len.""" + ... + + def __repr__(self) -> str: ... + class MultimodalRunner: """Runner for multimodal language models.""" diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index 3188b5390c4..cbd79959c80 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,74 @@ using namespace executorch::runtime; } \ }) +namespace { + +// Length of the longest prefix of `s` that does not end in the middle of a +// UTF-8 multi-byte sequence. Byte-level BPE tokenizers can emit a token that is +// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji), +// so the std::string->str conversion must wait until the character is complete +// or it throws UnicodeDecodeError and aborts generation. +size_t utf8_complete_prefix_len(const std::string& s) { + size_t i = 0; + const size_t n = s.size(); + while (i < n) { + const unsigned char c = static_cast(s[i]); + size_t len; + if (c < 0x80) { + len = 1; + } else if ((c >> 5) == 0x6) { + len = 2; + } else if ((c >> 4) == 0xE) { + len = 3; + } else if ((c >> 3) == 0x1E) { + len = 4; + } else { + len = 1; // invalid lead byte; emit it and let "replace" handle it + } + if (i + len > n) { + break; // incomplete trailing sequence: hold it for the next token + } + i += len; + } + return i; +} + +// Wraps a Python str token-callback so multi-byte characters split across +// tokens are buffered and only forwarded once complete. `flush()` emits any +// trailing (incomplete) bytes with replacement so nothing is silently dropped. +class TokenStringCallback { + public: + explicit TokenStringCallback(py::object cb) : cb_(std::move(cb)) {} + + void operator()(const std::string& token) { + buf_ += token; + const size_t n = utf8_complete_prefix_len(buf_); + if (n == 0) { + return; + } + py::gil_scoped_acquire acquire; + cb_(py::reinterpret_steal( + PyUnicode_DecodeUTF8(buf_.data(), n, "replace"))); + buf_.erase(0, n); + } + + void flush() { + if (buf_.empty()) { + return; + } + py::gil_scoped_acquire acquire; + cb_(py::reinterpret_steal( + PyUnicode_DecodeUTF8(buf_.data(), buf_.size(), "replace"))); + buf_.clear(); + } + + private: + py::object cb_; + std::string buf_; +}; + +} // namespace + // Python wrapper class for TextLLMRunner class PyTextLLMRunner { public: @@ -79,10 +148,11 @@ class PyTextLLMRunner { // Convert Python callbacks to C++ std::function std::function cpp_token_callback = nullptr; + std::shared_ptr token_assembler; if (!token_callback.is_none()) { - cpp_token_callback = [token_callback](const std::string& token) { - py::gil_scoped_acquire acquire; - token_callback(token); + token_assembler = std::make_shared(token_callback); + cpp_token_callback = [token_assembler](const std::string& token) { + (*token_assembler)(token); }; } @@ -94,13 +164,16 @@ class PyTextLLMRunner { }; } - // Release GIL during generation + // Release GIL during generation. { py::gil_scoped_release release; Error error = runner_->generate( prompt, config, cpp_token_callback, cpp_stats_callback); THROW_IF_ERROR(error, "Generation failed"); } + if (token_assembler) { + token_assembler->flush(); + } } void stop() { @@ -174,10 +247,11 @@ class PyMultimodalRunner { // Convert Python callbacks to C++ std::function std::function cpp_token_callback = nullptr; + std::shared_ptr token_assembler; if (!token_callback.is_none()) { - cpp_token_callback = [token_callback](const std::string& token) { - py::gil_scoped_acquire acquire; - token_callback(token); + token_assembler = std::make_shared(token_callback); + cpp_token_callback = [token_assembler](const std::string& token) { + (*token_assembler)(token); }; } @@ -196,6 +270,9 @@ class PyMultimodalRunner { inputs, config, cpp_token_callback, cpp_stats_callback); THROW_IF_ERROR(error, "Generation failed"); } + if (token_assembler) { + token_assembler->flush(); + } } std::string generate_text( @@ -251,6 +328,124 @@ class PyMultimodalRunner { std::unique_ptr runner_; }; +// A session handle (LLMSession), the model-agnostic per-conversation API. +// Backend calls (prefill_tokens/decode_one) take the engine-owned lock so +// concurrent sessions of one engine serialize (Module::execute isn't assumed +// thread-safe); cheap state ops (seek/reset/position/stop) don't. +class PyLLMSession { + public: + PyLLMSession( + std::unique_ptr session, + std::shared_ptr exec_mutex) + : session_(std::move(session)), exec_mutex_(std::move(exec_mutex)) {} + + void prefill_tokens(std::vector tokens) { + py::gil_scoped_release release; + auto exec_lock = lock_exec(); + THROW_IF_ERROR( + session_->prefill_tokens(std::move(tokens)), "prefill_tokens failed"); + } + + py::dict decode_one(float temperature = -1.0f) { + uint64_t token_id; + std::string text; + bool is_eos; + { + py::gil_scoped_release release; + auto exec_lock = lock_exec(); + SamplingConfig sampling; + sampling.temperature = temperature; + auto res = session_->decode_one(sampling); + THROW_IF_ERROR(res.error(), "decode_one failed"); + const auto& r = res.get(); + token_id = r.token_id; + text = r.text_piece; + is_eos = r.is_eos; + } + py::dict d; + d["token_id"] = token_id; + d["text"] = py::bytes(text); + d["is_eos"] = is_eos; + return d; + } + + void seek(int64_t pos) { + THROW_IF_ERROR(session_->seek(pos), "seek failed"); + } + int64_t position() const { + return session_->position(); + } + void reset() { + THROW_IF_ERROR(session_->reset(), "reset failed"); + } + void stop() { + session_->stop(); + } + + private: + std::unique_lock lock_exec() { + return exec_mutex_ ? std::unique_lock(*exec_mutex_) + : std::unique_lock(); + } + std::unique_ptr session_; + std::shared_ptr exec_mutex_; +}; + +// Engine over one loaded Program: loads it once; create_session() returns an +// LLMSession that reuses it but owns its own KV state. Physical weight sharing +// across sessions is backend-dependent (serving_capacity() is authoritative). +class PyLLMEngine { + public: + PyLLMEngine( + const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path = std::nullopt, + const std::string& method_name = "forward", + float temperature = -1.0f) { + if (data_path.has_value()) { + throw std::runtime_error( + "LLMEngine: shared sessions with external data (.ptd / data_path) are " + "not yet supported; use a self-contained .pte (the session Module " + "needs the data_map_loader threaded through — tracked as a follow-up)."); + } + engine_ = TextLLMEngine::create( + model_path, tokenizer_path, data_path, temperature, method_name); + if (!engine_) { + throw std::runtime_error( + "Failed to create LLMEngine with model: " + model_path); + } + } + + std::unique_ptr create_session() { + auto res = engine_->create_session(); + THROW_IF_ERROR(res.error(), "Failed to create session from LLMEngine"); + // Hand the session the engine-owned lock so backend execution across all + // sessions of this engine is serialized. + return std::make_unique(std::move(res.get()), exec_mutex_); + } + + py::dict serving_capacity() const { + const auto c = engine_->serving_capacity(); + py::dict d; + d["max_physical_sessions_without_weight_duplication"] = + c.max_physical_sessions_without_weight_duplication; + d["estimated_bytes_per_session"] = c.estimated_bytes_per_session; + return d; + } + + py::dict metadata() const { + py::dict d; + for (const auto& [key, value] : engine_->metadata()) { + d[py::str(key)] = value; + } + return d; + } + + private: + std::unique_ptr engine_; + std::shared_ptr exec_mutex_ = std::make_shared(); +}; + PYBIND11_MODULE(_llm_runner, m) { m.doc() = "Python bindings for ExecuTorch LLM Runners"; @@ -735,6 +930,54 @@ PYBIND11_MODULE(_llm_runner, m) { return ""; }); + // Bind PyLLMEngine: shared-weight engine, create_session() per conversation. + // Bind PyLLMSession: the per-conversation session handle. + py::class_(m, "LLMSession") + .def( + "prefill_tokens", + &PyLLMSession::prefill_tokens, + py::arg("token_ids"), + "Prefill pre-tokenized input at the current cache position.") + .def( + "decode_one", + &PyLLMSession::decode_one, + py::arg("temperature") = -1.0f, + "Decode one token; returns {token_id:int, text:bytes, is_eos:bool}.") + .def("seek", &PyLLMSession::seek, py::arg("pos"), "Rewind KV to `pos`.") + .def("position", &PyLLMSession::position, "Resident KV token count.") + .def("reset", &PyLLMSession::reset, "Clear KV / position.") + .def("stop", &PyLLMSession::stop, "Signal an in-flight decode to stop.") + .def("__repr__", [](const PyLLMSession&) { return ""; }); + + py::class_(m, "LLMEngine") + .def( + py::init< + const std::string&, + const std::string&, + std::optional, + const std::string&, + float>(), + py::arg("model_path"), + py::arg("tokenizer_path"), + py::arg("data_path") = py::none(), + py::arg("method_name") = "forward", + py::arg("temperature") = -1.0f, + "Load a model's program once for multi-session serving.") + .def( + "create_session", + &PyLLMEngine::create_session, + "Create an LLMSession that reuses the engine's program/resources " + "(weight sharing is backend-dependent — see serving_capacity()) but " + "owns its own KV cache. Backend execution across sessions is " + "serialized by an engine-owned lock.") + .def( + "serving_capacity", + &PyLLMEngine::serving_capacity, + "Serving-capacity dict; the server clamps physical sessions to " + "max_physical_sessions_without_weight_duplication (1 = single-slot).") + .def("metadata", &PyLLMEngine::metadata, "Model metadata from the .pte.") + .def("__repr__", [](const PyLLMEngine&) { return ""; }); + // Bind PyMultimodalRunner py::class_(m, "MultimodalRunner") // Constructor with tokenizer path diff --git a/extension/llm/runner/test/test_runner_pybindings.py b/extension/llm/runner/test/test_runner_pybindings.py index 5619e586c4b..33ec4971a62 100644 --- a/extension/llm/runner/test/test_runner_pybindings.py +++ b/extension/llm/runner/test/test_runner_pybindings.py @@ -14,19 +14,51 @@ import os import tempfile + +import threading import unittest import torch from executorch.extension.llm.runner import ( GenerationConfig, Image, + LLMEngine, + LLMSession, make_image_input, make_text_input, MultimodalInput, MultimodalRunner, + TextLLMRunner, ) +class TestSessionApiBoundary(unittest.TestCase): + """The Python serving boundary: token-step primitives live ONLY on + LLMSession, never on the legacy TextLLMRunner (whose token-step methods are + C++ implementation details behind TextLLMSession). Pure class introspection, + so no model/.pte is needed.""" + + TOKEN_STEP = ("prefill_tokens", "decode_one", "seek", "position") + + def test_text_llm_runner_does_not_expose_token_step(self): + for name in self.TOKEN_STEP: + self.assertFalse( + hasattr(TextLLMRunner, name), + f"TextLLMRunner must not expose token-step method {name!r} to " + f"Python; drive sessions through LLMSession instead.", + ) + + def test_llm_session_exposes_token_step(self): + for name in (*self.TOKEN_STEP, "reset", "stop"): + self.assertTrue( + hasattr(LLMSession, name), f"LLMSession must expose {name!r}" + ) + + def test_llm_engine_exposes_serving_api(self): + for name in ("create_session", "serving_capacity"): + self.assertTrue(hasattr(LLMEngine, name), f"LLMEngine must expose {name!r}") + + class TestGenerationConfig(unittest.TestCase): """Test the GenerationConfig class.""" @@ -264,3 +296,103 @@ def test_make_image_input(self): img_tensor_rgba = torch.ones((4, 50, 50), dtype=torch.uint8) * 128 image_input_rgba = make_image_input(img_tensor_rgba) self.assertTrue(image_input_rgba.is_image()) + + +# Real-engine tests need a model; gated on env vars so they're skipped in CI +# environments without one (fake-runner unit tests can't exercise the real +# shared-Program / serialization behavior). +_MODEL = os.environ.get("ET_TEST_TEXT_LLM_MODEL") +_TOKENIZER = os.environ.get("ET_TEST_TEXT_LLM_TOKENIZER") + + +@unittest.skipUnless( + _MODEL and _TOKENIZER, + "set ET_TEST_TEXT_LLM_MODEL and ET_TEST_TEXT_LLM_TOKENIZER to run", +) +class TestLLMEngineSessions(unittest.TestCase): + """LLMEngine: sessions share weights, stay isolated, and serialize backend + execution so concurrent sessions don't corrupt each other.""" + + @classmethod + def setUpClass(cls): + # LLM .pte files use custom/quantized ops; register them (the server's + # runner_pool does this automatically, but a direct engine test must). + try: + import executorch.extension.llm.custom_ops.custom_ops # noqa: F401 + import executorch.kernels.quantized # noqa: F401 + except Exception: # noqa: BLE001 - assume statically linked otherwise + pass + # The session API takes token ids; tokenize prompts in Python (the server + # does the same). Load the model's tokenizer.json directly. + from tokenizers import Tokenizer as HFTokenizer + + cls._hf = HFTokenizer.from_file(_TOKENIZER) + + @classmethod + def _ids(cls, prompt): + return cls._hf.encode(prompt).ids + + @staticmethod + def _gen_text(runner, prompt): # standalone TextLLMRunner baseline + out = [] + runner.reset() + runner.generate( + prompt, + GenerationConfig(echo=False, max_new_tokens=12, temperature=0.0), + lambda t: out.append(t), + ) + return "".join(out) + + def _session_ids(self, session, prompt_ids, n=12): + """Drive a session via prefill_tokens + a decode_one loop (the actual + new path); return the exact generated token ids.""" + session.reset() + session.prefill_tokens(prompt_ids) + ids = [] + for _ in range(n): + step = session.decode_one(0.0) + ids.append(step["token_id"]) + if step["is_eos"]: + break + return ids + + def test_sessions_isolated_and_match_baseline(self): + p1 = "<|im_start|>user\nName one primary color.<|im_end|>\n<|im_start|>assistant\n" + p2 = "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\n" + base = TextLLMRunner(model_path=_MODEL, tokenizer_path=_TOKENIZER) + b1, b2 = self._gen_text(base, p1), self._gen_text(base, p2) + + engine = LLMEngine(model_path=_MODEL, tokenizer_path=_TOKENIZER) + s1, s2 = engine.create_session(), engine.create_session() + ids1 = self._session_ids(s1, self._ids(p1)) + ids2 = self._session_ids(s2, self._ids(p2)) + ids1b = self._session_ids(s1, self._ids(p1)) # after s2 ran + # The session's decode_one ids, decoded, match the standalone generation. + self.assertEqual(self._hf.decode(ids1).strip(), b1.strip()) + self.assertEqual(self._hf.decode(ids2).strip(), b2.strip()) + self.assertEqual(ids1, ids1b, "session1 must be unaffected by session2") + + def test_concurrent_sessions_do_not_crash(self): + # The original num_runners>1 path crashed (heap corruption) under + # concurrent backend calls; the engine lock must serialize them safely. + p = self._ids( + "<|im_start|>user\nCount to five.<|im_end|>\n<|im_start|>assistant\n" + ) + engine = LLMEngine(model_path=_MODEL, tokenizer_path=_TOKENIZER) + s1, s2 = engine.create_session(), engine.create_session() + expect = self._session_ids(s1, p) + errors = [] + + def worker(sess): + try: + for _ in range(3): + self.assertEqual(self._session_ids(sess, p), expect) + except Exception as e: # noqa: BLE001 + errors.append(repr(e)) + + threads = [threading.Thread(target=worker, args=(s,)) for s in (s1, s2)] + for t in threads: + t.start() + for t in threads: + t.join() + self.assertEqual(errors, [], "concurrent sessions crashed or drifted") From 75a6aa385b439bd6c928b49b7e5da9e9576e1bb6 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 15:32:32 -0700 Subject: [PATCH 2/3] [INITIAL] Update [ghstack-poisoned] --- extension/llm/server/__init__.py | 5 + extension/llm/server/python/__init__.py | 7 + extension/llm/server/python/chat_template.py | 122 ++++++++++ extension/llm/server/python/errors.py | 49 ++++ extension/llm/server/python/prefix_cache.py | 221 ++++++++++++++++++ extension/llm/server/python/protocol.py | 148 ++++++++++++ extension/llm/server/python/requirements.txt | 5 + .../server/python/tests/test_prefix_cache.py | 214 +++++++++++++++++ .../server/python/tool_parsers/__init__.py | 16 ++ .../llm/server/python/tool_parsers/hermes.py | 92 ++++++++ .../llm/server/python/tool_parsers/types.py | 33 +++ extension/llm/server/spec/README.md | 72 ++++++ 12 files changed, 984 insertions(+) create mode 100644 extension/llm/server/__init__.py create mode 100644 extension/llm/server/python/__init__.py create mode 100644 extension/llm/server/python/chat_template.py create mode 100644 extension/llm/server/python/errors.py create mode 100644 extension/llm/server/python/prefix_cache.py create mode 100644 extension/llm/server/python/protocol.py create mode 100644 extension/llm/server/python/requirements.txt create mode 100644 extension/llm/server/python/tests/test_prefix_cache.py create mode 100644 extension/llm/server/python/tool_parsers/__init__.py create mode 100644 extension/llm/server/python/tool_parsers/hermes.py create mode 100644 extension/llm/server/python/tool_parsers/types.py create mode 100644 extension/llm/server/spec/README.md diff --git a/extension/llm/server/__init__.py b/extension/llm/server/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/extension/llm/server/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/extension/llm/server/python/__init__.py b/extension/llm/server/python/__init__.py new file mode 100644 index 00000000000..00b6274c01f --- /dev/null +++ b/extension/llm/server/python/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible server for ExecuTorch LLMs (Python implementation).""" diff --git a/extension/llm/server/python/chat_template.py b/extension/llm/server/python/chat_template.py new file mode 100644 index 00000000000..04f807fb4d9 --- /dev/null +++ b/extension/llm/server/python/chat_template.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Render OpenAI chat messages into a single prompt string. + +The ExecuTorch runner tokenizes a plain prompt; chat formatting is the server's +job (control plane). We require the model's own Hugging Face ``chat_template`` +(via ``--hf-tokenizer``) for correct, tool-aware, reasoning-aware formatting. +The generic ChatML fallback is opt-in only (``allow_fallback``): it is +approximate and cannot reproduce model-specific controls (e.g. enable_thinking), +so it must be a deliberate choice rather than a silent default. +""" + +import logging +from typing import Any, Optional + +from .protocol import ChatMessage + +logger = logging.getLogger(__name__) + + +_DEFAULT_SPECIAL_TOKENS = ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "<|end|>"] + + +class ChatTemplate: + def __init__( + self, + hf_tokenizer_path: Optional[str] = None, + default_template_kwargs: Optional[dict[str, Any]] = None, + allow_fallback: bool = False, + ): + # Server-level defaults (e.g. {"enable_thinking": False}); per-request + # chat_template_kwargs override these. + self._defaults = default_template_kwargs or {} + self._hf = None + if hf_tokenizer_path: + from transformers import AutoTokenizer + + self._hf = AutoTokenizer.from_pretrained(hf_tokenizer_path) + if self._hf.chat_template is None: + self._hf = None + if not allow_fallback: + raise ValueError( + f"HF tokenizer at {hf_tokenizer_path} has no chat_template; " + "pass an explicit fallback flag to use approximate ChatML." + ) + logger.warning( + "No chat_template at %s; using approximate ChatML.", + hf_tokenizer_path, + ) + elif not allow_fallback: + raise ValueError( + "A chat template is required: pass --hf-tokenizer for the model's own " + "template, or opt into approximate ChatML with --allow-chatml-fallback." + ) + else: + logger.warning( + "No --hf-tokenizer; using approximate ChatML (no thinking control)." + ) + + def render( + self, + messages: list[ChatMessage], + tools: Optional[list[dict[str, Any]]] = None, + template_kwargs: Optional[dict[str, Any]] = None, + ) -> str: + kwargs = {**self._defaults, **(template_kwargs or {})} + if self._hf is not None: + return self._hf.apply_chat_template( + [m.model_dump(exclude_none=True) for m in messages], + tools=tools, + add_generation_prompt=True, + tokenize=False, + **kwargs, + ) + return self._fallback(messages) + + def chat_template_str(self) -> Optional[str]: + """Raw chat-template string (for tool-format auto-detection), if available.""" + return ( + getattr(self._hf, "chat_template", None) if self._hf is not None else None + ) + + def tokenizer(self): + """The underlying HF tokenizer (for token-level prefix caching), or None. + + Must match the runner's tokenizer (same model) for prefix reuse to be + valid — i.e. the recommended --hf-tokenizer matching the exported model. + """ + return self._hf + + def count_tokens(self, prompt: str) -> Optional[int]: + """Token count for the rendered prompt, or None if no tokenizer is available.""" + if self._hf is not None: + return len(self._hf.encode(prompt)) + return None + + def special_tokens(self) -> list[str]: + """Special-token strings whose appearance ends the visible content. + + From the HF tokenizer when available (model-accurate), else a default set + covering common chat models. + """ + if self._hf is not None: + toks = list(getattr(self._hf, "all_special_tokens", []) or []) + return [t for t in toks if isinstance(t, str) and t] + return list(_DEFAULT_SPECIAL_TOKENS) + + @staticmethod + def _fallback(messages: list[ChatMessage]) -> str: + # Approximate ChatML. Provide --hf-tokenizer for model-correct formatting + # (including reasoning controls like enable_thinking, which the fallback + # cannot reproduce). + parts = [] + for m in messages: + content = m.content if isinstance(m.content, str) else str(m.content or "") + parts.append(f"<|im_start|>{m.role}\n{content}<|im_end|>") + parts.append("<|im_start|>assistant\n") + return "\n".join(parts) diff --git a/extension/llm/server/python/errors.py b/extension/llm/server/python/errors.py new file mode 100644 index 00000000000..8e87966f983 --- /dev/null +++ b/extension/llm/server/python/errors.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-shaped API errors. + +Raising these lets the server return a structured `{"error": {...}}` body with +the right HTTP status instead of dropping the connection. +""" + +from typing import Optional + + +class APIError(Exception): + def __init__( + self, status: int, message: str, err_type: str, code: Optional[str] = None + ): + super().__init__(message) + self.status = status + self.message = message + self.err_type = err_type + self.code = code + + def body(self) -> dict: + return { + "error": {"message": self.message, "type": self.err_type, "code": self.code} + } + + +class ContextLengthExceeded(APIError): + def __init__(self, num_tokens: int, max_context: int): + super().__init__( + status=400, + message=( + f"This model's maximum context length is {max_context} tokens, " + f"but the request has {num_tokens} prompt tokens." + ), + err_type="invalid_request_error", + code="context_length_exceeded", + ) + + +class GenerationError(APIError): + def __init__(self, detail: str): + super().__init__( + status=500, message=f"Generation failed: {detail}", err_type="server_error" + ) diff --git a/extension/llm/server/python/prefix_cache.py b/extension/llm/server/python/prefix_cache.py new file mode 100644 index 00000000000..74757c2b262 --- /dev/null +++ b/extension/llm/server/python/prefix_cache.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Turn-to-turn KV prefix reuse over the Engine/Session API. + +An agent re-sends a large, mostly-unchanged prompt every turn (system + tools + +context + history). The C++ KV cache is contiguous and position-indexed, so we +reuse it: find the longest token prefix shared with what's cached, seek() to it, +prefill only the new suffix, then drive decode_one(). + +Because decode_one() returns the EXACT sampled token ids, we track them +(`_cached = prompt_ids + generated_ids`) — so a follow-up turn that includes the +prior completion reuses that too (not just the static system prefix). This is +safe in a way re-tokenizing generated text is not (BPE encode(a)+encode(b) != +encode(a+b)). + +Constraints baked in: + - Only position-0 prefixes are reusable (RoPE is position-dependent) — exactly + the agent shape (system+tools+context+history at the front). + - seek() refuses on sliding-window models; we catch that and fall back to a + full reset + re-prefill. + +Lifecycle: a session is conversation-scoped, not request-scoped. The pool keeps +it warm across requests (KV preserved) and routes follow-up turns back to it by +prefix affinity; it is reset only on an unrecoverable error or torn down at +shutdown — never reset per request, which would discard the cache reuse exists +to exploit. +""" + +import codecs +import logging +import time +from typing import Optional, Protocol, runtime_checkable + +logger = logging.getLogger(__name__) + + +def longest_common_prefix(a: list[int], b: list[int]) -> int: + n = 0 + for x, y in zip(a, b): + if x != y: + break + n += 1 + return n + + +@runtime_checkable +class Session(Protocol): + """The C++ LLMSession surface this driver uses.""" + + def prefill_tokens(self, token_ids: list[int]) -> None: ... + def decode_one( + self, temperature: float = ... + ) -> dict: ... # {token_id, text:bytes, is_eos} + def seek(self, pos: int) -> None: ... + def position(self) -> int: ... + def reset(self) -> None: ... + def stop(self) -> None: ... + + +class Tokenizer(Protocol): + def encode(self, text: str, add_special_tokens: bool = ...) -> list[int]: ... + + +class _Stats: + """Token counts handed to the pool's stats_callback (matches the C++ Stats + attribute names the pool reads).""" + + __slots__ = ("num_prompt_tokens", "num_generated_tokens") + + def __init__(self, num_prompt_tokens: int, num_generated_tokens: int): + self.num_prompt_tokens = num_prompt_tokens + self.num_generated_tokens = num_generated_tokens + + +class PrefixCachingSession: + """Drives one LLMSession with turn-to-turn prefix reuse, tracking the exact + tokens resident in its KV cache (prompt + generated).""" + + def __init__( + self, + session: Session, + tokenizer: Tokenizer, + index: int = 0, + max_context_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + ): + self._session = session + self._tok = tokenizer + self._cached: list[int] = [] + self._index = index + self._fallbacks = 0 + self._stop = False + self._max_context_len = max_context_len + self._max_seq_len = max_seq_len + + @property + def cached_tokens(self) -> list[int]: + return self._cached + + def _encode(self, text: str) -> list[int]: + return list(self._tok.encode(text, add_special_tokens=False)) + + def reuse_len(self, prompt_ids: list[int]) -> int: + """Tokens reusable from cache, capped at the runner's resident position + (never seek() past resident KV) and at len-1 (always prefill >=1 token).""" + reuse = longest_common_prefix(self._cached, prompt_ids) + position = getattr(self._session, "position", None) + if position is not None: + reuse = min(reuse, position()) + if reuse >= len(prompt_ids): + reuse = len(prompt_ids) - 1 + return max(0, reuse) + + def _resolved_max_new_tokens(self, config) -> Optional[int]: + max_new = getattr(config, "max_new_tokens", -1) if config is not None else -1 + if self._max_context_len is None: + return None if max_new <= 0 else max_new + position = self._session.position() + # Match TextLLMRunner.generate(): sliding-window exports do not treat + # position as consumed full-context capacity. + if self._max_seq_len is not None and self._max_seq_len < self._max_context_len: + position = 0 + if config is not None and hasattr(config, "resolve_max_new_tokens"): + return config.resolve_max_new_tokens(self._max_context_len, position) + if max_new <= 0: + return max(0, self._max_context_len - position) + return max(0, min(max_new, self._max_context_len - position)) + + def generate( # noqa: C901 - prefill/reuse + decode loop + fallbacks read clearest inline + self, prompt: str, config, token_callback=None, stats_callback=None + ) -> None: + prompt_ids = self._encode(prompt) + self._stop = False + start = time.perf_counter() + ttft = None + + # --- prefill: reuse the shared prefix, else (on failure) full prefill --- + reuse = self.reuse_len(prompt_ids) + fallback = False + try: + # seek(reuse) (reuse may be 0 for a cold session) repositions to the + # shared prefix, discarding any stale KV beyond it; then prefill only + # the suffix. + self._session.seek(reuse) + self._session.prefill_tokens(prompt_ids[reuse:]) + except Exception as e: # noqa: BLE001 - reuse setup failed -> safe full path + fallback = True + self._fallbacks += 1 + reuse = 0 + logger.debug("prefix reuse setup failed (%s); full prefill", e) + self._session.reset() + self._cached = [] + try: + self._session.prefill_tokens(prompt_ids) + except Exception: + self._session.reset() + self._cached = [] + raise + + # --- decode loop: bounded by max_new_tokens; stop on EOS or stop() --- + max_new = self._resolved_max_new_tokens(config) + if max_new is not None and max_new <= 0: + raise RuntimeError("No available context capacity for generation") + temperature = ( + getattr(config, "temperature", -1.0) if config is not None else -1.0 + ) + decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") + gen_ids: list[int] = [] + n = 0 + try: + while (max_new is None or n < max_new) and not self._stop: + step = self._session.decode_one(temperature) + gen_ids.append(step["token_id"]) + n += 1 + piece = decoder.decode(step["text"]) # assemble UTF-8 from byte pieces + if piece: + if ttft is None: + ttft = time.perf_counter() - start + if token_callback: + token_callback(piece) + if step["is_eos"]: + break + tail = decoder.decode(b"", final=True) + if tail and token_callback: + token_callback(tail) + except ( + Exception + ): # noqa: BLE001 - real decode error: reset + propagate (no retry) + self._session.reset() + self._cached = [] + raise + + # Track EXACT ids (prompt + generated) so the next turn can reuse the + # completion too. seek() stays capped at position(), so this is safe. + self._cached = prompt_ids + gen_ids + if stats_callback: + stats_callback(_Stats(len(prompt_ids), n)) + logger.info( + "prefix-cache runner=%d reused=%d suffix=%d generated=%d fallback=%s " + "fallbacks_total=%d ttft_ms=%.0f", + self._index, + reuse, + len(prompt_ids) - reuse, + n, + fallback, + self._fallbacks, + (ttft or 0.0) * 1000, + ) + + def reset(self) -> None: + self._session.reset() + self._cached = [] + + def stop(self) -> None: + # Cooperative cancellation: the decode loop checks _stop each step. + self._stop = True + self._session.stop() diff --git a/extension/llm/server/python/protocol.py b/extension/llm/server/python/protocol.py new file mode 100644 index 00000000000..2d73d2d7f64 --- /dev/null +++ b/extension/llm/server/python/protocol.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible request/response schemas for the ExecuTorch LLM server. + +This is the Python view of the contract defined in ``extension/llm/server/spec``. +Any language server must serialize to the same shapes; the conformance suite in +``extension/llm/server/conformance`` validates them. +""" + +import time +import uuid +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +def _new_id(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex}" + + +class FunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + index: Optional[int] = None + id: Optional[str] = None + type: Literal["function"] = "function" + function: FunctionCall + + +class ChatMessage(BaseModel): + role: str + content: Optional[Union[str, list[dict[str, Any]]]] = None + name: Optional[str] = None + tool_calls: Optional[list[ToolCall]] = None + tool_call_id: Optional[str] = None + + +class StreamOptions(BaseModel): + include_usage: bool = False + + +class ChatCompletionRequest(BaseModel): + model: Optional[str] = None + messages: list[ChatMessage] + stream: bool = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + stop: Optional[Union[str, list[str]]] = None + n: int = 1 + seed: Optional[int] = None + # Sampling knobs that change generation output. We don't plumb these, so they + # are modeled (not dropped) in order to be rejected with a clear error rather + # than silently ignored — see serving_chat's unsupported-parameter check. + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + top_k: Optional[int] = None + logit_bias: Optional[dict[str, float]] = None + # Output-contract fields: modeled (not dropped) so we reject the ones we + # can't honor rather than returning an output that violates what was asked. + response_format: Optional[dict[str, Any]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + parallel_tool_calls: Optional[bool] = None + # Per-request chat-template controls, e.g. {"enable_thinking": false} for Qwen3. + chat_template_kwargs: Optional[dict[str, Any]] = None + # Accepted now so the contract is stable; parsing/enforcement land in M2/M5. + tools: Optional[list[dict[str, Any]]] = None + tool_choice: Optional[Union[str, dict[str, Any]]] = None + reasoning_effort: Optional[str] = None + + def resolved_max_tokens(self) -> int: + # `is not None` (not `or`): an explicit 0 must not be treated as unset. + # Callers validate positivity; -1 means "unset / auto". + if self.max_completion_tokens is not None: + return self.max_completion_tokens + if self.max_tokens is not None: + return self.max_tokens + return -1 + + +class Usage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class ResponseMessage(BaseModel): + role: str = "assistant" + content: Optional[str] = None + tool_calls: Optional[list[ToolCall]] = None + + +class Choice(BaseModel): + index: int = 0 + message: ResponseMessage + finish_reason: Optional[str] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: _new_id("chatcmpl")) + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[Choice] + usage: Usage = Field(default_factory=Usage) + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + tool_calls: Optional[list[ToolCall]] = None + + +class ChunkChoice(BaseModel): + index: int = 0 + delta: DeltaMessage + finish_reason: Optional[str] = None + + +class ChatCompletionChunk(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChunkChoice] + usage: Optional[Usage] = None + + +class ModelCard(BaseModel): + id: str + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "executorch" + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: list[ModelCard] diff --git a/extension/llm/server/python/requirements.txt b/extension/llm/server/python/requirements.txt new file mode 100644 index 00000000000..70ad7ccb4dd --- /dev/null +++ b/extension/llm/server/python/requirements.txt @@ -0,0 +1,5 @@ +fastapi>=0.110 +uvicorn[standard]>=0.27 +pydantic>=2.0 +# Optional but recommended for model-correct chat templating (--hf-tokenizer): +# transformers>=4.40 diff --git a/extension/llm/server/python/tests/test_prefix_cache.py b/extension/llm/server/python/tests/test_prefix_cache.py new file mode 100644 index 00000000000..bef784fcc6e --- /dev/null +++ b/extension/llm/server/python/tests/test_prefix_cache.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the prefix-cache session driver over the LLMSession API, using a +fake session (prefill_tokens/decode_one/seek/position/reset) + a byte-level fake +tokenizer so reuse decisions and exact-id cache tracking are verified without a +real engine.""" + +import pytest + +from executorch.extension.llm.server.python.prefix_cache import ( + longest_common_prefix, + PrefixCachingSession, +) + + +class FakeTokenizer: + # Byte-level: prefix-preserving and deterministic — ideal for LCP testing. + def encode(self, text, add_special_tokens=False): + return list(text.encode("utf-8")) + + +class FakeSession: + """LLMSession-shaped fake: decode_one() emits `gen_ids` in order, signalling + is_eos on the last one. Tracks position like the real session.""" + + def __init__(self, gen_ids=(10,), fail_seek=False): + self.gen_ids = list(gen_ids) + self.fail_seek = fail_seek + self._cursor = 0 + self._pos = 0 + self.seeks = [] + self.prefilled = [] + self.reset_count = 0 + + def prefill_tokens(self, ids): + assert len(ids) >= 1, "prefill_tokens must get >=1 token" + self.prefilled.append(list(ids)) + self._pos += len(ids) + + def decode_one(self, temperature=-1.0): + tid = self.gen_ids[self._cursor] if self._cursor < len(self.gen_ids) else 0 + self._cursor += 1 + self._pos += 1 + return { + "token_id": tid, + "text": bytes([65 + (tid % 26)]), # arbitrary 1-byte piece + "is_eos": self._cursor >= len(self.gen_ids), + } + + def seek(self, pos): + if self.fail_seek: + raise RuntimeError("seek unsupported (SWA)") + self.seeks.append(pos) + self._pos = pos + self._cursor = 0 + + def position(self): + return self._pos + + def reset(self): + self.reset_count += 1 + self._pos = 0 + self._cursor = 0 + + def stop(self): + pass + + +def _sess(gen_ids=(10,), fail_seek=False): + return PrefixCachingSession(FakeSession(gen_ids, fail_seek), FakeTokenizer()) + + +class FakeConfig: + def __init__(self, max_new_tokens=-1, seq_len=-1, temperature=0.0): + self.max_new_tokens = max_new_tokens + self.seq_len = seq_len + self.temperature = temperature + + def resolve_max_new_tokens(self, max_context_len, num_tokens_occupied): + if self.seq_len == -1 and self.max_new_tokens == -1: + result = max_context_len - num_tokens_occupied + elif self.seq_len == -1: + result = min(self.max_new_tokens, max_context_len - num_tokens_occupied) + elif self.max_new_tokens == -1: + result = min(self.seq_len, max_context_len) - num_tokens_occupied + else: + result = min( + min(self.seq_len, max_context_len) - num_tokens_occupied, + self.max_new_tokens, + ) + return max(0, result) + + +def test_longest_common_prefix(): + assert longest_common_prefix([1, 2, 3], [1, 2, 9]) == 2 + assert longest_common_prefix([], [1]) == 0 + assert longest_common_prefix([1, 2], [1, 2, 3]) == 2 + + +def test_first_turn_prefills_all_and_tracks_exact_ids(): + s = _sess(gen_ids=[10, 11]) + s.generate("abc", config=None) + assert s._session.seeks == [0] # seek(0) on a fresh session + assert s._session.prefilled == [list(b"abc")] + # cache = prompt ids + EXACT generated ids (decode_one token_ids). + assert s.cached_tokens == list(b"abc") + [10, 11] + + +def test_completion_reuse_via_exact_ids(): + # Turn 2 includes turn-1's completion; exact-id tracking lets us reuse it. + s = _sess(gen_ids=[120, 121]) # ids 120,121 -> bytes 'y','z' (120%26=16 -> 'Q'..) + s.generate("abc", None) # cache -> b"abc" + [120,121] + cached_after_t1 = list(s.cached_tokens) + # Build a turn-2 prompt whose ids are exactly the cached ids + a new suffix. + s._session.gen_ids = [99] + # Monkeypatch the tokenizer to return cached + new (simulating prompt=history). + s._tok = type( + "T", + (), + { + "encode": lambda self, t, add_special_tokens=False: cached_after_t1 + + [55, 56] + }, + )() + s.generate("ignored", None) + # Reused the whole cached prefix (capped at len-1), prefilled only the tail. + assert s._session.seeks[-1] == len(cached_after_t1) + assert s._session.prefilled[-1] == [55, 56] + + +def test_mid_divergence_partial_reuse(): + s = _sess(gen_ids=[10]) + s.generate("abcdef", None) # cache -> b"abcdef" + [10] + s.generate("abcXYZ", None) # shares b"abc" + assert s._session.seeks[-1] == 3 + assert s._session.prefilled[-1] == list(b"XYZ") + + +def test_reuse_capped_at_resident_position(): + # If _cached ever exceeds resident position, reuse is capped at position(). + s = _sess(gen_ids=[10]) + s._cached = list(b"abc!") # seed 4 cached + s._session._pos = 3 # but only 3 resident + s.generate("abc!XY", None) + assert s._session.seeks[-1] == 3 # capped, not 4 + assert s._session.prefilled[-1] == list(b"!XY") + + +def test_unset_max_tokens_resolves_against_context_capacity(): + session = FakeSession(gen_ids=[10, 11, 12, 13]) + s = PrefixCachingSession( + session, + FakeTokenizer(), + max_context_len=5, + max_seq_len=5, + ) + out = [] + s.generate("abc", FakeConfig(max_new_tokens=-1), token_callback=out.append) + assert out == ["K", "L"] # prompt pos 3 leaves exactly 2 decode slots + assert s.cached_tokens == list(b"abc") + [10, 11] + assert session.position() == 5 + + +def test_explicit_max_tokens_clamped_to_context_capacity(): + session = FakeSession(gen_ids=[10, 11, 12, 13]) + s = PrefixCachingSession( + session, + FakeTokenizer(), + max_context_len=5, + max_seq_len=5, + ) + s.generate("abc", FakeConfig(max_new_tokens=10)) + assert s.cached_tokens == list(b"abc") + [10, 11] + + +def test_warm_reuse_matches_cold_output(): + # A warm session (reuses a shared prefix) emits the same tokens as a cold + # session that full-prefills the same prompt — reuse must not perturb output. + warm = _sess(gen_ids=[65, 66, 67]) + warm.generate("system prefix ", None) # warm the cache + warm_out = [] + warm.generate("system prefix and more", None, token_callback=warm_out.append) + + cold = _sess(gen_ids=[65, 66, 67]) # fresh: seek(0) + full prefill + cold_out = [] + cold.generate("system prefix and more", None, token_callback=cold_out.append) + assert "".join(warm_out) == "".join(cold_out) + + +def test_fallback_on_seek_failure(): + # SWA model where seek() raises -> reset + full prefill, still correct. + s = _sess(gen_ids=[10], fail_seek=True) + out = [] + s.generate("abc", None, token_callback=lambda t: out.append(t)) + assert s._session.reset_count >= 1 + assert s._session.prefilled[-1] == list(b"abc") # full prefill after fallback + assert s.cached_tokens == list(b"abc") + [10] + + +def test_generation_error_propagates_without_retry(): + # A failure during decode_one must propagate (after reset), not be retried. + class FailingSession(FakeSession): + def decode_one(self, temperature=-1.0): + raise RuntimeError("backend boom") + + s = PrefixCachingSession(FailingSession(), FakeTokenizer()) + with pytest.raises(RuntimeError, match="backend boom"): + s.generate("abc", None) + assert s._session.reset_count >= 1 + assert s.cached_tokens == [] diff --git a/extension/llm/server/python/tool_parsers/__init__.py b/extension/llm/server/python/tool_parsers/__init__.py new file mode 100644 index 00000000000..71703d33a0c --- /dev/null +++ b/extension/llm/server/python/tool_parsers/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tool-call parsing. Hermes/Qwen format only (...). + +The server buffers the model's full output and parses it once into complete +OpenAI tool_calls; parse failures degrade to visible text. +""" + +from .hermes import HermesDetector +from .types import ParseResult, ToolCallItem + +__all__ = ["HermesDetector", "ParseResult", "ToolCallItem"] diff --git a/extension/llm/server/python/tool_parsers/hermes.py b/extension/llm/server/python/tool_parsers/hermes.py new file mode 100644 index 00000000000..809c18761a5 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/hermes.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hermes-style tool calls: {"name": ..., "arguments": {...}}. + +Used by Qwen2.5/Qwen3 (and Hermes models) — the only tool-call format this +server supports. The server buffers a model's full output and parses it once +into complete OpenAI tool_calls (no partial-fragment streaming). Parse failures +fall back to visible text — never a crash or a silent drop. +""" + +import json +import logging +import re +from typing import Any, Optional + +from .types import ParseResult, ToolCallItem + +logger = logging.getLogger(__name__) + +_CALL_RE = re.compile(r"\s*(.*?)\s*", re.DOTALL) + + +class _UndefinedToolCall(Exception): + """A named a tool not in the request's `tools`. v1 degrades the + WHOLE response to visible text rather than emitting a partial set — never + silently drop an undefined call while keeping its siblings (spec).""" + + +class HermesDetector: + """Parses Hermes/Qwen tool calls. Create a fresh instance per request (it + holds the per-request tool-call index); never share across requests.""" + + bot_token = "" + + def __init__(self): + self._next_index = 0 + + def detect_and_parse(self, text: str, tool_names: set[str]) -> ParseResult: + """Return leading text + any complete tool calls. On no call or a parse + failure, return the original text unchanged (kept visible to the client).""" + if self.bot_token not in text: + return ParseResult(normal_text=text) + normal = text[: text.find(self.bot_token)].strip() + try: + calls = self._parse_calls(text, tool_names) + except _UndefinedToolCall as e: + # Degrade the whole response to visible text so the undefined call + # isn't silently dropped (and its valid siblings aren't executed in + # isolation, losing the model's full intent). + logger.debug("undefined tool %s; returning raw text (no partial calls)", e) + return ParseResult(normal_text=text) + except Exception as e: # noqa: BLE001 - never crash; fall back to visible text + logger.debug("tool parse failed (%s); returning raw text", e) + return ParseResult(normal_text=text) + if not calls: + return ParseResult(normal_text=text) + return ParseResult(normal_text=normal, calls=calls) + + def _parse_calls(self, text: str, tool_names: set[str]) -> list[ToolCallItem]: + calls = [] + for raw in _CALL_RE.findall(text): + if not raw.strip(): + continue + obj = json.loads(raw.strip()) + for entry in obj if isinstance(obj, list) else [obj]: + calls.append( + self._make_item( + entry.get("name"), + entry.get("arguments", entry.get("parameters")), + tool_names, + ) + ) + return calls + + def _make_item( + self, name: Optional[str], arguments: Any, tool_names: set[str] + ) -> ToolCallItem: + if not name or name not in tool_names: + raise _UndefinedToolCall(repr(name)) + item = ToolCallItem( + tool_index=self._next_index, + name=name, + arguments=json.dumps( + arguments if arguments is not None else {}, ensure_ascii=False + ), + ) + self._next_index += 1 + return item diff --git a/extension/llm/server/python/tool_parsers/types.py b/extension/llm/server/python/tool_parsers/types.py new file mode 100644 index 00000000000..2dae5c79458 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/types.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Protocol-agnostic tool-parsing types. + +Kept independent of the OpenAI wire schema so the parser package is reusable; +serving_chat translates these into OpenAI tool_calls / deltas at the edge. +Design adapted from SGLang's core_types, with explicit per-request state. +""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ToolCallItem: + """A parsed tool call. `arguments` is a JSON string (the full arguments — + this server emits complete calls, not fragments).""" + + tool_index: int + name: Optional[str] = None + arguments: str = "" + + +@dataclass +class ParseResult: + """Outcome of a parse: free text plus any tool calls found.""" + + normal_text: str = "" + calls: list[ToolCallItem] = field(default_factory=list) diff --git a/extension/llm/server/spec/README.md b/extension/llm/server/spec/README.md new file mode 100644 index 00000000000..04da6edc3f4 --- /dev/null +++ b/extension/llm/server/spec/README.md @@ -0,0 +1,72 @@ +# ExecuTorch LLM Server — Contract Spec + +The language-neutral contract every ExecuTorch LLM server (Python today, C++ +later) implements. The conformance suite in `../conformance` validates an +implementation against this spec by hitting a live server, so it is independent +of language and engine. + +## Supported endpoints + +| Endpoint | Status | +|----------|--------| +| `GET /v1/models` | implemented | +| `POST /v1/chat/completions` (stream + non-stream) | implemented | +| `GET /health` | implemented | +| `POST /v1/completions` | planned | + +## `POST /v1/chat/completions` + +OpenAI Chat Completions subset. **Honored** request fields: `model`, `messages`, +`stream`, `temperature`, `max_tokens` / `max_completion_tokens`, `stop`, `tools`, +`tool_choice` (only `"none"` to disable tools, or `"auto"`/unset for default +parsing), `stream_options.include_usage`, and `chat_template_kwargs` (e.g. +`enable_thinking`). + +**Rejected** with `400 invalid_request_error` (`code: "unsupported_parameter"`) +rather than silently ignored — a client relying on them would otherwise get +wrong behavior: `top_p` (anything other than `1.0`), `seed`, `n` (> 1), +`reasoning_effort`, `frequency_penalty`/`presence_penalty` (nonzero), `top_k`, +`logit_bias`, `tool_choice` = `"required"` or a specific-function choice +(forcing/restricting a call needs constrained decoding, which v1 lacks), +`response_format` other than `{"type": "text"}` (no constrained JSON), +`logprobs`/`top_logprobs` (not returned), and `parallel_tool_calls: false` +(single-call can't be guaranteed without constraining). Unknown fields that +don't affect the output (e.g. `user`, `store`, `metadata`) are accepted and +ignored. + +Non-streaming response: `chat.completion` with one `choice` +(`message.role = "assistant"`, string `content` or `tool_calls`, `finish_reason` +∈ `stop` | `length` | `tool_calls`) and a `usage` block. + +Streaming response: `text/event-stream` of `chat.completion.chunk` objects — +first chunk carries `delta.role = "assistant"`, subsequent chunks carry +`delta.content` (or buffered `delta.tool_calls`), a final chunk carries +`finish_reason`, optionally a usage-only chunk (with +`stream_options.include_usage`), terminated by `data: [DONE]`. + +### Tool calling + +Hermes/Qwen format only (`{"name":...,"arguments":{...}}`). +The server buffers the model's full output and emits **complete** OpenAI +`tool_calls` (no partial-argument fragments). Calls to tools absent from the +request, and malformed tool JSON, degrade to visible text — never a crash or +silent drop. `tool_choice="none"` disables tool parsing. + +### Errors & cancellation + +Errors return `{"error": {"message", "type", "code"}}` with an appropriate +status (e.g. `400 context_length_exceeded` when `--max-context` is set and the +prompt exceeds it). A mid-stream failure emits an `error` SSE event then +`[DONE]` rather than dropping the socket. A client disconnect cancels generation +(the runner's `stop()` is called). + +### Prefix cache (opt-in, `--enable-prefix-cache`) + +Off by default. When enabled (requires `--hf-tokenizer`), each runner reuses the +longest common token prefix of consecutive prompts via `seek()` / +`prefill_tokens()`, capped at the runner's resident KV position. It is +conservative and fail-safe: any reuse failure (including sliding-window models, +where it is disabled) falls back to a full prefill. Outputs are equivalent to +cache-off but, on real models, not bit-identical (greedy decoding is sensitive +to the tiny numerical differences between prefill chunkings — the same tradeoff +vLLM/llama.cpp/SGLang make). From e35d01aede68bb25b323d9063b077c1060381214 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 4 Jun 2026 15:38:20 -0700 Subject: [PATCH 3/3] [UPDATE] Update [ghstack-poisoned] --- extension/llm/server/python/chat_template.py | 28 +++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/extension/llm/server/python/chat_template.py b/extension/llm/server/python/chat_template.py index 403e6943e3c..52a7ef21243 100644 --- a/extension/llm/server/python/chat_template.py +++ b/extension/llm/server/python/chat_template.py @@ -14,6 +14,7 @@ so it must be a deliberate choice rather than a silent default. """ +import json import logging from typing import Any, Optional @@ -25,6 +26,29 @@ _DEFAULT_SPECIAL_TOKENS = ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "<|end|>"] +def _decode_tool_call_arguments(messages: list[dict[str, Any]]) -> None: + """In-place: parse each tool call's ``function.arguments`` from a JSON string + into an object. + + OpenAI sends assistant tool-call arguments as a JSON-encoded string, but HF + chat templates expect a mapping (e.g. Qwen renders ``arguments|items`` into + ```` tags). Without this, a multi-turn tool conversation makes + the template raise "Can only get item pairs from a mapping". Left as-is if + the value isn't valid JSON, so a template that wants the raw string still works. + """ + for m in messages: + for tc in m.get("tool_calls") or []: + fn = tc.get("function") + if not isinstance(fn, dict): + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + fn["arguments"] = json.loads(args) + except (ValueError, TypeError): + pass + + class ChatTemplate: def __init__( self, @@ -69,8 +93,10 @@ def render( ) -> str: kwargs = {**self._defaults, **(template_kwargs or {})} if self._hf is not None: + dumped = [m.model_dump(exclude_none=True) for m in messages] + _decode_tool_call_arguments(dumped) return self._hf.apply_chat_template( - [m.model_dump(exclude_none=True) for m in messages], + dumped, tools=tools, add_generation_prompt=True, tokenize=False,