From 8a8a62779793ed99fb367d0f165687eaf6cfb2eb Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 15:32:32 -0700 Subject: [PATCH 1/6] [INITIAL] Update [ghstack-poisoned] --- extension/llm/runner/llm_runner_helper.cpp | 177 ++++++++++++++++++ extension/llm/runner/llm_runner_helper.h | 88 +++++++++ extension/llm/runner/llm_session.h | 120 ++++++++++++ extension/llm/runner/targets.bzl | 1 + .../llm/runner/test/test_text_llm_runner.cpp | 160 ++++++++++++++++ extension/llm/runner/text_llm_runner.cpp | 134 ++++++++++++- extension/llm/runner/text_llm_runner.h | 75 ++++++++ extension/llm/runner/text_token_generator.h | 21 ++- 8 files changed, 770 insertions(+), 6 deletions(-) create mode 100644 extension/llm/runner/llm_session.h diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 0744c09e641..293c5abd68b 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -29,8 +29,18 @@ namespace executorch::extension::llm { using ::executorch::extension::Module; +using ::executorch::extension::Program; using ::executorch::runtime::Error; +// Assembles the per-Module components (decoder/prefiller/token generator/io +// manager/stats) into a TextLLMRunner. Shared by the path-based and the +// shared-Program (TextLLMEngine session) construction paths. +static std::unique_ptr assemble_text_llm_runner( + std::unique_ptr module, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature, + const std::string& method_name); + std::unique_ptr load_tokenizer( const std::string& tokenizer_path, std::unique_ptr> special_tokens, @@ -251,6 +261,15 @@ std::unique_ptr create_text_llm_runner( max_cached_memory_size_bytes_)); } + return assemble_text_llm_runner( + std::move(module), std::move(tokenizer), temperature, method_name); +} + +static std::unique_ptr assemble_text_llm_runner( + std::unique_ptr module, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature, + const std::string& method_name) { // Get metadata from Module ET_LOG(Info, "Reading metadata from model"); auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); @@ -305,6 +324,164 @@ std::unique_ptr create_text_llm_runner( temperature); } +std::unique_ptr create_text_llm_runner_from_program( + std::shared_ptr program, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature, + const std::string& method_name) { + if (!tokenizer || !tokenizer->is_loaded()) { + ET_LOG(Error, "Tokenizer is null or not loaded"); + return nullptr; + } + if (!program) { + ET_LOG(Error, "Program is null"); + return nullptr; + } + // A Module over the already-loaded Program: it reuses that Program rather + // than re-loading it, and its loaded method allocates its own planned (KV) + // memory. Whether packed weights are physically shared vs. re-materialized + // per method instance is backend-dependent (serving_capacity() is the + // authority); on XNNPACK assume per-instance. + constexpr uint32_t kMaxCachedMemoryBytes = 1024 * 1024 * 10; // 10MB + auto module = std::make_unique( + std::move(program), + nullptr, // memory allocator + std::make_unique( + kMaxCachedMemoryBytes)); + return assemble_text_llm_runner( + std::move(module), std::move(tokenizer), temperature, method_name); +} + +namespace { +// The TextLLM adapter: implements the model-agnostic LLMSession over a +// TextLLMRunner. TextLLMRunner is an implementation detail here — the engine +// and server depend only on LLMSession. +class TextLLMSession : public LLMSession { + public: + explicit TextLLMSession(std::unique_ptr runner) + : runner_(std::move(runner)) {} + + Error prefill_tokens(std::vector tokens) override { + return runner_->prefill_tokens(std::move(tokens)).error(); + } + ::executorch::runtime::Result decode_one( + const SamplingConfig& sampling) override { + // Only temperature is plumbed today; top_p/top_k/seed need a per-session + // sampler (applied in a follow-up). + return runner_->decode_one(sampling.temperature); + } + Error seek(int64_t pos) override { + return runner_->seek(pos); + } + int64_t position() const override { + return runner_->position(); + } + Error reset() override { + runner_->reset(); + return Error::Ok; + } + void stop() override { + runner_->stop(); + } + + private: + std::unique_ptr runner_; +}; +} // namespace + +TextLLMEngine::TextLLMEngine( + std::unique_ptr loader_module, + std::shared_ptr program, + std::string tokenizer_path, + float temperature, + std::string method_name, + std::unordered_map metadata) + : loader_module_(std::move(loader_module)), + program_(std::move(program)), + tokenizer_path_(std::move(tokenizer_path)), + temperature_(temperature), + method_name_(std::move(method_name)), + metadata_(std::move(metadata)) {} + +std::unique_ptr TextLLMEngine::create( + const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path, + float temperature, + const std::string& method_name, + Module::LoadMode load_mode) { + // External .ptd weights are not yet supported for shared sessions: each + // session Module built from the shared Program would also need the + // data_map_loader threaded into its load_method() to resolve external + // weights (see Module::load_method merged_data_map_). Fail loudly rather than + // silently produce sessions that error on first generate. + if (data_path.has_value()) { + ET_LOG( + Error, + "TextLLMEngine: external data_path (.ptd) is not yet supported for " + "shared sessions; use a self-contained .pte for now."); + return nullptr; + } + // Load the program ONCE; sessions reuse it (loaded a single time, per-session + // KV). Physical weight sharing across sessions is backend-dependent — see + // serving_capacity(). + auto loader_module = std::make_unique(model_path, load_mode); + if (loader_module->load() != Error::Ok) { + ET_LOG( + Error, + "TextLLMEngine: failed to load program from %s", + model_path.c_str()); + return nullptr; + } + auto program = loader_module->program(); + if (!program) { + ET_LOG(Error, "TextLLMEngine: program is null after load"); + return nullptr; + } + // Read model-level metadata once (shared by all sessions). + auto meta_tokenizer = load_tokenizer(tokenizer_path); + if (!meta_tokenizer) { + ET_LOG( + Error, + "TextLLMEngine: failed to load tokenizer from %s", + tokenizer_path.c_str()); + return nullptr; + } + auto metadata_result = + get_llm_metadata(meta_tokenizer.get(), loader_module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "TextLLMEngine: failed to read metadata"); + return nullptr; + } + return std::unique_ptr(new TextLLMEngine( + std::move(loader_module), + std::move(program), + tokenizer_path, + temperature, + method_name, + metadata_result.get())); +} + +::executorch::runtime::Result> +TextLLMEngine::create_session() { + auto tokenizer = load_tokenizer(tokenizer_path_); + if (!tokenizer) { + ET_LOG( + Error, + "TextLLMEngine: failed to load tokenizer from %s", + tokenizer_path_.c_str()); + return Error::InvalidState; + } + auto runner = create_text_llm_runner_from_program( + program_, std::move(tokenizer), temperature_, method_name_); + if (!runner) { + ET_LOG(Error, "TextLLMEngine: failed to build session runner"); + return Error::InvalidState; + } + return std::unique_ptr( + std::make_unique(std::move(runner))); +} + std::unique_ptr create_multimodal_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index b4c7c59806d..de1ecc743e3 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -141,6 +142,93 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( const std::string& method_name = "forward", Module::LoadMode load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors); +/** + * @brief Creates a TextLLMRunner over an already-loaded Program. + * + * Unlike create_text_llm_runner(model_path, ...), this does not load the model + * file again: the resulting runner's Module reuses `program` while owning its + * own method state and KV cache. This is the per-session construction path for + * TextLLMEngine — N sessions reuse one loaded Program but isolate their mutable + * KV state. Whether they also avoid re-materializing packed weights per session + * is backend-dependent (serving_capacity() is authoritative; XNNPACK repacks + * per method instance, so assume per-session weights there). + * + * The caller must keep the DataLoader backing `program` alive for the lifetime + * of every runner created from it (TextLLMEngine holds the loader Module). + * + * @param program Shared, already-loaded program. + * @param tokenizer Initialized tokenizer instance (owned by the new runner). + * @param temperature Optional temperature (deprecated; prefer + * GenerationConfig). + * @param method_name Name of the method to execute in the model. + * @return std::unique_ptr on success, or nullptr on failure. + */ +ET_EXPERIMENTAL std::unique_ptr +create_text_llm_runner_from_program( + std::shared_ptr program, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature = -1.0f, + const std::string& method_name = "forward"); + +/** + * @brief Engine for multi-session text generation over one loaded Program. + * + * Loads the model's Program (weights/constants) once; create_session() builds a + * TextLLMRunner that reuses that Program but owns its own method/KV state. This + * is the correctness-first foundation for serving multiple conversations. + * Backend execution should be serialized by the caller until per-backend thread + * safety is proven (Module::execute is not assumed thread-safe). Whether extra + * sessions actually avoid duplicating packed weights is a backend property + * (e.g. AOTI/CUDA share device weights) reported by serving_capacity(); on the + * XNNPACK path weights are repacked per method instance and the KV cache is + * baked into the .pte, so it conservatively reports a single physical session. + */ +class ET_EXPERIMENTAL TextLLMEngine : public LLMEngine { + public: + static std::unique_ptr create( + const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path = std::nullopt, + float temperature = -1.0f, + const std::string& method_name = "forward", + Module::LoadMode load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors); + + // Returns a TextLLMSession (LLMSession) that reuses this engine's loaded + // Program (physical weight sharing is backend-dependent; see + // serving_capacity). + ::executorch::runtime::Result> create_session() + override; + // Conservative v1: a self-contained .pte repacks XNNPACK weights per runtime, + // so we don't claim multiple physical sessions share weights. Raise this on a + // backend/artifact proven to share packed weights. + LLMServingCapacity serving_capacity() const override { + return LLMServingCapacity{}; + } + const std::unordered_map& metadata() const override { + return metadata_; + } + + TextLLMEngine(const TextLLMEngine&) = delete; + TextLLMEngine& operator=(const TextLLMEngine&) = delete; + + private: + TextLLMEngine( + std::unique_ptr loader_module, + std::shared_ptr program, + std::string tokenizer_path, + float temperature, + std::string method_name, + std::unordered_map metadata); + + // Keeps the shared Program's DataLoader alive for the lifetime of sessions. + std::unique_ptr loader_module_; + std::shared_ptr program_; + std::string tokenizer_path_; + float temperature_; + std::string method_name_; + std::unordered_map metadata_; +}; + /** * @brief Creates a MultimodalRunner instance with dependency injection * diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h new file mode 100644 index 00000000000..456c26c281e --- /dev/null +++ b/extension/llm/runner/llm_session.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Model-agnostic Engine/Session interfaces. Model-specific execution lives in +// adapters that implement these (TextLLMSession over TextLLMRunner today; +// Gemma4Session etc. later); the server and pybind layer depend only on these +// interfaces, never on a concrete runner. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +/// Per-decode sampling parameters. An adapter applies what it supports and +/// reports the rest via capabilities(); -1 temperature means model default. +struct SamplingConfig { + float temperature = -1.0f; + float top_p = 1.0f; + int32_t top_k = 0; // 0 = disabled + uint64_t seed = 0; // 0 = unset +}; + +/// One decoded step: the exact sampled token id (for prefix-cache id tracking +/// and batching), its decoded text piece (raw bytes; may be a partial UTF-8 +/// sequence the caller assembles), and whether it is an EOS token. +struct DecodeResult { + uint64_t token_id; + std::string text_piece; + bool is_eos; +}; + +/// How many physical sessions an engine can host, so the server admits logical +/// requests without silently multiplying model memory. This is a *serving +/// capacity* concern (engine-level), distinct from how a session advances a +/// conversation (LLMSession) — keep backend memory flags off LLMSession. +struct LLMServingCapacity { + // Physical sessions (loaded runtimes) creatable without duplicating packed + // weights. Conservatively 1: a self-contained .pte with inline constants + // repacks weights per XNNPACK runtime, so N logical requests queue on one + // physical session (llama.cpp single-slot), not N copies of the model. A + // backend that provably shares packed weights (XNNWeightsCache with named + // external data; CUDA/AOTI shared device weights) can report >1. + int32_t max_physical_sessions_without_weight_duplication = 1; + // Planned bytes one session adds (KV + activations), for memory-budget + // admission. 0 = unknown; the server skips the memory clamp. + int64_t estimated_bytes_per_session = 0; +}; + +/// One conversation's mutable state (KV cache, position cursor). Created by an +/// LLMEngine; conversation/cache-scoped (kept warm across requests for prefix +/// reuse), not request-scoped. +class ET_EXPERIMENTAL LLMSession { + public: + virtual ~LLMSession() = default; + + /// Prefill pre-tokenized input at the current position (call seek() first for + /// prefix reuse). Must be non-empty and fit the context window. + virtual ::executorch::runtime::Error prefill_tokens( + std::vector tokens) = 0; + + /// Decode one token from the pending state; looping reproduces a full + /// generation while returning exact sampled token ids. A single decode_one() + /// runs one forward pass and is not interruptible mid-call (see stop()). + virtual ::executorch::runtime::Result decode_one( + const SamplingConfig& sampling) = 0; + + /// Rewind the KV cache to `pos` (prefix reuse). Valid for full-KV models; + /// sliding-window KV may reject a seek past its window (the caller falls back + /// to a fresh prefill). + virtual ::executorch::runtime::Error seek(int64_t pos) = 0; + + /// Number of tokens with resident KV (upper bound for seek()). + virtual int64_t position() const = 0; + + /// Clear the KV cache / position for a fresh conversation. + virtual ::executorch::runtime::Error reset() = 0; + + /// Request that a decode_one() loop stop. This is a TOKEN-BOUNDARY, + /// cooperative stop: it is safe to call from another thread, but it does not + /// abort a decode_one() that is already running — it takes effect before the + /// next decode_one() (the loop driver checks between tokens). + virtual void stop() = 0; +}; + +/// Holds the immutable model resources (program, tokenizer, metadata) once and +/// creates sessions that reuse them while isolating their own KV state. How +/// many sessions can be created without duplicating packed weights is backend- +/// dependent — see serving_capacity(). +class ET_EXPERIMENTAL LLMEngine { + public: + virtual ~LLMEngine() = default; + + /// Build a new session that reuses this engine's program/resources when the + /// backend supports it, with its own KV cache. serving_capacity() is the + /// authority on how many physical sessions are safe without weight + /// duplication. + virtual ::executorch::runtime::Result> + create_session() = 0; + + /// How many physical sessions this engine can host without duplicating + /// weights (+ optional per-session memory estimate); the server clamps the + /// number of physical sessions it creates to this. + virtual LLMServingCapacity serving_capacity() const = 0; + virtual const std::unordered_map& metadata() const = 0; +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index d3e12266adc..2914078cd40 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -114,6 +114,7 @@ def define_common_targets(): exported_headers = [ "text_llm_runner.h", "llm_runner_helper.h", + "llm_session.h", "constants.h", ], srcs = [ diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 5a98c55bb2f..e3f0b54abc8 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -21,6 +21,7 @@ using namespace ::testing; using executorch::extension::llm::GenerationConfig; +using executorch::extension::llm::LLMServingCapacity; using executorch::extension::llm::LogitProcessor; using executorch::extension::llm::Stats; using executorch::extension::llm::TextDecoderRunner; @@ -220,6 +221,44 @@ class RunnerTest : public Test { }; } + // Builds a loaded TextLLMRunner with default mocks whose prefiller advances + // the position cursor by the number of tokens (so position()/seek() and the + // capacity bound can be exercised directly). + std::unique_ptr makeRunner( + std::unordered_map metadata, + std::shared_ptr logit_processor = nullptr) { + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([](std::vector& tokens, int64_t& pos) { + pos += tokens.size(); + return Result(42); + }); + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + auto stats = std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + if (logit_processor) { + text_token_generator->add_logit_processor(std::move(logit_processor)); + } + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + auto runner = std::make_unique( + std::move(metadata), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + runner->load(); + return runner; + } + protected: Stats stats_; std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; @@ -490,6 +529,127 @@ TEST_F(RunnerTest, PrefillThenGenerateEmpty) { EXPECT_EQ(counter.getCount(), config.max_new_tokens); } +// prefill_tokens() must reject a suffix larger than the KV cache, mirroring the +// capacity bound generate(prompt) enforces (prefill_tokens is the public +// prefix-cache primitive and the only place this is checked for it). +TEST_F(RunnerTest, PrefillTokensRejectsOverContext) { + auto runner = makeRunner(createDefaultMetadata()); // max_context_len = 128 + auto result = + runner->prefill_tokens(std::vector(200, 1)); // 200 > 128 + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.error(), Error::InvalidArgument); +} + +// seek() + prefill_tokens() across the boundary is rejected: a valid prefill +// followed by a suffix that pushes pos_ past max_context_len must fail. +TEST_F(RunnerTest, PrefillTokensRejectsWhenPosPlusSuffixExceedsContext) { + auto runner = makeRunner(createDefaultMetadata()); // 128 + auto ok = runner->prefill_tokens(std::vector(100, 1)); + EXPECT_TRUE(ok.ok()); + EXPECT_EQ(runner->position(), 100); + auto bad = + runner->prefill_tokens(std::vector(50, 1)); // 100 + 50 > 128 + EXPECT_FALSE(bad.ok()); + EXPECT_EQ(bad.error(), Error::InvalidArgument); + EXPECT_EQ(runner->position(), 100); // rejected before advancing +} + +// Empty tokens are rejected. +TEST_F(RunnerTest, PrefillTokensRejectsEmpty) { + auto runner = makeRunner(createDefaultMetadata()); + auto result = runner->prefill_tokens(std::vector{}); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.error(), Error::InvalidArgument); +} + +// position() tracks prefilled tokens; seek() rewinds within range and rejects +// out-of-range targets. +TEST_F(RunnerTest, SeekAndPositionTrackResidentTokens) { + auto runner = makeRunner(createDefaultMetadata()); + EXPECT_EQ(runner->position(), 0); + EXPECT_TRUE(runner->prefill_tokens(std::vector(10, 1)).ok()); + EXPECT_EQ(runner->position(), 10); + EXPECT_EQ(runner->seek(5), Error::Ok); + EXPECT_EQ(runner->position(), 5); + EXPECT_EQ(runner->seek(999), Error::InvalidArgument); // past current position +} + +// seek() is refused on sliding-window models (max_seq_len < max_context_len), +// so the prefix cache falls back to a full reset+prefill instead of corrupting. +TEST_F(RunnerTest, SeekRejectedForSlidingWindow) { + auto md = createDefaultMetadata(); + md["get_max_seq_len"] = 64; // < get_max_context_len (128) => sliding window + auto runner = makeRunner(md); + EXPECT_EQ(runner->seek(0), Error::NotSupported); +} + +// decode_one() emits the pending token id exactly, then forwards it (advancing +// position by one). Looping it reproduces a generation while exposing ids. +TEST_F(RunnerTest, DecodeOneReturnsExactTokenIdAndAdvances) { + auto runner = makeRunner(createDefaultMetadata()); + ASSERT_TRUE(runner->prefill_tokens(std::vector{1, 2, 3}).ok()); + EXPECT_EQ(runner->position(), 3); + + auto r1 = runner->decode_one(0.0f); + ASSERT_TRUE(r1.ok()); + EXPECT_EQ(r1.get().token_id, 42u); // the prefill-pending token (mock prefill) + EXPECT_FALSE(r1.get().is_eos); + EXPECT_EQ(runner->position(), 4); // forwarded one token + + auto r2 = runner->decode_one(0.0f); + ASSERT_TRUE(r2.ok()); + EXPECT_EQ(r2.get().token_id, 3u); // argmax of canned logits {.1,.2,.3,.4} + EXPECT_EQ(runner->position(), 5); +} + +// decode_one() without a pending token (no prior prefill) must error. +TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { + auto runner = makeRunner(createDefaultMetadata()); + EXPECT_FALSE(runner->decode_one(0.0f).ok()); +} + +// decode_one() is the last-resort safety net for session drivers: even if a +// caller forgets to resolve max_new_tokens, it must not step past KV capacity. +TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { + auto md = createDefaultMetadata(); // max_context_len = 128 + auto runner = makeRunner(md); + ASSERT_TRUE(runner->prefill_tokens(std::vector(127, 1)).ok()); + auto r1 = runner->decode_one(0.0f); + ASSERT_TRUE(r1.ok()); + EXPECT_EQ(runner->position(), 128); + + auto r2 = runner->decode_one(0.0f); + EXPECT_FALSE(r2.ok()); + EXPECT_EQ(r2.error(), Error::InvalidArgument); +} + +// decode_one() must apply the generator's logit processors before sampling, +// exactly like generate(). Argmax of {0.1,0.2,0.3,0.4} is token 3; masking it +// makes the next sampled (pending) token 2 — proving the session decode path +// honors grammar/tool masks/penalties and can't diverge from generate(). +TEST_F(RunnerTest, DecodeOneAppliesLogitProcessors) { + auto runner = makeRunner( + createDefaultMetadata(), std::make_shared(3)); + ASSERT_TRUE(runner->prefill_tokens(std::vector{1, 2, 3}).ok()); + + // r1 emits the prefill-pending token (42, not sampled here); its forward pass + // samples the next pending token from masked logits. + ASSERT_TRUE(runner->decode_one(0.0f).ok()); + auto r2 = runner->decode_one(0.0f); + ASSERT_TRUE(r2.ok()); + EXPECT_EQ(r2.get().token_id, 2u); // token 3 masked -> argmax is now 2 +} + +// v1 serving capacity is conservatively single-slot: a self-contained .pte +// repacks XNNPACK weights per runtime, so we don't claim shared physical +// sessions. (TextLLMEngine::serving_capacity() returns this default; the +// engine-backed end-to-end check is in the pybinding test.) +TEST_F(RunnerTest, ServingCapacityIsSingleSlotByDefault) { + LLMServingCapacity cap; + EXPECT_EQ(cap.max_physical_sessions_without_weight_duplication, 1); + EXPECT_EQ(cap.estimated_bytes_per_session, 0); +} + // Test that generate("") without prior prefill() returns an error TEST_F(RunnerTest, GenerateEmptyWithoutPrefillFails) { auto tokenizer = createMockTokenizer(); diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index cf7ab50b9c8..34e4df61554 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -263,9 +263,15 @@ Error TextLLMRunner::generate( RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } - stats_->num_prompt_tokens = - prompt.empty() ? static_cast(pos_) : num_prompt_tokens; - stats_->num_generated_tokens = num_generated_tokens; + // The prefill step produced and emitted one token (cur_token) before the + // token generator ran, so the total generated count is that token plus the + // generator's. For an empty prompt (continuation/prefix-reuse path) the + // prompt length is everything resident before this turn's generation (pos_ + // already includes the generated tokens, so subtract them). + stats_->num_prompt_tokens = prompt.empty() + ? static_cast(pos_) - num_generated_tokens + : num_prompt_tokens; + stats_->num_generated_tokens = num_generated_tokens + 1; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); @@ -354,6 +360,128 @@ void TextLLMRunner::reset() { stats_->reset(); pos_ = 0; prefill_next_token_.reset(); + prev_decode_token_.reset(); +} + +::executorch::runtime::Error TextLLMRunner::seek(int64_t pos) { + // Sliding-window / ring-buffer models (max_seq_len < max_context_len) recycle + // KV space, so pos_ is not an absolute position and the prefix [0, pos) may + // have slid out of the window. Rewinding would attend to stale/overwritten KV + // and silently corrupt output. Refuse (fail-safe) so the caller falls back to + // reset() + full re-prefill — the same conservative choice vLLM + // (common_prefix_blocks=0 for SWA layers) and llama.cpp (seq_rm/get_can_shift + // fail for SWA) make. + if (metadata_.at(kMaxSeqLen) < metadata_.at(kMaxContextLen)) { + ET_LOG( + Error, + "seek() is unsupported for sliding-window models " + "(max_seq_len %" PRId64 " < max_context_len %" PRId64 ")", + metadata_.at(kMaxSeqLen), + metadata_.at(kMaxContextLen)); + return ::executorch::runtime::Error::NotSupported; + } + if (pos < 0 || pos > pos_) { + ET_LOG(Error, "seek(%" PRId64 ") out of range [0, %" PRId64 "]", pos, pos_); + return ::executorch::runtime::Error::InvalidArgument; + } + pos_ = pos; + prefill_next_token_.reset(); + prev_decode_token_.reset(); + return ::executorch::runtime::Error::Ok; +} + +::executorch::runtime::Result TextLLMRunner::prefill_tokens( + std::vector tokens) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + if (tokens.empty()) { + ET_LOG(Error, "prefill_tokens called with empty tokens"); + return ::executorch::runtime::Error::InvalidArgument; + } + // Same context-capacity guard as generate(): a caller that seek()s then + // prefill_tokens()es a suffix must not push pos_ past the KV cache. This is + // the only place the bound is enforced for prefill_tokens() (the public + // prefix-cache primitive), since it doesn't go through generate(prompt). + const int64_t max_seq_len = metadata_.at(kMaxSeqLen); + const int64_t max_context_len = metadata_.at(kMaxContextLen); + const int num_tokens = static_cast(tokens.size()); + if (max_seq_len >= max_context_len) { + ET_CHECK_OR_RETURN_ERROR( + pos_ + num_tokens < max_context_len, + InvalidArgument, + "pos_ %" PRId64 " + num_tokens %d >= max_context_len %" PRId64 + ", prefill_tokens would exceed KV cache capacity", + pos_, + num_tokens, + max_context_len); + } else { + ET_CHECK_OR_RETURN_ERROR( + num_tokens < max_context_len, + InvalidArgument, + "num_tokens %d >= max_context_len %" PRId64 + ", prefill_tokens exceeds KV cache capacity", + num_tokens, + max_context_len); + } + auto prefill_res = text_prefiller_->prefill(tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + prefill_next_token_ = prefill_res.get(); + prev_decode_token_.reset(); + return prefill_next_token_.value(); +} + +::executorch::runtime::Result TextLLMRunner::decode_one( + float temperature) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + ET_CHECK_OR_RETURN_ERROR( + prefill_next_token_.has_value(), + InvalidState, + "decode_one requires a pending token; call prefill()/prefill_tokens() first"); + if (metadata_.at(kMaxSeqLen) >= metadata_.at(kMaxContextLen)) { + ET_CHECK_OR_RETURN_ERROR( + pos_ < metadata_.at(kMaxContextLen), + InvalidArgument, + "decode_one would exceed KV cache capacity: pos_ %" PRId64 + " >= max_context_len %" PRId64, + pos_, + metadata_.at(kMaxContextLen)); + } + + // The pending token is the one we emit this step. + const uint64_t token = prefill_next_token_.value(); + const bool is_eos = text_token_generator_->is_eos(token); + + // Decode the text piece with BPE context (previous token), like generate(). + const uint64_t prev = prev_decode_token_.value_or(token); + std::string text_piece; + auto decode_res = tokenizer_->decode(prev, token); + if (decode_res.ok()) { + text_piece = std::move(*decode_res); + } + + // Forward `token` at pos_ to predict the next pending token. + std::vector tok_data = {token}; + std::vector<::executorch::aten::SizesType> shape = {1, 1}; + auto tok_tensor = + from_blob(tok_data.data(), shape, ::executorch::aten::ScalarType::Long); + auto logits_res = text_decoder_runner_->step(tok_tensor, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + // Apply the same logit processors generate() does (grammar/tool masks, + // penalties, top-k/top-p) so the session decode path can't diverge from it. + ET_CHECK_OK_OR_RETURN_ERROR( + text_token_generator_->apply_logit_processors(logits_res.get())); + const float temp = (temperature < 0.0f) + ? (temperature_ == -1.0f ? 0.0f : temperature_) + : temperature; + prefill_next_token_ = static_cast( + text_decoder_runner_->logits_to_token(logits_res.get(), temp)); + prev_decode_token_ = token; + pos_ += 1; + + return DecodeResult{token, std::move(text_piece), is_eos}; } } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index c73b6a4bed6..9347e93d262 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -28,6 +28,8 @@ // Helper functions are now in llm_runner_helper.h // These are provided for backward compatibility #include +// DecodeResult (returned by decode_one) lives with the Engine/Session API. +#include namespace executorch::extension::llm { @@ -152,6 +154,76 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { */ void reset() override; + /** + * @brief Truncate/rewind the KV cache to `pos` tokens. + * + * Sets the cache cursor so that subsequent attention reads positions + * [0, pos) and the next prefill() overwrites starting at `pos`. This enables + * prefix reuse across turns: rewind to the length of the prefix shared with + * the previous request, then prefill only the new suffix instead of + * re-prefilling the whole prompt. reset() is the special case pos == 0. + * + * Because the KV buffers are addressed by the position cursor (reset() does + * not clear them), no buffer manipulation is needed — stale entries beyond + * `pos` are ignored by attention and overwritten by the next prefill. + * + * @note Valid only for models exported with max_seq_len == max_context_len, + * where the cursor is an absolute position. Sliding-window/chunked models use + * a ring buffer in which positions are not absolute; do not seek() those. + * + * @param pos Target cache length in tokens; must be in [0, current position]. + * @return Error::Ok on success, Error::InvalidArgument if `pos` is out of + * range. + */ + ::executorch::runtime::Error seek(int64_t pos); + + /** + * @brief Prefill pre-tokenized input at the current KV cache position. + * + * Like prefill(prompt), but takes token ids directly instead of a string, so + * the caller controls the exact tokens written — required for prefix reuse, + * where the server computes the shared-prefix length, seek()s to it, and + * prefills only the new suffix tokens (no detokenize/re-tokenize round trip + * that could write mismatched KV). + * + * Tokens are written starting at the current position (call seek() first to + * position the cursor). The predicted next token is stored for a following + * generate("") call. + * + * @param tokens The token ids to prefill. Must be non-empty. + * @return The next token predicted after prefill, or an error. + */ + ::executorch::runtime::Result prefill_tokens( + std::vector tokens); + + /** + * @brief Current KV cache position (number of tokens with resident KV). + * + * This is the upper bound for a valid seek(): tokens at positions [0, pos) + * have been through a forward pass and are reusable. Note the last *sampled* + * token of a generation is not forwarded, so it is not resident — callers + * tracking emitted tokens must cap any reuse length at this value. + */ + int64_t position() const { + return pos_; + } + + /** + * @brief Decodes a single token: emits the current pending token (predicted + * by the preceding prefill/prefill_tokens or decode_one), then forwards it to + * predict the next pending token. Calling this in a loop reproduces the token + * sequence of generate(), but returns the exact sampled token id (not just + * decoded text) — the canonical decode unit for prefix-cache id tracking and + * future batched/interleaved scheduling. + * + * Requires a pending token (a prior prefill/prefill_tokens). `temperature` + * follows GenerationConfig semantics (-1 => use the model default / greedy). + * + * @return DecodeResult{token_id, text_piece, is_eos}, or an error. + */ + ::executorch::runtime::Result decode_one( + float temperature = -1.0f); + /** * @brief Stops the ongoing text generation process * @@ -185,6 +257,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { // Token predicted by the last prefill() call, consumed by generate(""). std::optional prefill_next_token_; + // Previously emitted token, for BPE-context text decoding in decode_one(). + std::optional prev_decode_token_; + // The position in KV cache of the input, starting from 0. int64_t pos_ = 0; }; diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 3627cacf3c3..856f1af5b9a 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -55,6 +55,18 @@ class ET_EXPERIMENTAL TextTokenGenerator { return logit_processors_.size(); } + /// Apply the registered logit processors (grammar/tool masks, penalties, + /// top-k/top-p, ...) to `logits` in order, before sampling. Both the + /// generate() loop and session decode_one() call this so the two decode paths + /// stay consistent. + inline ::executorch::runtime::Error apply_logit_processors( + executorch::aten::Tensor& logits) { + for (auto& processor : logit_processors_) { + ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits)); + } + return ::executorch::runtime::Error::Ok; + } + virtual ~TextTokenGenerator() = default; /** @@ -126,9 +138,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { prev_token = cur_token; - for (auto& processor : logit_processors_) { - ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor)); - } + ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors(logits_tensor)); stats_->on_sampling_begin(); cur_token = @@ -180,6 +190,11 @@ class ET_EXPERIMENTAL TextTokenGenerator { should_stop_.store(true, std::memory_order_relaxed); } + /// Whether `token` is an end-of-sequence token (used by single-step decode). + inline bool is_eos(uint64_t token) const { + return eos_ids_->find(token) != eos_ids_->end(); + } + /** * Load the necessary resources for TextTokenGenerator. * This method should be called before using the generate() method. From dff4856fe0b0f219dcc14c95e7eeec8ed8867dc0 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 3 Jun 2026 15:52:32 -0700 Subject: [PATCH 2/6] [UPDATE] Update [ghstack-poisoned] --- extension/llm/runner/llm_runner_helper.cpp | 10 +++++- extension/llm/runner/llm_session.h | 6 ++-- .../llm/runner/test/test_text_llm_runner.cpp | 32 ++++++++++++++++--- extension/llm/runner/text_llm_runner.cpp | 23 ++++++++++--- extension/llm/runner/text_llm_runner.h | 6 ++-- 5 files changed, 63 insertions(+), 14 deletions(-) diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 293c5abd68b..623d5d64453 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -367,7 +367,15 @@ class TextLLMSession : public LLMSession { ::executorch::runtime::Result decode_one( const SamplingConfig& sampling) override { // Only temperature is plumbed today; top_p/top_k/seed need a per-session - // sampler (applied in a follow-up). + // sampler (a follow-up). Reject non-default values rather than silently + // ignoring them, so callers can't assume constraints are applied. + if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + ET_LOG( + Error, + "TextLLMSession: only temperature is supported; top_p/top_k/seed are " + "not yet implemented"); + return ::executorch::runtime::Error::NotSupported; + } return runner_->decode_one(sampling.temperature); } Error seek(int64_t pos) override { diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h index 456c26c281e..a82a48d6578 100644 --- a/extension/llm/runner/llm_session.h +++ b/extension/llm/runner/llm_session.h @@ -24,8 +24,10 @@ namespace executorch::extension::llm { -/// Per-decode sampling parameters. An adapter applies what it supports and -/// reports the rest via capabilities(); -1 temperature means model default. +/// Per-decode sampling parameters. An adapter applies the fields it supports +/// and rejects non-default values of the rest rather than silently ignoring +/// them (today only temperature is plumbed). -1 temperature means model +/// default. struct SamplingConfig { float temperature = -1.0f; float top_p = 1.0f; diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index e3f0b54abc8..8d97ec00fea 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -226,15 +226,17 @@ class RunnerTest : public Test { // capacity bound can be exercised directly). std::unique_ptr makeRunner( std::unordered_map metadata, - std::shared_ptr logit_processor = nullptr) { + std::shared_ptr logit_processor = nullptr, + uint64_t prefill_token = 42) { auto tokenizer = createMockTokenizer(); auto text_decoder_runner = createMockTextDecoderRunner(); auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([](std::vector& tokens, int64_t& pos) { - pos += tokens.size(); - return Result(42); - }); + .WillByDefault( + [prefill_token](std::vector& tokens, int64_t& pos) { + pos += tokens.size(); + return Result(prefill_token); + }); ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); auto stats = std::make_unique(); auto text_token_generator = createTextTokenGenerator( @@ -608,6 +610,26 @@ TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { EXPECT_FALSE(runner->decode_one(0.0f).ok()); } +// decode_one() must stop at EOS WITHOUT forwarding it (like generate()): the +// EOS token is not made resident, position() does not advance, and no pending +// token remains — so prefix reuse stays correct and a further decode_one() +// errors. (The fixture's EOS id is 100.) +TEST_F(RunnerTest, DecodeOneStopsAtEosWithoutForwarding) { + auto runner = + makeRunner(createDefaultMetadata(), nullptr, /*prefill_token=*/100); + ASSERT_TRUE(runner->prefill_tokens(std::vector{1, 2, 3}).ok()); + EXPECT_EQ(runner->position(), 3); + + auto r = runner->decode_one(0.0f); + ASSERT_TRUE(r.ok()); + EXPECT_EQ(r.get().token_id, 100u); + EXPECT_TRUE(r.get().is_eos); + EXPECT_EQ(runner->position(), 3); // EOS not forwarded -> position unchanged + + // No pending token remains -> a further decode_one() errors. + EXPECT_FALSE(runner->decode_one(0.0f).ok()); +} + // decode_one() is the last-resort safety net for session drivers: even if a // caller forgets to resolve max_new_tokens, it must not step past KV capacity. TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 34e4df61554..7135dd020a6 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -455,11 +455,26 @@ ::executorch::runtime::Result TextLLMRunner::decode_one( const bool is_eos = text_token_generator_->is_eos(token); // Decode the text piece with BPE context (previous token), like generate(). + // Surface tokenizer errors rather than hiding them as an empty piece (matches + // generate(), which logs and returns InvalidArgument). const uint64_t prev = prev_decode_token_.value_or(token); - std::string text_piece; auto decode_res = tokenizer_->decode(prev, token); - if (decode_res.ok()) { - text_piece = std::move(*decode_res); + if (!decode_res.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(decode_res.error())); + return ::executorch::runtime::Error::InvalidArgument; + } + std::string text_piece = std::move(*decode_res); + + // Stop at EOS WITHOUT forwarding it, like generate() (which breaks before the + // next step()): the EOS token is not made resident and pos_ does not advance, + // so position()/prefix reuse stay correct. No pending token remains, so a + // subsequent decode_one() correctly errors (generation is complete). + if (is_eos) { + prefill_next_token_.reset(); + return DecodeResult{token, std::move(text_piece), true}; } // Forward `token` at pos_ to predict the next pending token. @@ -481,7 +496,7 @@ ::executorch::runtime::Result TextLLMRunner::decode_one( prev_decode_token_ = token; pos_ += 1; - return DecodeResult{token, std::move(text_piece), is_eos}; + return DecodeResult{token, std::move(text_piece), false}; } } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 9347e93d262..471c5628b47 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -172,8 +172,10 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { * a ring buffer in which positions are not absolute; do not seek() those. * * @param pos Target cache length in tokens; must be in [0, current position]. - * @return Error::Ok on success, Error::InvalidArgument if `pos` is out of - * range. + * @return Error::Ok on success; Error::InvalidArgument if `pos` is out of + * range; Error::NotSupported for sliding-window/chunked models + * (max_seq_len < max_context_len), where seek is unsafe — callers should fall + * back to reset() + full prefill. */ ::executorch::runtime::Error seek(int64_t pos); From 0756e5b65389357e652a6077a9ae0550a6aa95b5 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 4 Jun 2026 11:48:03 -0700 Subject: [PATCH 3/6] [UPDATE] Update [ghstack-poisoned] --- extension/llm/runner/llm_runner_helper.cpp | 94 ++++++----- extension/llm/runner/llm_runner_helper.h | 14 +- extension/llm/runner/llm_session.h | 18 +- extension/llm/runner/targets.bzl | 1 + .../llm/runner/test/test_text_llm_runner.cpp | 156 +++++++++--------- extension/llm/runner/text_llm_runner.cpp | 75 ++++++--- extension/llm/runner/text_llm_runner.h | 86 +++------- extension/llm/runner/text_llm_session.h | 43 +++++ extension/llm/runner/text_prefiller.cpp | 25 ++- extension/llm/runner/text_prefiller.h | 10 +- 10 files changed, 293 insertions(+), 229 deletions(-) create mode 100644 extension/llm/runner/text_llm_session.h diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 623d5d64453..6f01ddb15f8 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -341,7 +342,7 @@ std::unique_ptr create_text_llm_runner_from_program( // than re-loading it, and its loaded method allocates its own planned (KV) // memory. Whether packed weights are physically shared vs. re-materialized // per method instance is backend-dependent (serving_capacity() is the - // authority); on XNNPACK assume per-instance. + // authority). constexpr uint32_t kMaxCachedMemoryBytes = 1024 * 1024 * 10; // 10MB auto module = std::make_unique( std::move(program), @@ -352,50 +353,67 @@ std::unique_ptr create_text_llm_runner_from_program( std::move(module), std::move(tokenizer), temperature, method_name); } -namespace { +namespace detail { // The TextLLM adapter: implements the model-agnostic LLMSession over a -// TextLLMRunner. TextLLMRunner is an implementation detail here — the engine -// and server depend only on LLMSession. -class TextLLMSession : public LLMSession { - public: - explicit TextLLMSession(std::unique_ptr runner) - : runner_(std::move(runner)) {} - - Error prefill_tokens(std::vector tokens) override { - return runner_->prefill_tokens(std::move(tokens)).error(); - } - ::executorch::runtime::Result decode_one( - const SamplingConfig& sampling) override { - // Only temperature is plumbed today; top_p/top_k/seed need a per-session - // sampler (a follow-up). Reject non-default values rather than silently - // ignoring them, so callers can't assume constraints are applied. - if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { +// TextLLMRunner. TextLLMRunner's token-step methods are private; this adapter +// is their only (friended) caller, so the engine and server depend solely on +// LLMSession. +TextLLMSession::TextLLMSession(std::unique_ptr runner) + : runner_(std::move(runner)) {} + +Error TextLLMSession::prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling) { + // The model samples the FIRST generated token during prefill, so apply the + // request's sampling here (not a stale default). Only temperature is + // plumbed; reject non-default top_p/top_k/seed for parity with decode_one(). + float temperature = -1.0f; + if (initial_sampling != nullptr) { + if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 || + initial_sampling->seed != 0) { ET_LOG( Error, - "TextLLMSession: only temperature is supported; top_p/top_k/seed are " - "not yet implemented"); + "TextLLMSession: only temperature is supported; top_p/top_k/seed " + "are not yet implemented"); return ::executorch::runtime::Error::NotSupported; } - return runner_->decode_one(sampling.temperature); - } - Error seek(int64_t pos) override { - return runner_->seek(pos); - } - int64_t position() const override { - return runner_->position(); - } - Error reset() override { - runner_->reset(); - return Error::Ok; + temperature = initial_sampling->temperature; } - void stop() override { - runner_->stop(); + return runner_->prefill_tokens(std::move(tokens), temperature).error(); +} + +::executorch::runtime::Result TextLLMSession::decode_one( + const SamplingConfig& sampling) { + // Only temperature is plumbed today; top_p/top_k/seed need a per-session + // sampler (a follow-up). Reject non-default values rather than silently + // ignoring them, so callers can't assume constraints are applied. + if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + ET_LOG( + Error, + "TextLLMSession: only temperature is supported; top_p/top_k/seed are " + "not yet implemented"); + return ::executorch::runtime::Error::NotSupported; } + return runner_->decode_one(sampling.temperature); +} + +Error TextLLMSession::seek(int64_t pos) { + return runner_->seek(pos); +} + +int64_t TextLLMSession::position() const { + return runner_->position(); +} + +Error TextLLMSession::reset() { + runner_->reset(); + return Error::Ok; +} - private: - std::unique_ptr runner_; -}; -} // namespace +void TextLLMSession::stop() { + runner_->stop(); +} +} // namespace detail TextLLMEngine::TextLLMEngine( std::unique_ptr loader_module, @@ -487,7 +505,7 @@ TextLLMEngine::create_session() { return Error::InvalidState; } return std::unique_ptr( - std::make_unique(std::move(runner))); + std::make_unique(std::move(runner))); } std::unique_ptr create_multimodal_runner( diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index de1ecc743e3..ab28a644cc2 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -150,8 +150,7 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( * own method state and KV cache. This is the per-session construction path for * TextLLMEngine — N sessions reuse one loaded Program but isolate their mutable * KV state. Whether they also avoid re-materializing packed weights per session - * is backend-dependent (serving_capacity() is authoritative; XNNPACK repacks - * per method instance, so assume per-session weights there). + * is backend-dependent (serving_capacity() is authoritative). * * The caller must keep the DataLoader backing `program` alive for the lifetime * of every runner created from it (TextLLMEngine holds the loader Module). @@ -178,10 +177,8 @@ create_text_llm_runner_from_program( * is the correctness-first foundation for serving multiple conversations. * Backend execution should be serialized by the caller until per-backend thread * safety is proven (Module::execute is not assumed thread-safe). Whether extra - * sessions actually avoid duplicating packed weights is a backend property - * (e.g. AOTI/CUDA share device weights) reported by serving_capacity(); on the - * XNNPACK path weights are repacked per method instance and the KV cache is - * baked into the .pte, so it conservatively reports a single physical session. + * sessions avoid duplicating packed weights is backend-dependent and reported + * by serving_capacity() (conservatively one). */ class ET_EXPERIMENTAL TextLLMEngine : public LLMEngine { public: @@ -198,9 +195,8 @@ class ET_EXPERIMENTAL TextLLMEngine : public LLMEngine { // serving_capacity). ::executorch::runtime::Result> create_session() override; - // Conservative v1: a self-contained .pte repacks XNNPACK weights per runtime, - // so we don't claim multiple physical sessions share weights. Raise this on a - // backend/artifact proven to share packed weights. + // Conservative: a single physical session (no proven cross-session weight + // sharing). Raise on a backend proven to share packed weights. LLMServingCapacity serving_capacity() const override { return LLMServingCapacity{}; } diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h index a82a48d6578..df5bab710bd 100644 --- a/extension/llm/runner/llm_session.h +++ b/extension/llm/runner/llm_session.h @@ -49,12 +49,10 @@ struct DecodeResult { /// capacity* concern (engine-level), distinct from how a session advances a /// conversation (LLMSession) — keep backend memory flags off LLMSession. struct LLMServingCapacity { - // Physical sessions (loaded runtimes) creatable without duplicating packed - // weights. Conservatively 1: a self-contained .pte with inline constants - // repacks weights per XNNPACK runtime, so N logical requests queue on one - // physical session (llama.cpp single-slot), not N copies of the model. A - // backend that provably shares packed weights (XNNWeightsCache with named - // external data; CUDA/AOTI shared device weights) can report >1. + // Physical sessions creatable without duplicating packed weights. + // Conservatively 1 (some backends repack weights per runtime, so extra + // sessions would copy the whole model); raise only on a backend proven to + // share packed weights. int32_t max_physical_sessions_without_weight_duplication = 1; // Planned bytes one session adds (KV + activations), for memory-budget // admission. 0 = unknown; the server skips the memory clamp. @@ -70,8 +68,14 @@ class ET_EXPERIMENTAL LLMSession { /// Prefill pre-tokenized input at the current position (call seek() first for /// prefix reuse). Must be non-empty and fit the context window. + /// + /// `initial_sampling` (optional): the sampling config for the FIRST generated + /// token, for backends that sample during prefill (e.g. in-graph sampling). + /// Pass it so the first token uses the request's sampling instead of a stale + /// default. Backends that only sample in decode_one() ignore it. virtual ::executorch::runtime::Error prefill_tokens( - std::vector tokens) = 0; + std::vector tokens, + const SamplingConfig* initial_sampling = nullptr) = 0; /// Decode one token from the pending state; looping reproduces a full /// generation while returning exact sampled token ids. A single decode_one() diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 2914078cd40..ae1569547c4 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -113,6 +113,7 @@ def define_common_targets(): name = "runner_lib" + aten_suffix, exported_headers = [ "text_llm_runner.h", + "text_llm_session.h", "llm_runner_helper.h", "llm_session.h", "constants.h", diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 8d97ec00fea..410b3be3b12 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -20,14 +21,17 @@ #include using namespace ::testing; +using executorch::extension::llm::DecodeResult; using executorch::extension::llm::GenerationConfig; using executorch::extension::llm::LLMServingCapacity; using executorch::extension::llm::LogitProcessor; +using executorch::extension::llm::SamplingConfig; using executorch::extension::llm::Stats; using executorch::extension::llm::TextDecoderRunner; using executorch::extension::llm::TextLLMRunner; using executorch::extension::llm::TextPrefiller; using executorch::extension::llm::TextTokenGenerator; +using executorch::extension::llm::detail::TextLLMSession; using executorch::runtime::Error; using executorch::runtime::Result; using executorch::runtime::testing::TensorFactory; @@ -96,8 +100,8 @@ class MockTextPrefiller : public TextPrefiller { MOCK_METHOD( Result, prefill, - (std::vector&, int64_t&), - ()); + (std::vector&, int64_t&, float), + (override)); MOCK_METHOD(::executorch::runtime::Error, load, (), ()); MOCK_METHOD(bool, is_loaded, (), ()); }; @@ -191,7 +195,7 @@ class RunnerTest : public Test { ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); // Set up default behavior for the text prefiller ON_CALL(*text_prefiller, prefill) - .WillByDefault([](const std::vector&, int64_t) { + .WillByDefault([](const std::vector&, int64_t, float) { return Result(4); }); @@ -231,12 +235,12 @@ class RunnerTest : public Test { auto tokenizer = createMockTokenizer(); auto text_decoder_runner = createMockTextDecoderRunner(); auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault( - [prefill_token](std::vector& tokens, int64_t& pos) { - pos += tokens.size(); - return Result(prefill_token); - }); + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([prefill_token]( + std::vector& tokens, int64_t& pos, float) { + pos += tokens.size(); + return Result(prefill_token); + }); ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); auto stats = std::make_unique(); auto text_token_generator = createTextTokenGenerator( @@ -283,8 +287,8 @@ TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { }); // Set up expectations for the text prefiller - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector&, int64_t&) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector&, int64_t&, float) { return (Result(4)); }); @@ -351,8 +355,8 @@ TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { }); // Set up expectations for the text prefiller - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector&, int64_t&) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector&, int64_t&, float) { return (Result(4)); }); @@ -437,8 +441,8 @@ TEST_F(RunnerTest, PrefillReturnsNextToken) { std::vector{1, 2, 3}); }); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector& tokens, int64_t& pos) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos, float) { pos += tokens.size(); return Result(42); }); @@ -483,8 +487,8 @@ TEST_F(RunnerTest, PrefillThenGenerateEmpty) { std::vector{1, 2, 3}); }); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector& tokens, int64_t& pos) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos, float) { pos += tokens.size(); return Result(4); }); @@ -534,46 +538,46 @@ TEST_F(RunnerTest, PrefillThenGenerateEmpty) { // prefill_tokens() must reject a suffix larger than the KV cache, mirroring the // capacity bound generate(prompt) enforces (prefill_tokens is the public // prefix-cache primitive and the only place this is checked for it). +// The token-step methods are private on TextLLMRunner (internal serving hooks); +// they are exercised through their sole friended caller, +// detail::TextLLMSession, which is the LLMSession surface the server/engine +// actually depend on. TEST_F(RunnerTest, PrefillTokensRejectsOverContext) { - auto runner = makeRunner(createDefaultMetadata()); // max_context_len = 128 - auto result = - runner->prefill_tokens(std::vector(200, 1)); // 200 > 128 - EXPECT_FALSE(result.ok()); - EXPECT_EQ(result.error(), Error::InvalidArgument); + TextLLMSession session(makeRunner(createDefaultMetadata())); // context = 128 + EXPECT_EQ( + session.prefill_tokens(std::vector(200, 1)), // 200 > 128 + Error::InvalidArgument); } // seek() + prefill_tokens() across the boundary is rejected: a valid prefill // followed by a suffix that pushes pos_ past max_context_len must fail. TEST_F(RunnerTest, PrefillTokensRejectsWhenPosPlusSuffixExceedsContext) { - auto runner = makeRunner(createDefaultMetadata()); // 128 - auto ok = runner->prefill_tokens(std::vector(100, 1)); - EXPECT_TRUE(ok.ok()); - EXPECT_EQ(runner->position(), 100); - auto bad = - runner->prefill_tokens(std::vector(50, 1)); // 100 + 50 > 128 - EXPECT_FALSE(bad.ok()); - EXPECT_EQ(bad.error(), Error::InvalidArgument); - EXPECT_EQ(runner->position(), 100); // rejected before advancing + TextLLMSession session(makeRunner(createDefaultMetadata())); // 128 + EXPECT_EQ(session.prefill_tokens(std::vector(100, 1)), Error::Ok); + EXPECT_EQ(session.position(), 100); + EXPECT_EQ( + session.prefill_tokens(std::vector(50, 1)), // 100 + 50 > 128 + Error::InvalidArgument); + EXPECT_EQ(session.position(), 100); // rejected before advancing } // Empty tokens are rejected. TEST_F(RunnerTest, PrefillTokensRejectsEmpty) { - auto runner = makeRunner(createDefaultMetadata()); - auto result = runner->prefill_tokens(std::vector{}); - EXPECT_FALSE(result.ok()); - EXPECT_EQ(result.error(), Error::InvalidArgument); + TextLLMSession session(makeRunner(createDefaultMetadata())); + EXPECT_EQ( + session.prefill_tokens(std::vector{}), Error::InvalidArgument); } // position() tracks prefilled tokens; seek() rewinds within range and rejects // out-of-range targets. TEST_F(RunnerTest, SeekAndPositionTrackResidentTokens) { - auto runner = makeRunner(createDefaultMetadata()); - EXPECT_EQ(runner->position(), 0); - EXPECT_TRUE(runner->prefill_tokens(std::vector(10, 1)).ok()); - EXPECT_EQ(runner->position(), 10); - EXPECT_EQ(runner->seek(5), Error::Ok); - EXPECT_EQ(runner->position(), 5); - EXPECT_EQ(runner->seek(999), Error::InvalidArgument); // past current position + TextLLMSession session(makeRunner(createDefaultMetadata())); + EXPECT_EQ(session.position(), 0); + EXPECT_EQ(session.prefill_tokens(std::vector(10, 1)), Error::Ok); + EXPECT_EQ(session.position(), 10); + EXPECT_EQ(session.seek(5), Error::Ok); + EXPECT_EQ(session.position(), 5); + EXPECT_EQ(session.seek(999), Error::InvalidArgument); // past current position } // seek() is refused on sliding-window models (max_seq_len < max_context_len), @@ -581,33 +585,33 @@ TEST_F(RunnerTest, SeekAndPositionTrackResidentTokens) { TEST_F(RunnerTest, SeekRejectedForSlidingWindow) { auto md = createDefaultMetadata(); md["get_max_seq_len"] = 64; // < get_max_context_len (128) => sliding window - auto runner = makeRunner(md); - EXPECT_EQ(runner->seek(0), Error::NotSupported); + TextLLMSession session(makeRunner(md)); + EXPECT_EQ(session.seek(0), Error::NotSupported); } // decode_one() emits the pending token id exactly, then forwards it (advancing // position by one). Looping it reproduces a generation while exposing ids. TEST_F(RunnerTest, DecodeOneReturnsExactTokenIdAndAdvances) { - auto runner = makeRunner(createDefaultMetadata()); - ASSERT_TRUE(runner->prefill_tokens(std::vector{1, 2, 3}).ok()); - EXPECT_EQ(runner->position(), 3); + TextLLMSession session(makeRunner(createDefaultMetadata())); + ASSERT_EQ(session.prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + EXPECT_EQ(session.position(), 3); - auto r1 = runner->decode_one(0.0f); + auto r1 = session.decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r1.ok()); EXPECT_EQ(r1.get().token_id, 42u); // the prefill-pending token (mock prefill) EXPECT_FALSE(r1.get().is_eos); - EXPECT_EQ(runner->position(), 4); // forwarded one token + EXPECT_EQ(session.position(), 4); // forwarded one token - auto r2 = runner->decode_one(0.0f); + auto r2 = session.decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r2.ok()); EXPECT_EQ(r2.get().token_id, 3u); // argmax of canned logits {.1,.2,.3,.4} - EXPECT_EQ(runner->position(), 5); + EXPECT_EQ(session.position(), 5); } // decode_one() without a pending token (no prior prefill) must error. TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { - auto runner = makeRunner(createDefaultMetadata()); - EXPECT_FALSE(runner->decode_one(0.0f).ok()); + TextLLMSession session(makeRunner(createDefaultMetadata())); + EXPECT_FALSE(session.decode_one(SamplingConfig{0.0f}).ok()); } // decode_one() must stop at EOS WITHOUT forwarding it (like generate()): the @@ -615,32 +619,32 @@ TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { // token remains — so prefix reuse stays correct and a further decode_one() // errors. (The fixture's EOS id is 100.) TEST_F(RunnerTest, DecodeOneStopsAtEosWithoutForwarding) { - auto runner = - makeRunner(createDefaultMetadata(), nullptr, /*prefill_token=*/100); - ASSERT_TRUE(runner->prefill_tokens(std::vector{1, 2, 3}).ok()); - EXPECT_EQ(runner->position(), 3); + TextLLMSession session( + makeRunner(createDefaultMetadata(), nullptr, /*prefill_token=*/100)); + ASSERT_EQ(session.prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + EXPECT_EQ(session.position(), 3); - auto r = runner->decode_one(0.0f); + auto r = session.decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r.ok()); EXPECT_EQ(r.get().token_id, 100u); EXPECT_TRUE(r.get().is_eos); - EXPECT_EQ(runner->position(), 3); // EOS not forwarded -> position unchanged + EXPECT_EQ(session.position(), 3); // EOS not forwarded -> position unchanged // No pending token remains -> a further decode_one() errors. - EXPECT_FALSE(runner->decode_one(0.0f).ok()); + EXPECT_FALSE(session.decode_one(SamplingConfig{0.0f}).ok()); } // decode_one() is the last-resort safety net for session drivers: even if a // caller forgets to resolve max_new_tokens, it must not step past KV capacity. TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { auto md = createDefaultMetadata(); // max_context_len = 128 - auto runner = makeRunner(md); - ASSERT_TRUE(runner->prefill_tokens(std::vector(127, 1)).ok()); - auto r1 = runner->decode_one(0.0f); + TextLLMSession session(makeRunner(md)); + ASSERT_EQ(session.prefill_tokens(std::vector(127, 1)), Error::Ok); + auto r1 = session.decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r1.ok()); - EXPECT_EQ(runner->position(), 128); + EXPECT_EQ(session.position(), 128); - auto r2 = runner->decode_one(0.0f); + auto r2 = session.decode_one(SamplingConfig{0.0f}); EXPECT_FALSE(r2.ok()); EXPECT_EQ(r2.error(), Error::InvalidArgument); } @@ -650,22 +654,22 @@ TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { // makes the next sampled (pending) token 2 — proving the session decode path // honors grammar/tool masks/penalties and can't diverge from generate(). TEST_F(RunnerTest, DecodeOneAppliesLogitProcessors) { - auto runner = makeRunner( - createDefaultMetadata(), std::make_shared(3)); - ASSERT_TRUE(runner->prefill_tokens(std::vector{1, 2, 3}).ok()); + TextLLMSession session(makeRunner( + createDefaultMetadata(), std::make_shared(3))); + ASSERT_EQ(session.prefill_tokens(std::vector{1, 2, 3}), Error::Ok); // r1 emits the prefill-pending token (42, not sampled here); its forward pass // samples the next pending token from masked logits. - ASSERT_TRUE(runner->decode_one(0.0f).ok()); - auto r2 = runner->decode_one(0.0f); + ASSERT_TRUE(session.decode_one(SamplingConfig{0.0f}).ok()); + auto r2 = session.decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r2.ok()); EXPECT_EQ(r2.get().token_id, 2u); // token 3 masked -> argmax is now 2 } -// v1 serving capacity is conservatively single-slot: a self-contained .pte -// repacks XNNPACK weights per runtime, so we don't claim shared physical -// sessions. (TextLLMEngine::serving_capacity() returns this default; the -// engine-backed end-to-end check is in the pybinding test.) +// Serving capacity is conservatively a single physical session by default +// (no proven cross-session weight sharing). TextLLMEngine::serving_capacity() +// returns this default; the engine-backed end-to-end check is in the pybinding +// test. TEST_F(RunnerTest, ServingCapacityIsSingleSlotByDefault) { LLMServingCapacity cap; EXPECT_EQ(cap.max_physical_sessions_without_weight_duplication, 1); @@ -773,8 +777,8 @@ TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) { std::vector{1, 2, 3}); }); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector& tokens, int64_t& pos) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos, float) { pos += tokens.size(); return Result(4); }); diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 7135dd020a6..c03d8ebb35d 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -113,6 +113,12 @@ Error TextLLMRunner::generate( int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); + // Resolve sampling temperature once: the first token is sampled during + // prefill and the rest in the token generator, so both must use the same + // temperature. + const float resolved_temp = + temperature_ == -1.0f ? config.temperature : temperature_; + uint64_t cur_token = 0; int num_prompt_tokens = 0; std::vector prompt_tokens; @@ -177,7 +183,8 @@ Error TextLLMRunner::generate( // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + auto prefill_res = + text_prefiller_->prefill(prompt_tokens, pos_, resolved_temp); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); cur_token = prefill_res.get(); prefill_next_token_.reset(); @@ -235,10 +242,6 @@ Error TextLLMRunner::generate( // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); - // Use the configuration's temperature - float resolved_temp = - temperature_ == -1.0f ? config.temperature : temperature_; - // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, resolved_temp, wrapped_callback); @@ -364,13 +367,16 @@ void TextLLMRunner::reset() { } ::executorch::runtime::Error TextLLMRunner::seek(int64_t pos) { - // Sliding-window / ring-buffer models (max_seq_len < max_context_len) recycle - // KV space, so pos_ is not an absolute position and the prefix [0, pos) may - // have slid out of the window. Rewinding would attend to stale/overwritten KV - // and silently corrupt output. Refuse (fail-safe) so the caller falls back to - // reset() + full re-prefill — the same conservative choice vLLM - // (common_prefix_blocks=0 for SWA layers) and llama.cpp (seq_rm/get_can_shift - // fail for SWA) make. + // Token-step primitives require a KV cache (a non-KV model has no resident KV + // to rewind); fail closed. + if (metadata_.at(kUseKVCache) == 0) { + ET_LOG(Error, "seek() requires a KV-cache model (use_kv_cache=true)"); + return ::executorch::runtime::Error::NotSupported; + } + // Sliding-window models (max_seq_len < max_context_len) recycle KV space, so + // pos_ is not an absolute position and the prefix [0, pos) may have slid out + // of the window; rewinding would attend to stale KV. Refuse so the caller + // falls back to reset() + full re-prefill. if (metadata_.at(kMaxSeqLen) < metadata_.at(kMaxContextLen)) { ET_LOG( Error, @@ -391,10 +397,18 @@ ::executorch::runtime::Error TextLLMRunner::seek(int64_t pos) { } ::executorch::runtime::Result TextLLMRunner::prefill_tokens( - std::vector tokens) { + std::vector tokens, + float temperature) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } + // The token-step primitives assume KV-cached decode (each step forwards only + // the new token at pos_). A non-KV model needs full-sequence re-forwarding, + // which this path does not implement — fail closed rather than decode wrong. + ET_CHECK_OR_RETURN_ERROR( + metadata_.at(kUseKVCache) != 0, + NotSupported, + "prefill_tokens/decode_one require a KV-cache model (use_kv_cache=true)"); if (tokens.empty()) { ET_LOG(Error, "prefill_tokens called with empty tokens"); return ::executorch::runtime::Error::InvalidArgument; @@ -424,7 +438,12 @@ ::executorch::runtime::Result TextLLMRunner::prefill_tokens( num_tokens, max_context_len); } - auto prefill_res = text_prefiller_->prefill(tokens, pos_); + // Resolve temperature like decode_one() so the first token (sampled here in + // prefill) honors the request instead of defaulting to greedy. + const float temp = (temperature < 0.0f) + ? (temperature_ == -1.0f ? 0.0f : temperature_) + : temperature; + auto prefill_res = text_prefiller_->prefill(tokens, pos_, temp); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); prefill_next_token_ = prefill_res.get(); prev_decode_token_.reset(); @@ -436,19 +455,16 @@ ::executorch::runtime::Result TextLLMRunner::decode_one( if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } + // See prefill_tokens(): single-token KV stepping is invalid without a KV + // cache. + ET_CHECK_OR_RETURN_ERROR( + metadata_.at(kUseKVCache) != 0, + NotSupported, + "decode_one requires a KV-cache model (use_kv_cache=true)"); ET_CHECK_OR_RETURN_ERROR( prefill_next_token_.has_value(), InvalidState, "decode_one requires a pending token; call prefill()/prefill_tokens() first"); - if (metadata_.at(kMaxSeqLen) >= metadata_.at(kMaxContextLen)) { - ET_CHECK_OR_RETURN_ERROR( - pos_ < metadata_.at(kMaxContextLen), - InvalidArgument, - "decode_one would exceed KV cache capacity: pos_ %" PRId64 - " >= max_context_len %" PRId64, - pos_, - metadata_.at(kMaxContextLen)); - } // The pending token is the one we emit this step. const uint64_t token = prefill_next_token_.value(); @@ -477,6 +493,19 @@ ::executorch::runtime::Result TextLLMRunner::decode_one( return DecodeResult{token, std::move(text_piece), true}; } + // Only a NON-EOS token is forwarded (made resident at pos_), so the capacity + // check belongs here — after the EOS short-circuit. This lets the final EOS + // be emitted even when the KV cache is exactly full. + if (metadata_.at(kMaxSeqLen) >= metadata_.at(kMaxContextLen)) { + ET_CHECK_OR_RETURN_ERROR( + pos_ < metadata_.at(kMaxContextLen), + InvalidArgument, + "decode_one would exceed KV cache capacity: pos_ %" PRId64 + " >= max_context_len %" PRId64, + pos_, + metadata_.at(kMaxContextLen)); + } + // Forward `token` at pos_ to predict the next pending token. std::vector tok_data = {token}; std::vector<::executorch::aten::SizesType> shape = {1, 1}; diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 471c5628b47..90032ae8b59 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -33,6 +33,10 @@ namespace executorch::extension::llm { +namespace detail { +class TextLLMSession; +} // namespace detail + class ET_EXPERIMENTAL TextLLMRunner : public IRunner { public: /** @@ -155,86 +159,34 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { void reset() override; /** - * @brief Truncate/rewind the KV cache to `pos` tokens. - * - * Sets the cache cursor so that subsequent attention reads positions - * [0, pos) and the next prefill() overwrites starting at `pos`. This enables - * prefix reuse across turns: rewind to the length of the prefix shared with - * the previous request, then prefill only the new suffix instead of - * re-prefilling the whole prompt. reset() is the special case pos == 0. - * - * Because the KV buffers are addressed by the position cursor (reset() does - * not clear them), no buffer manipulation is needed — stale entries beyond - * `pos` are ignored by attention and overwritten by the next prefill. - * - * @note Valid only for models exported with max_seq_len == max_context_len, - * where the cursor is an absolute position. Sliding-window/chunked models use - * a ring buffer in which positions are not absolute; do not seek() those. + * @brief Stops the ongoing text generation process * - * @param pos Target cache length in tokens; must be in [0, current position]. - * @return Error::Ok on success; Error::InvalidArgument if `pos` is out of - * range; Error::NotSupported for sliding-window/chunked models - * (max_seq_len < max_context_len), where seek is unsafe — callers should fall - * back to reset() + full prefill. + * This method signals the generator to stop producing new tokens and + * terminate the current generation process. */ + void stop() override; + + private: + // Internal serving hooks: the single-token-step primitives that detail:: + // TextLLMSession composes into the LLMSession serving API. Not a public or + // ABI-stable surface — call them through LLMEngine/LLMSession, never + // directly. Friending the adapter keeps them out of every other caller (and + // out of the Python bindings) so nothing can take a dependency on this shape. + friend class detail::TextLLMSession; + ::executorch::runtime::Error seek(int64_t pos); - /** - * @brief Prefill pre-tokenized input at the current KV cache position. - * - * Like prefill(prompt), but takes token ids directly instead of a string, so - * the caller controls the exact tokens written — required for prefix reuse, - * where the server computes the shared-prefix length, seek()s to it, and - * prefills only the new suffix tokens (no detokenize/re-tokenize round trip - * that could write mismatched KV). - * - * Tokens are written starting at the current position (call seek() first to - * position the cursor). The predicted next token is stored for a following - * generate("") call. - * - * @param tokens The token ids to prefill. Must be non-empty. - * @return The next token predicted after prefill, or an error. - */ ::executorch::runtime::Result prefill_tokens( - std::vector tokens); + std::vector tokens, + float temperature = -1.0f); - /** - * @brief Current KV cache position (number of tokens with resident KV). - * - * This is the upper bound for a valid seek(): tokens at positions [0, pos) - * have been through a forward pass and are reusable. Note the last *sampled* - * token of a generation is not forwarded, so it is not resident — callers - * tracking emitted tokens must cap any reuse length at this value. - */ int64_t position() const { return pos_; } - /** - * @brief Decodes a single token: emits the current pending token (predicted - * by the preceding prefill/prefill_tokens or decode_one), then forwards it to - * predict the next pending token. Calling this in a loop reproduces the token - * sequence of generate(), but returns the exact sampled token id (not just - * decoded text) — the canonical decode unit for prefix-cache id tracking and - * future batched/interleaved scheduling. - * - * Requires a pending token (a prior prefill/prefill_tokens). `temperature` - * follows GenerationConfig semantics (-1 => use the model default / greedy). - * - * @return DecodeResult{token_id, text_piece, is_eos}, or an error. - */ ::executorch::runtime::Result decode_one( float temperature = -1.0f); - /** - * @brief Stops the ongoing text generation process - * - * This method signals the generator to stop producing new tokens and - * terminate the current generation process. - */ - void stop() override; - - private: // Components std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; std::unordered_map metadata_; diff --git a/extension/llm/runner/text_llm_session.h b/extension/llm/runner/text_llm_session.h new file mode 100644 index 00000000000..700691f5317 --- /dev/null +++ b/extension/llm/runner/text_llm_session.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Internal adapter implementing the model-agnostic LLMSession over a +// TextLLMRunner. It lives in `detail` (not a public API) so it can be named as +// a friend of TextLLMRunner: this is the *only* caller of the runner's internal +// token-step hooks (prefill_tokens/decode_one/seek/position). Server and engine +// code depend on LLMSession alone, never on this type or on TextLLMRunner. + +#pragma once + +#include +#include + +#include +#include + +namespace executorch::extension::llm::detail { + +class TextLLMSession : public LLMSession { + public: + explicit TextLLMSession(std::unique_ptr runner); + + ::executorch::runtime::Error prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling = nullptr) override; + ::executorch::runtime::Result decode_one( + const SamplingConfig& sampling) override; + ::executorch::runtime::Error seek(int64_t pos) override; + int64_t position() const override; + ::executorch::runtime::Error reset() override; + void stop() override; + + private: + std::unique_ptr runner_; +}; + +} // namespace executorch::extension::llm::detail diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index a391cef01de..b87fc25a8ec 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -28,7 +28,8 @@ TextPrefiller::TextPrefiller( ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, - int64_t& start_pos) { + int64_t& start_pos, + float temperature) { ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null"); if (!text_decoder_runner_->is_method_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); @@ -54,8 +55,15 @@ ::executorch::runtime::Result TextPrefiller::prefill( num_tokens_to_prefill_with, prompt_tokens_to_process.begin()); - // Process this chunk - auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos); + // Process this chunk. Only the LAST chunk produces the first generated + // token, so apply `temperature` there; intermediate chunks just prefill. + const bool is_last_chunk = + num_tokens_to_process + num_tokens_to_prefill_with >= + num_prompt_tokens; + auto chunk_result = prefill_chunk( + prompt_tokens_to_process, + start_pos, + is_last_chunk ? temperature : 0.0f); ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); cur_token = chunk_result.get(); @@ -65,13 +73,14 @@ ::executorch::runtime::Result TextPrefiller::prefill( return cur_token; } else { // If prompt tokens don't exceed max_seq_len_, process them directly - return prefill_chunk(prompt_tokens, start_pos); + return prefill_chunk(prompt_tokens, start_pos, temperature); } } ::executorch::runtime::Result TextPrefiller::prefill_chunk( std::vector& prompt_tokens, - int64_t& start_pos) { + int64_t& start_pos, + float temperature) { // enable_parallel_prefill_ maybe set even when not using kv cache // When kv cache is not used, start pos is ignored int32_t num_prompt_tokens = prompt_tokens.size(); @@ -92,7 +101,8 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); start_pos += num_prompt_tokens; - cur_token = text_decoder_runner_->logits_to_token(outputs_res.get()); + cur_token = + text_decoder_runner_->logits_to_token(outputs_res.get(), temperature); } else { // sequential prefill int64_t pos = 0; // position in the sequence // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) @@ -128,7 +138,8 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( start_pos++; } - cur_token = text_decoder_runner_->logits_to_token(logits_tensor); + cur_token = + text_decoder_runner_->logits_to_token(logits_tensor, temperature); } return cur_token; } diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index a02cd3d1bf4..3ca961a15cb 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -32,22 +32,28 @@ class ET_EXPERIMENTAL TextPrefiller { * tokenizer. * @param start_pos The starting position in KV cache of the input in the LLM * Module. + * @param temperature Sampling temperature for the first generated token + * (which is sampled here during prefill). Defaults to greedy (0.0). * @return The next token of the LLM Module after prefill. */ virtual ::executorch::runtime::Result prefill( std::vector& prompt_tokens, - int64_t& start_pos); + int64_t& start_pos, + float temperature = 0.0f); /** * Helper method to prefill a chunk of tokens. * @param prompt_tokens The chunk of text prompt tokens to process. * @param start_pos The starting position in KV cache of the input in the LLM * Module. + * @param temperature Sampling temperature for the token produced by this + * chunk. Defaults to greedy (0.0). * @return The next token of the LLM Module after prefilling this chunk. */ virtual ::executorch::runtime::Result prefill_chunk( std::vector& prompt_tokens, - int64_t& start_pos); + int64_t& start_pos, + float temperature = 0.0f); /** * Load the necessary resources for the TextPrefiller. From 170f01db8eabc0e944ea86383b6768b7ca9a1566 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 4 Jun 2026 15:14:45 -0700 Subject: [PATCH 4/6] [UPDATE] Update [ghstack-poisoned] --- extension/llm/runner/llm_runner_helper.cpp | 15 +- extension/llm/runner/llm_runner_helper.h | 41 ++--- extension/llm/runner/llm_session.h | 17 +- extension/llm/runner/targets.bzl | 7 +- .../llm/runner/test/test_text_llm_runner.cpp | 155 ++++++++++++------ .../llm/runner/test/test_text_prefiller.cpp | 16 +- extension/llm/runner/test/test_util.cpp | 23 +++ extension/llm/runner/text_llm_runner.cpp | 47 +++++- extension/llm/runner/text_llm_runner.h | 21 ++- extension/llm/runner/util.h | 32 ++++ 10 files changed, 270 insertions(+), 104 deletions(-) diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 6f01ddb15f8..5063f6ef132 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -325,7 +325,12 @@ static std::unique_ptr assemble_text_llm_runner( temperature); } -std::unique_ptr create_text_llm_runner_from_program( +// Builds a TextLLMRunner over an already-loaded Program: the runner's Module +// reuses `program` while owning its own method state and KV cache. File-local — +// the per-session construction path for TextLLMEngine (which keeps the backing +// DataLoader alive for the runners' lifetime). External callers go through +// LLMEngine -> LLMSession, not a raw shared-Program runner. +static std::unique_ptr create_text_llm_runner_from_program( std::shared_ptr program, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, float temperature, @@ -413,6 +418,11 @@ Error TextLLMSession::reset() { void TextLLMSession::stop() { runner_->stop(); } + +std::unique_ptr make_text_llm_session( + std::unique_ptr runner) { + return std::make_unique(std::move(runner)); +} } // namespace detail TextLLMEngine::TextLLMEngine( @@ -504,8 +514,7 @@ TextLLMEngine::create_session() { ET_LOG(Error, "TextLLMEngine: failed to build session runner"); return Error::InvalidState; } - return std::unique_ptr( - std::make_unique(std::move(runner))); + return detail::make_text_llm_session(std::move(runner)); } std::unique_ptr create_multimodal_runner( diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index ab28a644cc2..c2eaf0c8ac3 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -142,33 +142,6 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( const std::string& method_name = "forward", Module::LoadMode load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors); -/** - * @brief Creates a TextLLMRunner over an already-loaded Program. - * - * Unlike create_text_llm_runner(model_path, ...), this does not load the model - * file again: the resulting runner's Module reuses `program` while owning its - * own method state and KV cache. This is the per-session construction path for - * TextLLMEngine — N sessions reuse one loaded Program but isolate their mutable - * KV state. Whether they also avoid re-materializing packed weights per session - * is backend-dependent (serving_capacity() is authoritative). - * - * The caller must keep the DataLoader backing `program` alive for the lifetime - * of every runner created from it (TextLLMEngine holds the loader Module). - * - * @param program Shared, already-loaded program. - * @param tokenizer Initialized tokenizer instance (owned by the new runner). - * @param temperature Optional temperature (deprecated; prefer - * GenerationConfig). - * @param method_name Name of the method to execute in the model. - * @return std::unique_ptr on success, or nullptr on failure. - */ -ET_EXPERIMENTAL std::unique_ptr -create_text_llm_runner_from_program( - std::shared_ptr program, - std::unique_ptr<::tokenizers::Tokenizer> tokenizer, - float temperature = -1.0f, - const std::string& method_name = "forward"); - /** * @brief Engine for multi-session text generation over one loaded Program. * @@ -225,6 +198,20 @@ class ET_EXPERIMENTAL TextLLMEngine : public LLMEngine { std::unordered_map metadata_; }; +namespace detail { +// Implementation detail (not a public API): wraps a TextLLMRunner in an +// LLMSession (the runner -> session seam). The supported entry point is +// LLMEngine::create_session(); this exists only so TextLLMEngine can build its +// sessions and so unit tests can drive the runner's token-step primitives +// through the public LLMSession surface (the concrete adapter type is private). +// Do not depend on wrapping arbitrary runners. +// +// @param runner A loaded TextLLMRunner; ownership transfers to the session. +// @return std::unique_ptr wrapping `runner`. +std::unique_ptr make_text_llm_session( + std::unique_ptr runner); +} // namespace detail + /** * @brief Creates a MultimodalRunner instance with dependency injection * diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h index df5bab710bd..1bfdb0a08a0 100644 --- a/extension/llm/runner/llm_session.h +++ b/extension/llm/runner/llm_session.h @@ -36,12 +36,19 @@ struct SamplingConfig { }; /// One decoded step: the exact sampled token id (for prefix-cache id tracking -/// and batching), its decoded text piece (raw bytes; may be a partial UTF-8 -/// sequence the caller assembles), and whether it is an EOS token. +/// and batching) and its decoded text piece (raw bytes; may be a partial UTF-8 +/// sequence the caller assembles). +/// +/// `is_eos` is literal: the sampled token is an end-of-sequence token (use it +/// for the "stop" finish reason, metrics, cache/accounting). `is_terminal` is +/// the loop signal: generation ended at this step — either EOS or a cooperative +/// stop() took effect. A decode loop should end when is_terminal is set; every +/// EOS step is also terminal, but a stop step is terminal without being EOS. struct DecodeResult { uint64_t token_id; std::string text_piece; bool is_eos; + bool is_terminal; }; /// How many physical sessions an engine can host, so the server admits logical @@ -96,8 +103,10 @@ class ET_EXPERIMENTAL LLMSession { /// Request that a decode_one() loop stop. This is a TOKEN-BOUNDARY, /// cooperative stop: it is safe to call from another thread, but it does not - /// abort a decode_one() that is already running — it takes effect before the - /// next decode_one() (the loop driver checks between tokens). + /// abort a decode_one() that is already running. It takes effect at the next + /// decode_one(), which then returns a terminal step (is_terminal set, is_eos + /// false) without forwarding a new token. The stop is cleared by the next + /// prefill_tokens() or reset(). virtual void stop() = 0; }; diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index ae1569547c4..60ba47e9013 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -113,11 +113,16 @@ def define_common_targets(): name = "runner_lib" + aten_suffix, exported_headers = [ "text_llm_runner.h", - "text_llm_session.h", "llm_runner_helper.h", "llm_session.h", "constants.h", ], + # Internal: the detail::TextLLMSession adapter (sole friended caller + # of TextLLMRunner's token-step hooks). Private so dependents reach + # it only through LLMEngine/LLMSession + make_text_llm_session(). + headers = [ + "text_llm_session.h", + ], srcs = [ "text_llm_runner.cpp", "llm_runner_helper.cpp", diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 410b3be3b12..e55eb06b545 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -31,7 +30,7 @@ using executorch::extension::llm::TextDecoderRunner; using executorch::extension::llm::TextLLMRunner; using executorch::extension::llm::TextPrefiller; using executorch::extension::llm::TextTokenGenerator; -using executorch::extension::llm::detail::TextLLMSession; +using executorch::extension::llm::detail::make_text_llm_session; using executorch::runtime::Error; using executorch::runtime::Result; using executorch::runtime::testing::TensorFactory; @@ -231,7 +230,8 @@ class RunnerTest : public Test { std::unique_ptr makeRunner( std::unordered_map metadata, std::shared_ptr logit_processor = nullptr, - uint64_t prefill_token = 42) { + uint64_t prefill_token = 42, + MockTextPrefiller** out_prefiller = nullptr) { auto tokenizer = createMockTokenizer(); auto text_decoder_runner = createMockTextDecoderRunner(); auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); @@ -242,6 +242,11 @@ class RunnerTest : public Test { return Result(prefill_token); }); ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + // Expose the mock prefiller (still owned by the runner) so a test can set + // expectations on it. + if (out_prefiller != nullptr) { + *out_prefiller = text_prefiller.get(); + } auto stats = std::make_unique(); auto text_token_generator = createTextTokenGenerator( tokenizer.get(), text_decoder_runner.get(), stats.get()); @@ -535,49 +540,73 @@ TEST_F(RunnerTest, PrefillThenGenerateEmpty) { EXPECT_EQ(counter.getCount(), config.max_new_tokens); } +// prefill(prompt, GenerationConfig) must pass config.temperature to the +// prefiller: the first token is sampled during prefill, so dropping it would +// silently sample greedy regardless of the requested temperature. +TEST_F(RunnerTest, PrefillWithConfigPassesTemperatureToPrefiller) { + MockTextPrefiller* prefiller = nullptr; + auto runner = makeRunner( + createDefaultMetadata(), nullptr, /*prefill_token=*/42, &prefiller); + ASSERT_NE(prefiller, nullptr); + + EXPECT_CALL(*prefiller, prefill(_, _, FloatEq(0.7f))) + .WillOnce([](std::vector& tokens, int64_t& pos, float) { + pos += tokens.size(); + return Result(42); + }); + + GenerationConfig config; + config.temperature = 0.7f; + auto result = runner->prefill("hello", config); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.get(), 42); +} + +// The token-step methods are private on TextLLMRunner (internal serving hooks); +// they are exercised through the public LLMSession surface returned by +// make_text_llm_session() — the same path the engine/server depend on. +// // prefill_tokens() must reject a suffix larger than the KV cache, mirroring the -// capacity bound generate(prompt) enforces (prefill_tokens is the public +// capacity bound generate(prompt) enforces (prefill_tokens is the // prefix-cache primitive and the only place this is checked for it). -// The token-step methods are private on TextLLMRunner (internal serving hooks); -// they are exercised through their sole friended caller, -// detail::TextLLMSession, which is the LLMSession surface the server/engine -// actually depend on. TEST_F(RunnerTest, PrefillTokensRejectsOverContext) { - TextLLMSession session(makeRunner(createDefaultMetadata())); // context = 128 + auto session = + make_text_llm_session(makeRunner(createDefaultMetadata())); // context 128 EXPECT_EQ( - session.prefill_tokens(std::vector(200, 1)), // 200 > 128 + session->prefill_tokens(std::vector(200, 1)), // 200 > 128 Error::InvalidArgument); } // seek() + prefill_tokens() across the boundary is rejected: a valid prefill // followed by a suffix that pushes pos_ past max_context_len must fail. TEST_F(RunnerTest, PrefillTokensRejectsWhenPosPlusSuffixExceedsContext) { - TextLLMSession session(makeRunner(createDefaultMetadata())); // 128 - EXPECT_EQ(session.prefill_tokens(std::vector(100, 1)), Error::Ok); - EXPECT_EQ(session.position(), 100); + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_EQ(session->prefill_tokens(std::vector(100, 1)), Error::Ok); + EXPECT_EQ(session->position(), 100); EXPECT_EQ( - session.prefill_tokens(std::vector(50, 1)), // 100 + 50 > 128 + session->prefill_tokens(std::vector(50, 1)), // 100 + 50 > 128 Error::InvalidArgument); - EXPECT_EQ(session.position(), 100); // rejected before advancing + EXPECT_EQ(session->position(), 100); // rejected before advancing } // Empty tokens are rejected. TEST_F(RunnerTest, PrefillTokensRejectsEmpty) { - TextLLMSession session(makeRunner(createDefaultMetadata())); + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); EXPECT_EQ( - session.prefill_tokens(std::vector{}), Error::InvalidArgument); + session->prefill_tokens(std::vector{}), Error::InvalidArgument); } // position() tracks prefilled tokens; seek() rewinds within range and rejects // out-of-range targets. TEST_F(RunnerTest, SeekAndPositionTrackResidentTokens) { - TextLLMSession session(makeRunner(createDefaultMetadata())); - EXPECT_EQ(session.position(), 0); - EXPECT_EQ(session.prefill_tokens(std::vector(10, 1)), Error::Ok); - EXPECT_EQ(session.position(), 10); - EXPECT_EQ(session.seek(5), Error::Ok); - EXPECT_EQ(session.position(), 5); - EXPECT_EQ(session.seek(999), Error::InvalidArgument); // past current position + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_EQ(session->position(), 0); + EXPECT_EQ(session->prefill_tokens(std::vector(10, 1)), Error::Ok); + EXPECT_EQ(session->position(), 10); + EXPECT_EQ(session->seek(5), Error::Ok); + EXPECT_EQ(session->position(), 5); + EXPECT_EQ( + session->seek(999), Error::InvalidArgument); // past current position } // seek() is refused on sliding-window models (max_seq_len < max_context_len), @@ -585,33 +614,34 @@ TEST_F(RunnerTest, SeekAndPositionTrackResidentTokens) { TEST_F(RunnerTest, SeekRejectedForSlidingWindow) { auto md = createDefaultMetadata(); md["get_max_seq_len"] = 64; // < get_max_context_len (128) => sliding window - TextLLMSession session(makeRunner(md)); - EXPECT_EQ(session.seek(0), Error::NotSupported); + auto session = make_text_llm_session(makeRunner(md)); + EXPECT_EQ(session->seek(0), Error::NotSupported); } // decode_one() emits the pending token id exactly, then forwards it (advancing // position by one). Looping it reproduces a generation while exposing ids. TEST_F(RunnerTest, DecodeOneReturnsExactTokenIdAndAdvances) { - TextLLMSession session(makeRunner(createDefaultMetadata())); - ASSERT_EQ(session.prefill_tokens(std::vector{1, 2, 3}), Error::Ok); - EXPECT_EQ(session.position(), 3); + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + EXPECT_EQ(session->position(), 3); - auto r1 = session.decode_one(SamplingConfig{0.0f}); + auto r1 = session->decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r1.ok()); EXPECT_EQ(r1.get().token_id, 42u); // the prefill-pending token (mock prefill) EXPECT_FALSE(r1.get().is_eos); - EXPECT_EQ(session.position(), 4); // forwarded one token + EXPECT_FALSE(r1.get().is_terminal); + EXPECT_EQ(session->position(), 4); // forwarded one token - auto r2 = session.decode_one(SamplingConfig{0.0f}); + auto r2 = session->decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r2.ok()); EXPECT_EQ(r2.get().token_id, 3u); // argmax of canned logits {.1,.2,.3,.4} - EXPECT_EQ(session.position(), 5); + EXPECT_EQ(session->position(), 5); } // decode_one() without a pending token (no prior prefill) must error. TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { - TextLLMSession session(makeRunner(createDefaultMetadata())); - EXPECT_FALSE(session.decode_one(SamplingConfig{0.0f}).ok()); + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_FALSE(session->decode_one(SamplingConfig{0.0f}).ok()); } // decode_one() must stop at EOS WITHOUT forwarding it (like generate()): the @@ -619,32 +649,57 @@ TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { // token remains — so prefix reuse stays correct and a further decode_one() // errors. (The fixture's EOS id is 100.) TEST_F(RunnerTest, DecodeOneStopsAtEosWithoutForwarding) { - TextLLMSession session( + auto session = make_text_llm_session( makeRunner(createDefaultMetadata(), nullptr, /*prefill_token=*/100)); - ASSERT_EQ(session.prefill_tokens(std::vector{1, 2, 3}), Error::Ok); - EXPECT_EQ(session.position(), 3); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + EXPECT_EQ(session->position(), 3); - auto r = session.decode_one(SamplingConfig{0.0f}); + auto r = session->decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r.ok()); EXPECT_EQ(r.get().token_id, 100u); EXPECT_TRUE(r.get().is_eos); - EXPECT_EQ(session.position(), 3); // EOS not forwarded -> position unchanged + EXPECT_TRUE(r.get().is_terminal); + EXPECT_EQ(session->position(), 3); // EOS not forwarded -> position unchanged // No pending token remains -> a further decode_one() errors. - EXPECT_FALSE(session.decode_one(SamplingConfig{0.0f}).ok()); + EXPECT_FALSE(session->decode_one(SamplingConfig{0.0f}).ok()); +} + +// stop() is honored at the next decode_one(): it returns a terminal step +// (is_terminal set, is_eos false) WITHOUT forwarding, position() does not +// advance, and no pending token remains; a fresh prefill_tokens() clears the +// stop so decoding resumes. +TEST_F(RunnerTest, DecodeOneHonorsStop) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + + session->stop(); + auto r = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r.ok()); + EXPECT_FALSE(r.get().is_eos); // a stop is not an EOS token + EXPECT_TRUE(r.get().is_terminal); // but it ends the loop + EXPECT_EQ(session->position(), 3); // not forwarded + EXPECT_FALSE(session->decode_one(SamplingConfig{0.0f}).ok()); // no pending + + // A new prefill clears the stop; decoding resumes. + ASSERT_EQ(session->prefill_tokens(std::vector{4, 5}), Error::Ok); + auto r2 = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r2.ok()); + EXPECT_FALSE(r2.get().is_eos); + EXPECT_FALSE(r2.get().is_terminal); } // decode_one() is the last-resort safety net for session drivers: even if a // caller forgets to resolve max_new_tokens, it must not step past KV capacity. TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { auto md = createDefaultMetadata(); // max_context_len = 128 - TextLLMSession session(makeRunner(md)); - ASSERT_EQ(session.prefill_tokens(std::vector(127, 1)), Error::Ok); - auto r1 = session.decode_one(SamplingConfig{0.0f}); + auto session = make_text_llm_session(makeRunner(md)); + ASSERT_EQ(session->prefill_tokens(std::vector(127, 1)), Error::Ok); + auto r1 = session->decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r1.ok()); - EXPECT_EQ(session.position(), 128); + EXPECT_EQ(session->position(), 128); - auto r2 = session.decode_one(SamplingConfig{0.0f}); + auto r2 = session->decode_one(SamplingConfig{0.0f}); EXPECT_FALSE(r2.ok()); EXPECT_EQ(r2.error(), Error::InvalidArgument); } @@ -654,14 +709,14 @@ TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { // makes the next sampled (pending) token 2 — proving the session decode path // honors grammar/tool masks/penalties and can't diverge from generate(). TEST_F(RunnerTest, DecodeOneAppliesLogitProcessors) { - TextLLMSession session(makeRunner( + auto session = make_text_llm_session(makeRunner( createDefaultMetadata(), std::make_shared(3))); - ASSERT_EQ(session.prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); // r1 emits the prefill-pending token (42, not sampled here); its forward pass // samples the next pending token from masked logits. - ASSERT_TRUE(session.decode_one(SamplingConfig{0.0f}).ok()); - auto r2 = session.decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(session->decode_one(SamplingConfig{0.0f}).ok()); + auto r2 = session->decode_one(SamplingConfig{0.0f}); ASSERT_TRUE(r2.ok()); EXPECT_EQ(r2.get().token_id, 2u); // token 3 masked -> argmax is now 2 } diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp index 5ed5031dace..cba44ac8ab6 100644 --- a/extension/llm/runner/test/test_text_prefiller.cpp +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -79,8 +79,8 @@ class TextPrefillerTest : public Test { MOCK_METHOD( ::executorch::runtime::Result, prefill_chunk, - (std::vector&, int64_t&), - ()); + (std::vector&, int64_t&, float), + (override)); }; // Create a mock TextPrefiller @@ -112,9 +112,9 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) { int64_t start_pos = 0; // Expect prefill_chunk to be called exactly once with the entire prompt - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)) .Times(1) - .WillOnce([&](std::vector& tokens, int64_t& pos) { + .WillOnce([&](std::vector& tokens, int64_t& pos, float) { // Verify the tokens passed to prefill_chunk EXPECT_EQ(tokens.size(), prompt_tokens.size()); for (size_t i = 0; i < tokens.size(); i++) { @@ -217,14 +217,14 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) { InSequence seq; // First chunk: tokens [1, 2, 3] - succeeds - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos, float) { return Result(10); }); // Second chunk: tokens [4, 5] - fails - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos, float) { return Result(Error::InvalidArgument); }); } diff --git a/extension/llm/runner/test/test_util.cpp b/extension/llm/runner/test/test_util.cpp index 3d66c212375..751b3f9daf9 100644 --- a/extension/llm/runner/test/test_util.cpp +++ b/extension/llm/runner/test/test_util.cpp @@ -18,6 +18,7 @@ namespace { using ::executorch::aten::ScalarType; using ::executorch::extension::make_tensor_ptr; using ::executorch::extension::llm::convert_to_bfloat16; +using ::executorch::extension::llm::utf8_complete_prefix_len; class ConvertToBFloat16Test : public ::testing::Test { protected: @@ -63,4 +64,26 @@ TEST_F(ConvertToBFloat16Test, RejectsNonFloatTensor) { EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument); } +TEST(Utf8CompletePrefixLenTest, HandlesAsciiAndMultiByteBoundaries) { + EXPECT_EQ(utf8_complete_prefix_len(""), 0u); + EXPECT_EQ(utf8_complete_prefix_len("ascii"), 5u); + + // Complete multi-byte characters are fully consumed. + EXPECT_EQ(utf8_complete_prefix_len("\xc3\xa9"), 2u); // é (2-byte) + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac"), 3u); // € (3-byte) + EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98\x80"), 4u); // 😀 (4-byte) + + // A character split across the end is held back (not counted). + EXPECT_EQ(utf8_complete_prefix_len("\xc3"), 0u); // 1/2 of é + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82"), 0u); // 2/3 of € + EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98"), 0u); // 3/4 of 😀 + + // A complete prefix followed by a split character keeps the complete part. + EXPECT_EQ(utf8_complete_prefix_len("hi\xe2\x82"), 2u); + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac\xf0\x9f"), 3u); + + // An invalid lead byte counts as length 1 (emitted, not stalled). + EXPECT_EQ(utf8_complete_prefix_len("\x80"), 1u); +} + } // namespace diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index c03d8ebb35d..b02692ed9a0 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -293,6 +293,17 @@ Result TextLLMRunner::prefill( const std::vector& inputs, int32_t num_bos, int32_t num_eos) { + // No GenerationConfig on this overload: the deprecated temperature_ wins if + // set (mirroring generate()'s precedence), otherwise greedy. + return prefill_impl( + inputs, num_bos, num_eos, temperature_ == -1.0f ? 0.0f : temperature_); +} + +Result TextLLMRunner::prefill_impl( + const std::vector& inputs, + int32_t num_bos, + int32_t num_eos, + float temperature) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } @@ -306,7 +317,9 @@ Result TextLLMRunner::prefill( "Failed to encode prompt %s", input.get_text().c_str()); std::vector tokens = encode_res.get(); - auto prefill_res = text_prefiller_->prefill(tokens, pos_); + // The first generated token is sampled here during prefill, so honor the + // requested temperature instead of defaulting to greedy. + auto prefill_res = text_prefiller_->prefill(tokens, pos_, temperature); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); prefill_next_token_ = prefill_res.get(); num_bos = 0; @@ -333,7 +346,13 @@ Result TextLLMRunner::prefill( Result TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { - return prefill(prompt, config.num_bos, config.num_eos); + std::vector inputs; + inputs.emplace_back(MultimodalInput(prompt)); + // Honor the request's sampling for the first token (sampled during prefill), + // mirroring generate(): the deprecated temperature_ wins if set. + const float resolved_temp = + temperature_ == -1.0f ? config.temperature : temperature_; + return prefill_impl(inputs, config.num_bos, config.num_eos, resolved_temp); } Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { @@ -352,6 +371,9 @@ Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { } void TextLLMRunner::stop() { + // Honored by both decode paths: generate() checks the token generator's flag, + // the token-step (session) loop checks stop_requested_ in decode_one(). + stop_requested_.store(true, std::memory_order_relaxed); if (is_loaded()) { text_token_generator_->stop(); } else { @@ -364,6 +386,7 @@ void TextLLMRunner::reset() { pos_ = 0; prefill_next_token_.reset(); prev_decode_token_.reset(); + stop_requested_.store(false, std::memory_order_relaxed); } ::executorch::runtime::Error TextLLMRunner::seek(int64_t pos) { @@ -443,6 +466,8 @@ ::executorch::runtime::Result TextLLMRunner::prefill_tokens( const float temp = (temperature < 0.0f) ? (temperature_ == -1.0f ? 0.0f : temperature_) : temperature; + // A new prefill starts a fresh generation turn; clear any prior stop request. + stop_requested_.store(false, std::memory_order_relaxed); auto prefill_res = text_prefiller_->prefill(tokens, pos_, temp); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); prefill_next_token_ = prefill_res.get(); @@ -484,13 +509,16 @@ ::executorch::runtime::Result TextLLMRunner::decode_one( } std::string text_piece = std::move(*decode_res); - // Stop at EOS WITHOUT forwarding it, like generate() (which breaks before the - // next step()): the EOS token is not made resident and pos_ does not advance, - // so position()/prefix reuse stay correct. No pending token remains, so a - // subsequent decode_one() correctly errors (generation is complete). - if (is_eos) { + // Terminate WITHOUT forwarding the token — at EOS (like generate(), which + // breaks before the next step()) or at a cooperative stop() request observed + // at this token boundary. The token is not made resident and pos_ does not + // advance, so position()/prefix reuse stay correct; no pending token remains, + // so a subsequent decode_one() correctly errors (generation is complete). + // is_eos stays literal; is_terminal is set either way so the loop ends. + if (is_eos || stop_requested_.load(std::memory_order_relaxed)) { prefill_next_token_.reset(); - return DecodeResult{token, std::move(text_piece), true}; + return DecodeResult{ + token, std::move(text_piece), is_eos, /*is_terminal=*/true}; } // Only a NON-EOS token is forwarded (made resident at pos_), so the capacity @@ -525,7 +553,8 @@ ::executorch::runtime::Result TextLLMRunner::decode_one( prev_decode_token_ = token; pos_ += 1; - return DecodeResult{token, std::move(text_piece), false}; + return DecodeResult{ + token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false}; } } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 90032ae8b59..8867775568a 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include @@ -129,8 +130,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { prefill(const std::string& prompt, int32_t num_bos = 0, int32_t num_eos = 0); /** - * Prefill a text prompt using GenerationConfig. - * Deprecated: prefer prefill(prompt, num_bos, num_eos). + * Prefill a text prompt using GenerationConfig. Samples the first token + * (sampled during prefill) at config.temperature, unlike the bare overloads + * which default to greedy. */ ::executorch::runtime::Result prefill( const std::string& prompt, @@ -187,6 +189,15 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { ::executorch::runtime::Result decode_one( float temperature = -1.0f); + // Shared implementation for the prefill() overloads: encodes the text inputs, + // prefills them at the current position, and samples the first token (sampled + // during prefill) at `temperature`. + ::executorch::runtime::Result prefill_impl( + const std::vector& inputs, + int32_t num_bos, + int32_t num_eos, + float temperature); + // Components std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; std::unordered_map metadata_; @@ -216,6 +227,12 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { // The position in KV cache of the input, starting from 0. int64_t pos_ = 0; + + // Cooperative stop for the token-step (session) decode loop: stop() sets it, + // decode_one() honors it at the next token boundary, and prefill_tokens() / + // reset() clear it for a fresh generation. Atomic so stop() is safe to call + // from another thread while decode_one() runs. + std::atomic stop_requested_{false}; }; } // namespace executorch::extension::llm diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 6bfde46eda0..751e0191a6b 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #if defined(__linux__) || defined(__ANDROID__) || defined(__unix__) #include @@ -66,6 +67,37 @@ ET_EXPERIMENTAL void inline safe_printf(const char* piece) { printf("%s", piece); } +// Length of the longest prefix of `s` that does not end in the middle of a +// UTF-8 multi-byte sequence. A byte-level tokenizer can emit a token that is +// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji), +// so a caller streaming text must hold the incomplete tail until it completes +// rather than decode the partial bytes. An invalid lead byte counts as length 1 +// (emitted, so the caller can replace it) rather than stalling output. +ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) { + size_t i = 0; + const size_t n = s.size(); + while (i < n) { + const unsigned char c = static_cast(s[i]); + size_t len; + if (c < 0x80) { + len = 1; + } else if ((c >> 5) == 0x6) { + len = 2; + } else if ((c >> 4) == 0xE) { + len = 3; + } else if ((c >> 3) == 0x1E) { + len = 4; + } else { + len = 1; // invalid lead byte; emit it and let the caller replace it + } + if (i + len > n) { + break; // incomplete trailing sequence: hold it for more bytes + } + i += len; + } + return i; +} + // ---------------------------------------------------------------------------- // utilities: time From 464863915426c404a3ef54e3b4a603ea489a495c Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 5 Jun 2026 10:29:45 -0700 Subject: [PATCH 5/6] [UPDATE] Update [ghstack-poisoned] --- extension/llm/runner/llm_session.h | 4 +- extension/llm/runner/test/test_util.cpp | 36 ++++++++++++++++++ extension/llm/runner/util.h | 50 +++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h index 1bfdb0a08a0..e65cdebd1d2 100644 --- a/extension/llm/runner/llm_session.h +++ b/extension/llm/runner/llm_session.h @@ -8,8 +8,8 @@ // Model-agnostic Engine/Session interfaces. Model-specific execution lives in // adapters that implement these (TextLLMSession over TextLLMRunner today; -// Gemma4Session etc. later); the server and pybind layer depend only on these -// interfaces, never on a concrete runner. +// Gemma4Session etc. later); the serving code (HTTP control plane + C++ worker +// binaries) depends only on these interfaces, never on a concrete runner. #pragma once diff --git a/extension/llm/runner/test/test_util.cpp b/extension/llm/runner/test/test_util.cpp index 751b3f9daf9..f189c0868d4 100644 --- a/extension/llm/runner/test/test_util.cpp +++ b/extension/llm/runner/test/test_util.cpp @@ -18,6 +18,7 @@ namespace { using ::executorch::aten::ScalarType; using ::executorch::extension::make_tensor_ptr; using ::executorch::extension::llm::convert_to_bfloat16; +using ::executorch::extension::llm::stop_safe_prefix_len; using ::executorch::extension::llm::utf8_complete_prefix_len; class ConvertToBFloat16Test : public ::testing::Test { @@ -86,4 +87,39 @@ TEST(Utf8CompletePrefixLenTest, HandlesAsciiAndMultiByteBoundaries) { EXPECT_EQ(utf8_complete_prefix_len("\x80"), 1u); } +TEST(StopSafePrefixLenTest, NoStopsEmitsEverything) { + bool hit = true; + EXPECT_EQ(stop_safe_prefix_len("hello world", {}, hit), 11u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, StopFoundReturnsEarliestOffsetAndExcludesIt) { + bool hit = false; + // "STOP" begins at offset 6; emit "Hello " (6 bytes), drop the stop and rest. + EXPECT_EQ(stop_safe_prefix_len("Hello STOP there", {"STOP"}, hit), 6u); + EXPECT_TRUE(hit); + // Earliest of several wins. + hit = false; + EXPECT_EQ(stop_safe_prefix_len("aXbY", {"Y", "X"}, hit), 1u); + EXPECT_TRUE(hit); +} + +TEST(StopSafePrefixLenTest, HoldsBackPossiblePartialStopTail) { + bool hit = false; + // No full stop yet, but the trailing "ST" could become "STOP": hold back + // len("STOP")-1 == 3 bytes, so of "hi ST" (5 bytes) only "hi" (2) is safe. + EXPECT_EQ(stop_safe_prefix_len("hi ST", {"STOP"}, hit), 2u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, HoldBackSnapsToUtf8Boundary) { + bool hit = false; + // "ab" + "€"(3 bytes). Stop "XX" => hold back 1 byte, which would land inside + // the euro sign; snap down so the multi-byte char isn't split. + const std::string text = "ab\xe2\x82\xac"; + const size_t safe = stop_safe_prefix_len(text, {"XX"}, hit); + EXPECT_FALSE(hit); + EXPECT_EQ(safe, 2u); // only "ab"; the € is held whole +} + } // namespace diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 751e0191a6b..a92dc7a7ba4 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -98,6 +99,55 @@ ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) { return i; } +// How many leading bytes of `text` a streaming consumer may safely emit given a +// set of `stops` strings, and whether a stop was hit (`stop_hit`). +// * If any stop occurs, returns the byte offset of the EARLIEST occurrence and +// sets stop_hit=true — text before it is safe; the stop and everything after +// are dropped (the stop is excluded from output). +// * Otherwise returns the length minus the longest possible partial-stop tail +// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a multi-byte +// character is never split; stop_hit=false. Holding back that tail lets a +// stop that straddles the next piece still be caught. +// `text` is expected to be complete-UTF-8 (e.g. the assembled output of +// utf8_complete_prefix_len). Empty `stops` => emit everything, no hold-back. +ET_EXPERIMENTAL size_t inline stop_safe_prefix_len( + const std::string& text, + const std::vector& stops, + bool& stop_hit) { + stop_hit = false; + if (stops.empty()) { + return text.size(); + } + size_t earliest = std::string::npos; + size_t max_len = 0; + for (const auto& s : stops) { + if (s.empty()) { + continue; + } + max_len = std::max(max_len, s.size()); + const size_t p = text.find(s); + if (p != std::string::npos && + (earliest == std::string::npos || p < earliest)) { + earliest = p; + } + } + if (earliest != std::string::npos) { + stop_hit = true; + return earliest; + } + const size_t hold = max_len > 0 ? max_len - 1 : 0; + if (text.size() <= hold) { + return 0; + } + size_t end = text.size() - hold; + // Don't cut in the middle of a UTF-8 character: back up over continuation + // bytes (10xxxxxx). + while (end > 0 && (static_cast(text[end]) & 0xC0) == 0x80) { + --end; + } + return end; +} + // ---------------------------------------------------------------------------- // utilities: time From 6edbb17040633c786b23c04ab0458163a6aa2bf9 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Mon, 8 Jun 2026 12:19:38 -0700 Subject: [PATCH 6/6] [UPDATE] Update [ghstack-poisoned] --- extension/llm/runner/llm_session.h | 8 +++++--- extension/llm/runner/util.h | 13 +++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h index e65cdebd1d2..089c96c7427 100644 --- a/extension/llm/runner/llm_session.h +++ b/extension/llm/runner/llm_session.h @@ -90,9 +90,11 @@ class ET_EXPERIMENTAL LLMSession { virtual ::executorch::runtime::Result decode_one( const SamplingConfig& sampling) = 0; - /// Rewind the KV cache to `pos` (prefix reuse). Valid for full-KV models; - /// sliding-window KV may reject a seek past its window (the caller falls back - /// to a fresh prefill). + /// Rewind the KV cache to `pos` (prefix reuse). Valid for full-KV models. + /// Returns InvalidArgument if `pos` is outside [0, position()]. Returns + /// NotSupported for models whose state cannot be safely rewound (for example, + /// non-KV-cache, sliding-window, or recurrent-state models); callers should + /// fall back to reset() + full prefill. virtual ::executorch::runtime::Error seek(int64_t pos) = 0; /// Number of tokens with resident KV (upper bound for seek()). diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index a92dc7a7ba4..61860580c50 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -101,13 +101,14 @@ ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) { // How many leading bytes of `text` a streaming consumer may safely emit given a // set of `stops` strings, and whether a stop was hit (`stop_hit`). -// * If any stop occurs, returns the byte offset of the EARLIEST occurrence and -// sets stop_hit=true — text before it is safe; the stop and everything after -// are dropped (the stop is excluded from output). +// * If any stop occurs, returns the byte offset of the EARLIEST occurrence +// and +// sets stop_hit=true — text before it is safe; the stop and everything +// after are dropped (the stop is excluded from output). // * Otherwise returns the length minus the longest possible partial-stop tail -// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a multi-byte -// character is never split; stop_hit=false. Holding back that tail lets a -// stop that straddles the next piece still be caught. +// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a +// multi-byte character is never split; stop_hit=false. Holding back that +// tail lets a stop that straddles the next piece still be caught. // `text` is expected to be complete-UTF-8 (e.g. the assembled output of // utf8_complete_prefix_len). Empty `stops` => emit everything, no hold-back. ET_EXPERIMENTAL size_t inline stop_safe_prefix_len(