From 729842940a5ef7d7cbe57efcd9fd2726c28e9197 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 15:32:32 -0700 Subject: [PATCH] [INITIAL] Update [ghstack-poisoned] --- extension/llm/runner/__init__.py | 4 + extension/llm/runner/_llm_runner.pyi | 60 ++++ extension/llm/runner/pybindings.cpp | 257 +++++++++++++++++- .../llm/runner/test/test_runner_pybindings.py | 132 +++++++++ 4 files changed, 446 insertions(+), 7 deletions(-) diff --git a/extension/llm/runner/__init__.py b/extension/llm/runner/__init__.py index 4e0ced33b21..97950cc95fa 100644 --- a/extension/llm/runner/__init__.py +++ b/extension/llm/runner/__init__.py @@ -18,6 +18,8 @@ from executorch.extension.llm.runner._llm_runner import ( # noqa: F401 GenerationConfig, Image, + LLMEngine, + LLMSession, make_audio_input, make_image_input, make_raw_audio_input, @@ -234,5 +236,7 @@ def generate_text_hf( "MultimodalInput", "MultimodalRunner", "TextLLMRunner", + "LLMEngine", + "LLMSession", "Stats", ] diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 271cf1e1540..79e5e3adc8c 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -411,6 +411,66 @@ class TextLLMRunner: def __repr__(self) -> str: ... +class LLMSession: + """A per-conversation session created by LLMEngine: reuses the engine's + program/resources (weight sharing is backend-dependent — see + LLMEngine.serving_capacity()) but owns its own KV cache. Backend calls + (prefill_tokens/decode_one) are serialized across the engine's sessions by + an engine-owned lock.""" + + def prefill_tokens(self, token_ids: List[int]) -> None: ... + def decode_one(self, temperature: float = -1.0) -> dict: + """One decode step -> {"token_id": int, "text": bytes, "is_eos": bool}.""" + ... + + def seek(self, pos: int) -> None: ... + def position(self) -> int: ... + def reset(self) -> None: ... + def stop(self) -> None: + """Token-boundary cooperative stop: safe from another thread, but it + does not abort a decode_one() already running — it takes effect before + the next decode_one().""" + ... + + def __repr__(self) -> str: ... + +class LLMEngine: + """Engine for multi-session text generation over one loaded program. + + Loads the model's program once; create_session() returns a LLMSession that + reuses it but owns its own KV cache. Whether extra sessions avoid + duplicating packed weights is backend-dependent — ask serving_capacity(). Backend execution across all sessions of one engine is + serialized by an engine-owned lock (backend ops are not assumed + thread-safe), so it is safe to drive multiple sessions from multiple Python + threads. + """ + + def __init__( + self, + model_path: str, + tokenizer_path: str, + data_path: Optional[str] = None, + method_name: str = "forward", + temperature: float = -1.0, + ) -> None: ... + def create_session(self) -> LLMSession: + """Create a session that reuses this engine's program/resources (weight + sharing is backend-dependent — see serving_capacity()), with its own KV + cache.""" + ... + + def serving_capacity(self) -> dict: + """Serving-capacity dict: max_physical_sessions_without_weight_duplication + (1 = single-slot, no weight duplication) and estimated_bytes_per_session + (0 = unknown). The server clamps physical sessions to this.""" + ... + + def metadata(self) -> dict: + """Model metadata from the .pte, e.g. get_max_context_len.""" + ... + + def __repr__(self) -> str: ... + class MultimodalRunner: """Runner for multimodal language models.""" diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index 3188b5390c4..cbd79959c80 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,74 @@ using namespace executorch::runtime; } \ }) +namespace { + +// Length of the longest prefix of `s` that does not end in the middle of a +// UTF-8 multi-byte sequence. Byte-level BPE tokenizers can emit a token that is +// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji), +// so the std::string->str conversion must wait until the character is complete +// or it throws UnicodeDecodeError and aborts generation. +size_t utf8_complete_prefix_len(const std::string& s) { + size_t i = 0; + const size_t n = s.size(); + while (i < n) { + const unsigned char c = static_cast(s[i]); + size_t len; + if (c < 0x80) { + len = 1; + } else if ((c >> 5) == 0x6) { + len = 2; + } else if ((c >> 4) == 0xE) { + len = 3; + } else if ((c >> 3) == 0x1E) { + len = 4; + } else { + len = 1; // invalid lead byte; emit it and let "replace" handle it + } + if (i + len > n) { + break; // incomplete trailing sequence: hold it for the next token + } + i += len; + } + return i; +} + +// Wraps a Python str token-callback so multi-byte characters split across +// tokens are buffered and only forwarded once complete. `flush()` emits any +// trailing (incomplete) bytes with replacement so nothing is silently dropped. +class TokenStringCallback { + public: + explicit TokenStringCallback(py::object cb) : cb_(std::move(cb)) {} + + void operator()(const std::string& token) { + buf_ += token; + const size_t n = utf8_complete_prefix_len(buf_); + if (n == 0) { + return; + } + py::gil_scoped_acquire acquire; + cb_(py::reinterpret_steal( + PyUnicode_DecodeUTF8(buf_.data(), n, "replace"))); + buf_.erase(0, n); + } + + void flush() { + if (buf_.empty()) { + return; + } + py::gil_scoped_acquire acquire; + cb_(py::reinterpret_steal( + PyUnicode_DecodeUTF8(buf_.data(), buf_.size(), "replace"))); + buf_.clear(); + } + + private: + py::object cb_; + std::string buf_; +}; + +} // namespace + // Python wrapper class for TextLLMRunner class PyTextLLMRunner { public: @@ -79,10 +148,11 @@ class PyTextLLMRunner { // Convert Python callbacks to C++ std::function std::function cpp_token_callback = nullptr; + std::shared_ptr token_assembler; if (!token_callback.is_none()) { - cpp_token_callback = [token_callback](const std::string& token) { - py::gil_scoped_acquire acquire; - token_callback(token); + token_assembler = std::make_shared(token_callback); + cpp_token_callback = [token_assembler](const std::string& token) { + (*token_assembler)(token); }; } @@ -94,13 +164,16 @@ class PyTextLLMRunner { }; } - // Release GIL during generation + // Release GIL during generation. { py::gil_scoped_release release; Error error = runner_->generate( prompt, config, cpp_token_callback, cpp_stats_callback); THROW_IF_ERROR(error, "Generation failed"); } + if (token_assembler) { + token_assembler->flush(); + } } void stop() { @@ -174,10 +247,11 @@ class PyMultimodalRunner { // Convert Python callbacks to C++ std::function std::function cpp_token_callback = nullptr; + std::shared_ptr token_assembler; if (!token_callback.is_none()) { - cpp_token_callback = [token_callback](const std::string& token) { - py::gil_scoped_acquire acquire; - token_callback(token); + token_assembler = std::make_shared(token_callback); + cpp_token_callback = [token_assembler](const std::string& token) { + (*token_assembler)(token); }; } @@ -196,6 +270,9 @@ class PyMultimodalRunner { inputs, config, cpp_token_callback, cpp_stats_callback); THROW_IF_ERROR(error, "Generation failed"); } + if (token_assembler) { + token_assembler->flush(); + } } std::string generate_text( @@ -251,6 +328,124 @@ class PyMultimodalRunner { std::unique_ptr runner_; }; +// A session handle (LLMSession), the model-agnostic per-conversation API. +// Backend calls (prefill_tokens/decode_one) take the engine-owned lock so +// concurrent sessions of one engine serialize (Module::execute isn't assumed +// thread-safe); cheap state ops (seek/reset/position/stop) don't. +class PyLLMSession { + public: + PyLLMSession( + std::unique_ptr session, + std::shared_ptr exec_mutex) + : session_(std::move(session)), exec_mutex_(std::move(exec_mutex)) {} + + void prefill_tokens(std::vector tokens) { + py::gil_scoped_release release; + auto exec_lock = lock_exec(); + THROW_IF_ERROR( + session_->prefill_tokens(std::move(tokens)), "prefill_tokens failed"); + } + + py::dict decode_one(float temperature = -1.0f) { + uint64_t token_id; + std::string text; + bool is_eos; + { + py::gil_scoped_release release; + auto exec_lock = lock_exec(); + SamplingConfig sampling; + sampling.temperature = temperature; + auto res = session_->decode_one(sampling); + THROW_IF_ERROR(res.error(), "decode_one failed"); + const auto& r = res.get(); + token_id = r.token_id; + text = r.text_piece; + is_eos = r.is_eos; + } + py::dict d; + d["token_id"] = token_id; + d["text"] = py::bytes(text); + d["is_eos"] = is_eos; + return d; + } + + void seek(int64_t pos) { + THROW_IF_ERROR(session_->seek(pos), "seek failed"); + } + int64_t position() const { + return session_->position(); + } + void reset() { + THROW_IF_ERROR(session_->reset(), "reset failed"); + } + void stop() { + session_->stop(); + } + + private: + std::unique_lock lock_exec() { + return exec_mutex_ ? std::unique_lock(*exec_mutex_) + : std::unique_lock(); + } + std::unique_ptr session_; + std::shared_ptr exec_mutex_; +}; + +// Engine over one loaded Program: loads it once; create_session() returns an +// LLMSession that reuses it but owns its own KV state. Physical weight sharing +// across sessions is backend-dependent (serving_capacity() is authoritative). +class PyLLMEngine { + public: + PyLLMEngine( + const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path = std::nullopt, + const std::string& method_name = "forward", + float temperature = -1.0f) { + if (data_path.has_value()) { + throw std::runtime_error( + "LLMEngine: shared sessions with external data (.ptd / data_path) are " + "not yet supported; use a self-contained .pte (the session Module " + "needs the data_map_loader threaded through — tracked as a follow-up)."); + } + engine_ = TextLLMEngine::create( + model_path, tokenizer_path, data_path, temperature, method_name); + if (!engine_) { + throw std::runtime_error( + "Failed to create LLMEngine with model: " + model_path); + } + } + + std::unique_ptr create_session() { + auto res = engine_->create_session(); + THROW_IF_ERROR(res.error(), "Failed to create session from LLMEngine"); + // Hand the session the engine-owned lock so backend execution across all + // sessions of this engine is serialized. + return std::make_unique(std::move(res.get()), exec_mutex_); + } + + py::dict serving_capacity() const { + const auto c = engine_->serving_capacity(); + py::dict d; + d["max_physical_sessions_without_weight_duplication"] = + c.max_physical_sessions_without_weight_duplication; + d["estimated_bytes_per_session"] = c.estimated_bytes_per_session; + return d; + } + + py::dict metadata() const { + py::dict d; + for (const auto& [key, value] : engine_->metadata()) { + d[py::str(key)] = value; + } + return d; + } + + private: + std::unique_ptr engine_; + std::shared_ptr exec_mutex_ = std::make_shared(); +}; + PYBIND11_MODULE(_llm_runner, m) { m.doc() = "Python bindings for ExecuTorch LLM Runners"; @@ -735,6 +930,54 @@ PYBIND11_MODULE(_llm_runner, m) { return ""; }); + // Bind PyLLMEngine: shared-weight engine, create_session() per conversation. + // Bind PyLLMSession: the per-conversation session handle. + py::class_(m, "LLMSession") + .def( + "prefill_tokens", + &PyLLMSession::prefill_tokens, + py::arg("token_ids"), + "Prefill pre-tokenized input at the current cache position.") + .def( + "decode_one", + &PyLLMSession::decode_one, + py::arg("temperature") = -1.0f, + "Decode one token; returns {token_id:int, text:bytes, is_eos:bool}.") + .def("seek", &PyLLMSession::seek, py::arg("pos"), "Rewind KV to `pos`.") + .def("position", &PyLLMSession::position, "Resident KV token count.") + .def("reset", &PyLLMSession::reset, "Clear KV / position.") + .def("stop", &PyLLMSession::stop, "Signal an in-flight decode to stop.") + .def("__repr__", [](const PyLLMSession&) { return ""; }); + + py::class_(m, "LLMEngine") + .def( + py::init< + const std::string&, + const std::string&, + std::optional, + const std::string&, + float>(), + py::arg("model_path"), + py::arg("tokenizer_path"), + py::arg("data_path") = py::none(), + py::arg("method_name") = "forward", + py::arg("temperature") = -1.0f, + "Load a model's program once for multi-session serving.") + .def( + "create_session", + &PyLLMEngine::create_session, + "Create an LLMSession that reuses the engine's program/resources " + "(weight sharing is backend-dependent — see serving_capacity()) but " + "owns its own KV cache. Backend execution across sessions is " + "serialized by an engine-owned lock.") + .def( + "serving_capacity", + &PyLLMEngine::serving_capacity, + "Serving-capacity dict; the server clamps physical sessions to " + "max_physical_sessions_without_weight_duplication (1 = single-slot).") + .def("metadata", &PyLLMEngine::metadata, "Model metadata from the .pte.") + .def("__repr__", [](const PyLLMEngine&) { return ""; }); + // Bind PyMultimodalRunner py::class_(m, "MultimodalRunner") // Constructor with tokenizer path diff --git a/extension/llm/runner/test/test_runner_pybindings.py b/extension/llm/runner/test/test_runner_pybindings.py index 5619e586c4b..33ec4971a62 100644 --- a/extension/llm/runner/test/test_runner_pybindings.py +++ b/extension/llm/runner/test/test_runner_pybindings.py @@ -14,19 +14,51 @@ import os import tempfile + +import threading import unittest import torch from executorch.extension.llm.runner import ( GenerationConfig, Image, + LLMEngine, + LLMSession, make_image_input, make_text_input, MultimodalInput, MultimodalRunner, + TextLLMRunner, ) +class TestSessionApiBoundary(unittest.TestCase): + """The Python serving boundary: token-step primitives live ONLY on + LLMSession, never on the legacy TextLLMRunner (whose token-step methods are + C++ implementation details behind TextLLMSession). Pure class introspection, + so no model/.pte is needed.""" + + TOKEN_STEP = ("prefill_tokens", "decode_one", "seek", "position") + + def test_text_llm_runner_does_not_expose_token_step(self): + for name in self.TOKEN_STEP: + self.assertFalse( + hasattr(TextLLMRunner, name), + f"TextLLMRunner must not expose token-step method {name!r} to " + f"Python; drive sessions through LLMSession instead.", + ) + + def test_llm_session_exposes_token_step(self): + for name in (*self.TOKEN_STEP, "reset", "stop"): + self.assertTrue( + hasattr(LLMSession, name), f"LLMSession must expose {name!r}" + ) + + def test_llm_engine_exposes_serving_api(self): + for name in ("create_session", "serving_capacity"): + self.assertTrue(hasattr(LLMEngine, name), f"LLMEngine must expose {name!r}") + + class TestGenerationConfig(unittest.TestCase): """Test the GenerationConfig class.""" @@ -264,3 +296,103 @@ def test_make_image_input(self): img_tensor_rgba = torch.ones((4, 50, 50), dtype=torch.uint8) * 128 image_input_rgba = make_image_input(img_tensor_rgba) self.assertTrue(image_input_rgba.is_image()) + + +# Real-engine tests need a model; gated on env vars so they're skipped in CI +# environments without one (fake-runner unit tests can't exercise the real +# shared-Program / serialization behavior). +_MODEL = os.environ.get("ET_TEST_TEXT_LLM_MODEL") +_TOKENIZER = os.environ.get("ET_TEST_TEXT_LLM_TOKENIZER") + + +@unittest.skipUnless( + _MODEL and _TOKENIZER, + "set ET_TEST_TEXT_LLM_MODEL and ET_TEST_TEXT_LLM_TOKENIZER to run", +) +class TestLLMEngineSessions(unittest.TestCase): + """LLMEngine: sessions share weights, stay isolated, and serialize backend + execution so concurrent sessions don't corrupt each other.""" + + @classmethod + def setUpClass(cls): + # LLM .pte files use custom/quantized ops; register them (the server's + # runner_pool does this automatically, but a direct engine test must). + try: + import executorch.extension.llm.custom_ops.custom_ops # noqa: F401 + import executorch.kernels.quantized # noqa: F401 + except Exception: # noqa: BLE001 - assume statically linked otherwise + pass + # The session API takes token ids; tokenize prompts in Python (the server + # does the same). Load the model's tokenizer.json directly. + from tokenizers import Tokenizer as HFTokenizer + + cls._hf = HFTokenizer.from_file(_TOKENIZER) + + @classmethod + def _ids(cls, prompt): + return cls._hf.encode(prompt).ids + + @staticmethod + def _gen_text(runner, prompt): # standalone TextLLMRunner baseline + out = [] + runner.reset() + runner.generate( + prompt, + GenerationConfig(echo=False, max_new_tokens=12, temperature=0.0), + lambda t: out.append(t), + ) + return "".join(out) + + def _session_ids(self, session, prompt_ids, n=12): + """Drive a session via prefill_tokens + a decode_one loop (the actual + new path); return the exact generated token ids.""" + session.reset() + session.prefill_tokens(prompt_ids) + ids = [] + for _ in range(n): + step = session.decode_one(0.0) + ids.append(step["token_id"]) + if step["is_eos"]: + break + return ids + + def test_sessions_isolated_and_match_baseline(self): + p1 = "<|im_start|>user\nName one primary color.<|im_end|>\n<|im_start|>assistant\n" + p2 = "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\n" + base = TextLLMRunner(model_path=_MODEL, tokenizer_path=_TOKENIZER) + b1, b2 = self._gen_text(base, p1), self._gen_text(base, p2) + + engine = LLMEngine(model_path=_MODEL, tokenizer_path=_TOKENIZER) + s1, s2 = engine.create_session(), engine.create_session() + ids1 = self._session_ids(s1, self._ids(p1)) + ids2 = self._session_ids(s2, self._ids(p2)) + ids1b = self._session_ids(s1, self._ids(p1)) # after s2 ran + # The session's decode_one ids, decoded, match the standalone generation. + self.assertEqual(self._hf.decode(ids1).strip(), b1.strip()) + self.assertEqual(self._hf.decode(ids2).strip(), b2.strip()) + self.assertEqual(ids1, ids1b, "session1 must be unaffected by session2") + + def test_concurrent_sessions_do_not_crash(self): + # The original num_runners>1 path crashed (heap corruption) under + # concurrent backend calls; the engine lock must serialize them safely. + p = self._ids( + "<|im_start|>user\nCount to five.<|im_end|>\n<|im_start|>assistant\n" + ) + engine = LLMEngine(model_path=_MODEL, tokenizer_path=_TOKENIZER) + s1, s2 = engine.create_session(), engine.create_session() + expect = self._session_ids(s1, p) + errors = [] + + def worker(sess): + try: + for _ in range(3): + self.assertEqual(self._session_ids(sess, p), expect) + except Exception as e: # noqa: BLE001 + errors.append(repr(e)) + + threads = [threading.Thread(target=worker, args=(s,)) for s in (s1, s2)] + for t in threads: + t.start() + for t in threads: + t.join() + self.assertEqual(errors, [], "concurrent sessions crashed or drifted")