diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 0744c09e641..5063f6ef132 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -29,8 +30,18 @@ namespace executorch::extension::llm { using ::executorch::extension::Module; +using ::executorch::extension::Program; using ::executorch::runtime::Error; +// Assembles the per-Module components (decoder/prefiller/token generator/io +// manager/stats) into a TextLLMRunner. Shared by the path-based and the +// shared-Program (TextLLMEngine session) construction paths. +static std::unique_ptr assemble_text_llm_runner( + std::unique_ptr module, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature, + const std::string& method_name); + std::unique_ptr load_tokenizer( const std::string& tokenizer_path, std::unique_ptr> special_tokens, @@ -251,6 +262,15 @@ std::unique_ptr create_text_llm_runner( max_cached_memory_size_bytes_)); } + return assemble_text_llm_runner( + std::move(module), std::move(tokenizer), temperature, method_name); +} + +static std::unique_ptr assemble_text_llm_runner( + std::unique_ptr module, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature, + const std::string& method_name) { // Get metadata from Module ET_LOG(Info, "Reading metadata from model"); auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); @@ -305,6 +325,198 @@ std::unique_ptr create_text_llm_runner( temperature); } +// Builds a TextLLMRunner over an already-loaded Program: the runner's Module +// reuses `program` while owning its own method state and KV cache. File-local — +// the per-session construction path for TextLLMEngine (which keeps the backing +// DataLoader alive for the runners' lifetime). External callers go through +// LLMEngine -> LLMSession, not a raw shared-Program runner. +static std::unique_ptr create_text_llm_runner_from_program( + std::shared_ptr program, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + float temperature, + const std::string& method_name) { + if (!tokenizer || !tokenizer->is_loaded()) { + ET_LOG(Error, "Tokenizer is null or not loaded"); + return nullptr; + } + if (!program) { + ET_LOG(Error, "Program is null"); + return nullptr; + } + // A Module over the already-loaded Program: it reuses that Program rather + // than re-loading it, and its loaded method allocates its own planned (KV) + // memory. Whether packed weights are physically shared vs. re-materialized + // per method instance is backend-dependent (serving_capacity() is the + // authority). + constexpr uint32_t kMaxCachedMemoryBytes = 1024 * 1024 * 10; // 10MB + auto module = std::make_unique( + std::move(program), + nullptr, // memory allocator + std::make_unique( + kMaxCachedMemoryBytes)); + return assemble_text_llm_runner( + std::move(module), std::move(tokenizer), temperature, method_name); +} + +namespace detail { +// The TextLLM adapter: implements the model-agnostic LLMSession over a +// TextLLMRunner. TextLLMRunner's token-step methods are private; this adapter +// is their only (friended) caller, so the engine and server depend solely on +// LLMSession. +TextLLMSession::TextLLMSession(std::unique_ptr runner) + : runner_(std::move(runner)) {} + +Error TextLLMSession::prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling) { + // The model samples the FIRST generated token during prefill, so apply the + // request's sampling here (not a stale default). Only temperature is + // plumbed; reject non-default top_p/top_k/seed for parity with decode_one(). + float temperature = -1.0f; + if (initial_sampling != nullptr) { + if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 || + initial_sampling->seed != 0) { + ET_LOG( + Error, + "TextLLMSession: only temperature is supported; top_p/top_k/seed " + "are not yet implemented"); + return ::executorch::runtime::Error::NotSupported; + } + temperature = initial_sampling->temperature; + } + return runner_->prefill_tokens(std::move(tokens), temperature).error(); +} + +::executorch::runtime::Result TextLLMSession::decode_one( + const SamplingConfig& sampling) { + // Only temperature is plumbed today; top_p/top_k/seed need a per-session + // sampler (a follow-up). Reject non-default values rather than silently + // ignoring them, so callers can't assume constraints are applied. + if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + ET_LOG( + Error, + "TextLLMSession: only temperature is supported; top_p/top_k/seed are " + "not yet implemented"); + return ::executorch::runtime::Error::NotSupported; + } + return runner_->decode_one(sampling.temperature); +} + +Error TextLLMSession::seek(int64_t pos) { + return runner_->seek(pos); +} + +int64_t TextLLMSession::position() const { + return runner_->position(); +} + +Error TextLLMSession::reset() { + runner_->reset(); + return Error::Ok; +} + +void TextLLMSession::stop() { + runner_->stop(); +} + +std::unique_ptr make_text_llm_session( + std::unique_ptr runner) { + return std::make_unique(std::move(runner)); +} +} // namespace detail + +TextLLMEngine::TextLLMEngine( + std::unique_ptr loader_module, + std::shared_ptr program, + std::string tokenizer_path, + float temperature, + std::string method_name, + std::unordered_map metadata) + : loader_module_(std::move(loader_module)), + program_(std::move(program)), + tokenizer_path_(std::move(tokenizer_path)), + temperature_(temperature), + method_name_(std::move(method_name)), + metadata_(std::move(metadata)) {} + +std::unique_ptr TextLLMEngine::create( + const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path, + float temperature, + const std::string& method_name, + Module::LoadMode load_mode) { + // External .ptd weights are not yet supported for shared sessions: each + // session Module built from the shared Program would also need the + // data_map_loader threaded into its load_method() to resolve external + // weights (see Module::load_method merged_data_map_). Fail loudly rather than + // silently produce sessions that error on first generate. + if (data_path.has_value()) { + ET_LOG( + Error, + "TextLLMEngine: external data_path (.ptd) is not yet supported for " + "shared sessions; use a self-contained .pte for now."); + return nullptr; + } + // Load the program ONCE; sessions reuse it (loaded a single time, per-session + // KV). Physical weight sharing across sessions is backend-dependent — see + // serving_capacity(). + auto loader_module = std::make_unique(model_path, load_mode); + if (loader_module->load() != Error::Ok) { + ET_LOG( + Error, + "TextLLMEngine: failed to load program from %s", + model_path.c_str()); + return nullptr; + } + auto program = loader_module->program(); + if (!program) { + ET_LOG(Error, "TextLLMEngine: program is null after load"); + return nullptr; + } + // Read model-level metadata once (shared by all sessions). + auto meta_tokenizer = load_tokenizer(tokenizer_path); + if (!meta_tokenizer) { + ET_LOG( + Error, + "TextLLMEngine: failed to load tokenizer from %s", + tokenizer_path.c_str()); + return nullptr; + } + auto metadata_result = + get_llm_metadata(meta_tokenizer.get(), loader_module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "TextLLMEngine: failed to read metadata"); + return nullptr; + } + return std::unique_ptr(new TextLLMEngine( + std::move(loader_module), + std::move(program), + tokenizer_path, + temperature, + method_name, + metadata_result.get())); +} + +::executorch::runtime::Result> +TextLLMEngine::create_session() { + auto tokenizer = load_tokenizer(tokenizer_path_); + if (!tokenizer) { + ET_LOG( + Error, + "TextLLMEngine: failed to load tokenizer from %s", + tokenizer_path_.c_str()); + return Error::InvalidState; + } + auto runner = create_text_llm_runner_from_program( + program_, std::move(tokenizer), temperature_, method_name_); + if (!runner) { + ET_LOG(Error, "TextLLMEngine: failed to build session runner"); + return Error::InvalidState; + } + return detail::make_text_llm_session(std::move(runner)); +} + std::unique_ptr create_multimodal_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index b4c7c59806d..c2eaf0c8ac3 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -141,6 +142,76 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( const std::string& method_name = "forward", Module::LoadMode load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors); +/** + * @brief Engine for multi-session text generation over one loaded Program. + * + * Loads the model's Program (weights/constants) once; create_session() builds a + * TextLLMRunner that reuses that Program but owns its own method/KV state. This + * is the correctness-first foundation for serving multiple conversations. + * Backend execution should be serialized by the caller until per-backend thread + * safety is proven (Module::execute is not assumed thread-safe). Whether extra + * sessions avoid duplicating packed weights is backend-dependent and reported + * by serving_capacity() (conservatively one). + */ +class ET_EXPERIMENTAL TextLLMEngine : public LLMEngine { + public: + static std::unique_ptr create( + const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path = std::nullopt, + float temperature = -1.0f, + const std::string& method_name = "forward", + Module::LoadMode load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors); + + // Returns a TextLLMSession (LLMSession) that reuses this engine's loaded + // Program (physical weight sharing is backend-dependent; see + // serving_capacity). + ::executorch::runtime::Result> create_session() + override; + // Conservative: a single physical session (no proven cross-session weight + // sharing). Raise on a backend proven to share packed weights. + LLMServingCapacity serving_capacity() const override { + return LLMServingCapacity{}; + } + const std::unordered_map& metadata() const override { + return metadata_; + } + + TextLLMEngine(const TextLLMEngine&) = delete; + TextLLMEngine& operator=(const TextLLMEngine&) = delete; + + private: + TextLLMEngine( + std::unique_ptr loader_module, + std::shared_ptr program, + std::string tokenizer_path, + float temperature, + std::string method_name, + std::unordered_map metadata); + + // Keeps the shared Program's DataLoader alive for the lifetime of sessions. + std::unique_ptr loader_module_; + std::shared_ptr program_; + std::string tokenizer_path_; + float temperature_; + std::string method_name_; + std::unordered_map metadata_; +}; + +namespace detail { +// Implementation detail (not a public API): wraps a TextLLMRunner in an +// LLMSession (the runner -> session seam). The supported entry point is +// LLMEngine::create_session(); this exists only so TextLLMEngine can build its +// sessions and so unit tests can drive the runner's token-step primitives +// through the public LLMSession surface (the concrete adapter type is private). +// Do not depend on wrapping arbitrary runners. +// +// @param runner A loaded TextLLMRunner; ownership transfers to the session. +// @return std::unique_ptr wrapping `runner`. +std::unique_ptr make_text_llm_session( + std::unique_ptr runner); +} // namespace detail + /** * @brief Creates a MultimodalRunner instance with dependency injection * diff --git a/extension/llm/runner/llm_session.h b/extension/llm/runner/llm_session.h new file mode 100644 index 00000000000..089c96c7427 --- /dev/null +++ b/extension/llm/runner/llm_session.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Model-agnostic Engine/Session interfaces. Model-specific execution lives in +// adapters that implement these (TextLLMSession over TextLLMRunner today; +// Gemma4Session etc. later); the serving code (HTTP control plane + C++ worker +// binaries) depends only on these interfaces, never on a concrete runner. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +/// Per-decode sampling parameters. An adapter applies the fields it supports +/// and rejects non-default values of the rest rather than silently ignoring +/// them (today only temperature is plumbed). -1 temperature means model +/// default. +struct SamplingConfig { + float temperature = -1.0f; + float top_p = 1.0f; + int32_t top_k = 0; // 0 = disabled + uint64_t seed = 0; // 0 = unset +}; + +/// One decoded step: the exact sampled token id (for prefix-cache id tracking +/// and batching) and its decoded text piece (raw bytes; may be a partial UTF-8 +/// sequence the caller assembles). +/// +/// `is_eos` is literal: the sampled token is an end-of-sequence token (use it +/// for the "stop" finish reason, metrics, cache/accounting). `is_terminal` is +/// the loop signal: generation ended at this step — either EOS or a cooperative +/// stop() took effect. A decode loop should end when is_terminal is set; every +/// EOS step is also terminal, but a stop step is terminal without being EOS. +struct DecodeResult { + uint64_t token_id; + std::string text_piece; + bool is_eos; + bool is_terminal; +}; + +/// How many physical sessions an engine can host, so the server admits logical +/// requests without silently multiplying model memory. This is a *serving +/// capacity* concern (engine-level), distinct from how a session advances a +/// conversation (LLMSession) — keep backend memory flags off LLMSession. +struct LLMServingCapacity { + // Physical sessions creatable without duplicating packed weights. + // Conservatively 1 (some backends repack weights per runtime, so extra + // sessions would copy the whole model); raise only on a backend proven to + // share packed weights. + int32_t max_physical_sessions_without_weight_duplication = 1; + // Planned bytes one session adds (KV + activations), for memory-budget + // admission. 0 = unknown; the server skips the memory clamp. + int64_t estimated_bytes_per_session = 0; +}; + +/// One conversation's mutable state (KV cache, position cursor). Created by an +/// LLMEngine; conversation/cache-scoped (kept warm across requests for prefix +/// reuse), not request-scoped. +class ET_EXPERIMENTAL LLMSession { + public: + virtual ~LLMSession() = default; + + /// Prefill pre-tokenized input at the current position (call seek() first for + /// prefix reuse). Must be non-empty and fit the context window. + /// + /// `initial_sampling` (optional): the sampling config for the FIRST generated + /// token, for backends that sample during prefill (e.g. in-graph sampling). + /// Pass it so the first token uses the request's sampling instead of a stale + /// default. Backends that only sample in decode_one() ignore it. + virtual ::executorch::runtime::Error prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling = nullptr) = 0; + + /// Decode one token from the pending state; looping reproduces a full + /// generation while returning exact sampled token ids. A single decode_one() + /// runs one forward pass and is not interruptible mid-call (see stop()). + virtual ::executorch::runtime::Result decode_one( + const SamplingConfig& sampling) = 0; + + /// Rewind the KV cache to `pos` (prefix reuse). Valid for full-KV models. + /// Returns InvalidArgument if `pos` is outside [0, position()]. Returns + /// NotSupported for models whose state cannot be safely rewound (for example, + /// non-KV-cache, sliding-window, or recurrent-state models); callers should + /// fall back to reset() + full prefill. + virtual ::executorch::runtime::Error seek(int64_t pos) = 0; + + /// Number of tokens with resident KV (upper bound for seek()). + virtual int64_t position() const = 0; + + /// Clear the KV cache / position for a fresh conversation. + virtual ::executorch::runtime::Error reset() = 0; + + /// Request that a decode_one() loop stop. This is a TOKEN-BOUNDARY, + /// cooperative stop: it is safe to call from another thread, but it does not + /// abort a decode_one() that is already running. It takes effect at the next + /// decode_one(), which then returns a terminal step (is_terminal set, is_eos + /// false) without forwarding a new token. The stop is cleared by the next + /// prefill_tokens() or reset(). + virtual void stop() = 0; +}; + +/// Holds the immutable model resources (program, tokenizer, metadata) once and +/// creates sessions that reuse them while isolating their own KV state. How +/// many sessions can be created without duplicating packed weights is backend- +/// dependent — see serving_capacity(). +class ET_EXPERIMENTAL LLMEngine { + public: + virtual ~LLMEngine() = default; + + /// Build a new session that reuses this engine's program/resources when the + /// backend supports it, with its own KV cache. serving_capacity() is the + /// authority on how many physical sessions are safe without weight + /// duplication. + virtual ::executorch::runtime::Result> + create_session() = 0; + + /// How many physical sessions this engine can host without duplicating + /// weights (+ optional per-session memory estimate); the server clamps the + /// number of physical sessions it creates to this. + virtual LLMServingCapacity serving_capacity() const = 0; + virtual const std::unordered_map& metadata() const = 0; +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index d3e12266adc..60ba47e9013 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -114,8 +114,15 @@ def define_common_targets(): exported_headers = [ "text_llm_runner.h", "llm_runner_helper.h", + "llm_session.h", "constants.h", ], + # Internal: the detail::TextLLMSession adapter (sole friended caller + # of TextLLMRunner's token-step hooks). Private so dependents reach + # it only through LLMEngine/LLMSession + make_text_llm_session(). + headers = [ + "text_llm_session.h", + ], srcs = [ "text_llm_runner.cpp", "llm_runner_helper.cpp", diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 5a98c55bb2f..e55eb06b545 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -20,13 +20,17 @@ #include using namespace ::testing; +using executorch::extension::llm::DecodeResult; using executorch::extension::llm::GenerationConfig; +using executorch::extension::llm::LLMServingCapacity; using executorch::extension::llm::LogitProcessor; +using executorch::extension::llm::SamplingConfig; using executorch::extension::llm::Stats; using executorch::extension::llm::TextDecoderRunner; using executorch::extension::llm::TextLLMRunner; using executorch::extension::llm::TextPrefiller; using executorch::extension::llm::TextTokenGenerator; +using executorch::extension::llm::detail::make_text_llm_session; using executorch::runtime::Error; using executorch::runtime::Result; using executorch::runtime::testing::TensorFactory; @@ -95,8 +99,8 @@ class MockTextPrefiller : public TextPrefiller { MOCK_METHOD( Result, prefill, - (std::vector&, int64_t&), - ()); + (std::vector&, int64_t&, float), + (override)); MOCK_METHOD(::executorch::runtime::Error, load, (), ()); MOCK_METHOD(bool, is_loaded, (), ()); }; @@ -190,7 +194,7 @@ class RunnerTest : public Test { ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); // Set up default behavior for the text prefiller ON_CALL(*text_prefiller, prefill) - .WillByDefault([](const std::vector&, int64_t) { + .WillByDefault([](const std::vector&, int64_t, float) { return Result(4); }); @@ -220,6 +224,52 @@ class RunnerTest : public Test { }; } + // Builds a loaded TextLLMRunner with default mocks whose prefiller advances + // the position cursor by the number of tokens (so position()/seek() and the + // capacity bound can be exercised directly). + std::unique_ptr makeRunner( + std::unordered_map metadata, + std::shared_ptr logit_processor = nullptr, + uint64_t prefill_token = 42, + MockTextPrefiller** out_prefiller = nullptr) { + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([prefill_token]( + std::vector& tokens, int64_t& pos, float) { + pos += tokens.size(); + return Result(prefill_token); + }); + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + // Expose the mock prefiller (still owned by the runner) so a test can set + // expectations on it. + if (out_prefiller != nullptr) { + *out_prefiller = text_prefiller.get(); + } + auto stats = std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + if (logit_processor) { + text_token_generator->add_logit_processor(std::move(logit_processor)); + } + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + auto runner = std::make_unique( + std::move(metadata), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + runner->load(); + return runner; + } + protected: Stats stats_; std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; @@ -242,8 +292,8 @@ TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { }); // Set up expectations for the text prefiller - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector&, int64_t&) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector&, int64_t&, float) { return (Result(4)); }); @@ -310,8 +360,8 @@ TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { }); // Set up expectations for the text prefiller - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector&, int64_t&) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector&, int64_t&, float) { return (Result(4)); }); @@ -396,8 +446,8 @@ TEST_F(RunnerTest, PrefillReturnsNextToken) { std::vector{1, 2, 3}); }); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector& tokens, int64_t& pos) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos, float) { pos += tokens.size(); return Result(42); }); @@ -442,8 +492,8 @@ TEST_F(RunnerTest, PrefillThenGenerateEmpty) { std::vector{1, 2, 3}); }); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector& tokens, int64_t& pos) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos, float) { pos += tokens.size(); return Result(4); }); @@ -490,6 +540,197 @@ TEST_F(RunnerTest, PrefillThenGenerateEmpty) { EXPECT_EQ(counter.getCount(), config.max_new_tokens); } +// prefill(prompt, GenerationConfig) must pass config.temperature to the +// prefiller: the first token is sampled during prefill, so dropping it would +// silently sample greedy regardless of the requested temperature. +TEST_F(RunnerTest, PrefillWithConfigPassesTemperatureToPrefiller) { + MockTextPrefiller* prefiller = nullptr; + auto runner = makeRunner( + createDefaultMetadata(), nullptr, /*prefill_token=*/42, &prefiller); + ASSERT_NE(prefiller, nullptr); + + EXPECT_CALL(*prefiller, prefill(_, _, FloatEq(0.7f))) + .WillOnce([](std::vector& tokens, int64_t& pos, float) { + pos += tokens.size(); + return Result(42); + }); + + GenerationConfig config; + config.temperature = 0.7f; + auto result = runner->prefill("hello", config); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.get(), 42); +} + +// The token-step methods are private on TextLLMRunner (internal serving hooks); +// they are exercised through the public LLMSession surface returned by +// make_text_llm_session() — the same path the engine/server depend on. +// +// prefill_tokens() must reject a suffix larger than the KV cache, mirroring the +// capacity bound generate(prompt) enforces (prefill_tokens is the +// prefix-cache primitive and the only place this is checked for it). +TEST_F(RunnerTest, PrefillTokensRejectsOverContext) { + auto session = + make_text_llm_session(makeRunner(createDefaultMetadata())); // context 128 + EXPECT_EQ( + session->prefill_tokens(std::vector(200, 1)), // 200 > 128 + Error::InvalidArgument); +} + +// seek() + prefill_tokens() across the boundary is rejected: a valid prefill +// followed by a suffix that pushes pos_ past max_context_len must fail. +TEST_F(RunnerTest, PrefillTokensRejectsWhenPosPlusSuffixExceedsContext) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_EQ(session->prefill_tokens(std::vector(100, 1)), Error::Ok); + EXPECT_EQ(session->position(), 100); + EXPECT_EQ( + session->prefill_tokens(std::vector(50, 1)), // 100 + 50 > 128 + Error::InvalidArgument); + EXPECT_EQ(session->position(), 100); // rejected before advancing +} + +// Empty tokens are rejected. +TEST_F(RunnerTest, PrefillTokensRejectsEmpty) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_EQ( + session->prefill_tokens(std::vector{}), Error::InvalidArgument); +} + +// position() tracks prefilled tokens; seek() rewinds within range and rejects +// out-of-range targets. +TEST_F(RunnerTest, SeekAndPositionTrackResidentTokens) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_EQ(session->position(), 0); + EXPECT_EQ(session->prefill_tokens(std::vector(10, 1)), Error::Ok); + EXPECT_EQ(session->position(), 10); + EXPECT_EQ(session->seek(5), Error::Ok); + EXPECT_EQ(session->position(), 5); + EXPECT_EQ( + session->seek(999), Error::InvalidArgument); // past current position +} + +// seek() is refused on sliding-window models (max_seq_len < max_context_len), +// so the prefix cache falls back to a full reset+prefill instead of corrupting. +TEST_F(RunnerTest, SeekRejectedForSlidingWindow) { + auto md = createDefaultMetadata(); + md["get_max_seq_len"] = 64; // < get_max_context_len (128) => sliding window + auto session = make_text_llm_session(makeRunner(md)); + EXPECT_EQ(session->seek(0), Error::NotSupported); +} + +// decode_one() emits the pending token id exactly, then forwards it (advancing +// position by one). Looping it reproduces a generation while exposing ids. +TEST_F(RunnerTest, DecodeOneReturnsExactTokenIdAndAdvances) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + EXPECT_EQ(session->position(), 3); + + auto r1 = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r1.ok()); + EXPECT_EQ(r1.get().token_id, 42u); // the prefill-pending token (mock prefill) + EXPECT_FALSE(r1.get().is_eos); + EXPECT_FALSE(r1.get().is_terminal); + EXPECT_EQ(session->position(), 4); // forwarded one token + + auto r2 = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r2.ok()); + EXPECT_EQ(r2.get().token_id, 3u); // argmax of canned logits {.1,.2,.3,.4} + EXPECT_EQ(session->position(), 5); +} + +// decode_one() without a pending token (no prior prefill) must error. +TEST_F(RunnerTest, DecodeOneWithoutPendingTokenFails) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + EXPECT_FALSE(session->decode_one(SamplingConfig{0.0f}).ok()); +} + +// decode_one() must stop at EOS WITHOUT forwarding it (like generate()): the +// EOS token is not made resident, position() does not advance, and no pending +// token remains — so prefix reuse stays correct and a further decode_one() +// errors. (The fixture's EOS id is 100.) +TEST_F(RunnerTest, DecodeOneStopsAtEosWithoutForwarding) { + auto session = make_text_llm_session( + makeRunner(createDefaultMetadata(), nullptr, /*prefill_token=*/100)); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + EXPECT_EQ(session->position(), 3); + + auto r = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r.ok()); + EXPECT_EQ(r.get().token_id, 100u); + EXPECT_TRUE(r.get().is_eos); + EXPECT_TRUE(r.get().is_terminal); + EXPECT_EQ(session->position(), 3); // EOS not forwarded -> position unchanged + + // No pending token remains -> a further decode_one() errors. + EXPECT_FALSE(session->decode_one(SamplingConfig{0.0f}).ok()); +} + +// stop() is honored at the next decode_one(): it returns a terminal step +// (is_terminal set, is_eos false) WITHOUT forwarding, position() does not +// advance, and no pending token remains; a fresh prefill_tokens() clears the +// stop so decoding resumes. +TEST_F(RunnerTest, DecodeOneHonorsStop) { + auto session = make_text_llm_session(makeRunner(createDefaultMetadata())); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + + session->stop(); + auto r = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r.ok()); + EXPECT_FALSE(r.get().is_eos); // a stop is not an EOS token + EXPECT_TRUE(r.get().is_terminal); // but it ends the loop + EXPECT_EQ(session->position(), 3); // not forwarded + EXPECT_FALSE(session->decode_one(SamplingConfig{0.0f}).ok()); // no pending + + // A new prefill clears the stop; decoding resumes. + ASSERT_EQ(session->prefill_tokens(std::vector{4, 5}), Error::Ok); + auto r2 = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r2.ok()); + EXPECT_FALSE(r2.get().is_eos); + EXPECT_FALSE(r2.get().is_terminal); +} + +// decode_one() is the last-resort safety net for session drivers: even if a +// caller forgets to resolve max_new_tokens, it must not step past KV capacity. +TEST_F(RunnerTest, DecodeOneRejectsWhenContextFull) { + auto md = createDefaultMetadata(); // max_context_len = 128 + auto session = make_text_llm_session(makeRunner(md)); + ASSERT_EQ(session->prefill_tokens(std::vector(127, 1)), Error::Ok); + auto r1 = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r1.ok()); + EXPECT_EQ(session->position(), 128); + + auto r2 = session->decode_one(SamplingConfig{0.0f}); + EXPECT_FALSE(r2.ok()); + EXPECT_EQ(r2.error(), Error::InvalidArgument); +} + +// decode_one() must apply the generator's logit processors before sampling, +// exactly like generate(). Argmax of {0.1,0.2,0.3,0.4} is token 3; masking it +// makes the next sampled (pending) token 2 — proving the session decode path +// honors grammar/tool masks/penalties and can't diverge from generate(). +TEST_F(RunnerTest, DecodeOneAppliesLogitProcessors) { + auto session = make_text_llm_session(makeRunner( + createDefaultMetadata(), std::make_shared(3))); + ASSERT_EQ(session->prefill_tokens(std::vector{1, 2, 3}), Error::Ok); + + // r1 emits the prefill-pending token (42, not sampled here); its forward pass + // samples the next pending token from masked logits. + ASSERT_TRUE(session->decode_one(SamplingConfig{0.0f}).ok()); + auto r2 = session->decode_one(SamplingConfig{0.0f}); + ASSERT_TRUE(r2.ok()); + EXPECT_EQ(r2.get().token_id, 2u); // token 3 masked -> argmax is now 2 +} + +// Serving capacity is conservatively a single physical session by default +// (no proven cross-session weight sharing). TextLLMEngine::serving_capacity() +// returns this default; the engine-backed end-to-end check is in the pybinding +// test. +TEST_F(RunnerTest, ServingCapacityIsSingleSlotByDefault) { + LLMServingCapacity cap; + EXPECT_EQ(cap.max_physical_sessions_without_weight_duplication, 1); + EXPECT_EQ(cap.estimated_bytes_per_session, 0); +} + // Test that generate("") without prior prefill() returns an error TEST_F(RunnerTest, GenerateEmptyWithoutPrefillFails) { auto tokenizer = createMockTokenizer(); @@ -591,8 +832,8 @@ TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) { std::vector{1, 2, 3}); }); - ON_CALL(*text_prefiller, prefill(_, _)) - .WillByDefault([&](std::vector& tokens, int64_t& pos) { + ON_CALL(*text_prefiller, prefill(_, _, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos, float) { pos += tokens.size(); return Result(4); }); diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp index 5ed5031dace..cba44ac8ab6 100644 --- a/extension/llm/runner/test/test_text_prefiller.cpp +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -79,8 +79,8 @@ class TextPrefillerTest : public Test { MOCK_METHOD( ::executorch::runtime::Result, prefill_chunk, - (std::vector&, int64_t&), - ()); + (std::vector&, int64_t&, float), + (override)); }; // Create a mock TextPrefiller @@ -112,9 +112,9 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) { int64_t start_pos = 0; // Expect prefill_chunk to be called exactly once with the entire prompt - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)) .Times(1) - .WillOnce([&](std::vector& tokens, int64_t& pos) { + .WillOnce([&](std::vector& tokens, int64_t& pos, float) { // Verify the tokens passed to prefill_chunk EXPECT_EQ(tokens.size(), prompt_tokens.size()); for (size_t i = 0; i < tokens.size(); i++) { @@ -217,14 +217,14 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) { InSequence seq; // First chunk: tokens [1, 2, 3] - succeeds - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos, float) { return Result(10); }); // Second chunk: tokens [4, 5] - fails - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos, float) { return Result(Error::InvalidArgument); }); } diff --git a/extension/llm/runner/test/test_util.cpp b/extension/llm/runner/test/test_util.cpp index 3d66c212375..f189c0868d4 100644 --- a/extension/llm/runner/test/test_util.cpp +++ b/extension/llm/runner/test/test_util.cpp @@ -18,6 +18,8 @@ namespace { using ::executorch::aten::ScalarType; using ::executorch::extension::make_tensor_ptr; using ::executorch::extension::llm::convert_to_bfloat16; +using ::executorch::extension::llm::stop_safe_prefix_len; +using ::executorch::extension::llm::utf8_complete_prefix_len; class ConvertToBFloat16Test : public ::testing::Test { protected: @@ -63,4 +65,61 @@ TEST_F(ConvertToBFloat16Test, RejectsNonFloatTensor) { EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument); } +TEST(Utf8CompletePrefixLenTest, HandlesAsciiAndMultiByteBoundaries) { + EXPECT_EQ(utf8_complete_prefix_len(""), 0u); + EXPECT_EQ(utf8_complete_prefix_len("ascii"), 5u); + + // Complete multi-byte characters are fully consumed. + EXPECT_EQ(utf8_complete_prefix_len("\xc3\xa9"), 2u); // é (2-byte) + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac"), 3u); // € (3-byte) + EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98\x80"), 4u); // 😀 (4-byte) + + // A character split across the end is held back (not counted). + EXPECT_EQ(utf8_complete_prefix_len("\xc3"), 0u); // 1/2 of é + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82"), 0u); // 2/3 of € + EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98"), 0u); // 3/4 of 😀 + + // A complete prefix followed by a split character keeps the complete part. + EXPECT_EQ(utf8_complete_prefix_len("hi\xe2\x82"), 2u); + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac\xf0\x9f"), 3u); + + // An invalid lead byte counts as length 1 (emitted, not stalled). + EXPECT_EQ(utf8_complete_prefix_len("\x80"), 1u); +} + +TEST(StopSafePrefixLenTest, NoStopsEmitsEverything) { + bool hit = true; + EXPECT_EQ(stop_safe_prefix_len("hello world", {}, hit), 11u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, StopFoundReturnsEarliestOffsetAndExcludesIt) { + bool hit = false; + // "STOP" begins at offset 6; emit "Hello " (6 bytes), drop the stop and rest. + EXPECT_EQ(stop_safe_prefix_len("Hello STOP there", {"STOP"}, hit), 6u); + EXPECT_TRUE(hit); + // Earliest of several wins. + hit = false; + EXPECT_EQ(stop_safe_prefix_len("aXbY", {"Y", "X"}, hit), 1u); + EXPECT_TRUE(hit); +} + +TEST(StopSafePrefixLenTest, HoldsBackPossiblePartialStopTail) { + bool hit = false; + // No full stop yet, but the trailing "ST" could become "STOP": hold back + // len("STOP")-1 == 3 bytes, so of "hi ST" (5 bytes) only "hi" (2) is safe. + EXPECT_EQ(stop_safe_prefix_len("hi ST", {"STOP"}, hit), 2u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, HoldBackSnapsToUtf8Boundary) { + bool hit = false; + // "ab" + "€"(3 bytes). Stop "XX" => hold back 1 byte, which would land inside + // the euro sign; snap down so the multi-byte char isn't split. + const std::string text = "ab\xe2\x82\xac"; + const size_t safe = stop_safe_prefix_len(text, {"XX"}, hit); + EXPECT_FALSE(hit); + EXPECT_EQ(safe, 2u); // only "ab"; the € is held whole +} + } // namespace diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index cf7ab50b9c8..b02692ed9a0 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -113,6 +113,12 @@ Error TextLLMRunner::generate( int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); + // Resolve sampling temperature once: the first token is sampled during + // prefill and the rest in the token generator, so both must use the same + // temperature. + const float resolved_temp = + temperature_ == -1.0f ? config.temperature : temperature_; + uint64_t cur_token = 0; int num_prompt_tokens = 0; std::vector prompt_tokens; @@ -177,7 +183,8 @@ Error TextLLMRunner::generate( // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + auto prefill_res = + text_prefiller_->prefill(prompt_tokens, pos_, resolved_temp); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); cur_token = prefill_res.get(); prefill_next_token_.reset(); @@ -235,10 +242,6 @@ Error TextLLMRunner::generate( // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); - // Use the configuration's temperature - float resolved_temp = - temperature_ == -1.0f ? config.temperature : temperature_; - // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, resolved_temp, wrapped_callback); @@ -263,9 +266,15 @@ Error TextLLMRunner::generate( RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } - stats_->num_prompt_tokens = - prompt.empty() ? static_cast(pos_) : num_prompt_tokens; - stats_->num_generated_tokens = num_generated_tokens; + // The prefill step produced and emitted one token (cur_token) before the + // token generator ran, so the total generated count is that token plus the + // generator's. For an empty prompt (continuation/prefix-reuse path) the + // prompt length is everything resident before this turn's generation (pos_ + // already includes the generated tokens, so subtract them). + stats_->num_prompt_tokens = prompt.empty() + ? static_cast(pos_) - num_generated_tokens + : num_prompt_tokens; + stats_->num_generated_tokens = num_generated_tokens + 1; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); @@ -284,6 +293,17 @@ Result TextLLMRunner::prefill( const std::vector& inputs, int32_t num_bos, int32_t num_eos) { + // No GenerationConfig on this overload: the deprecated temperature_ wins if + // set (mirroring generate()'s precedence), otherwise greedy. + return prefill_impl( + inputs, num_bos, num_eos, temperature_ == -1.0f ? 0.0f : temperature_); +} + +Result TextLLMRunner::prefill_impl( + const std::vector& inputs, + int32_t num_bos, + int32_t num_eos, + float temperature) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } @@ -297,7 +317,9 @@ Result TextLLMRunner::prefill( "Failed to encode prompt %s", input.get_text().c_str()); std::vector tokens = encode_res.get(); - auto prefill_res = text_prefiller_->prefill(tokens, pos_); + // The first generated token is sampled here during prefill, so honor the + // requested temperature instead of defaulting to greedy. + auto prefill_res = text_prefiller_->prefill(tokens, pos_, temperature); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); prefill_next_token_ = prefill_res.get(); num_bos = 0; @@ -324,7 +346,13 @@ Result TextLLMRunner::prefill( Result TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { - return prefill(prompt, config.num_bos, config.num_eos); + std::vector inputs; + inputs.emplace_back(MultimodalInput(prompt)); + // Honor the request's sampling for the first token (sampled during prefill), + // mirroring generate(): the deprecated temperature_ wins if set. + const float resolved_temp = + temperature_ == -1.0f ? config.temperature : temperature_; + return prefill_impl(inputs, config.num_bos, config.num_eos, resolved_temp); } Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { @@ -343,6 +371,9 @@ Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { } void TextLLMRunner::stop() { + // Honored by both decode paths: generate() checks the token generator's flag, + // the token-step (session) loop checks stop_requested_ in decode_one(). + stop_requested_.store(true, std::memory_order_relaxed); if (is_loaded()) { text_token_generator_->stop(); } else { @@ -354,6 +385,176 @@ void TextLLMRunner::reset() { stats_->reset(); pos_ = 0; prefill_next_token_.reset(); + prev_decode_token_.reset(); + stop_requested_.store(false, std::memory_order_relaxed); +} + +::executorch::runtime::Error TextLLMRunner::seek(int64_t pos) { + // Token-step primitives require a KV cache (a non-KV model has no resident KV + // to rewind); fail closed. + if (metadata_.at(kUseKVCache) == 0) { + ET_LOG(Error, "seek() requires a KV-cache model (use_kv_cache=true)"); + return ::executorch::runtime::Error::NotSupported; + } + // Sliding-window models (max_seq_len < max_context_len) recycle KV space, so + // pos_ is not an absolute position and the prefix [0, pos) may have slid out + // of the window; rewinding would attend to stale KV. Refuse so the caller + // falls back to reset() + full re-prefill. + if (metadata_.at(kMaxSeqLen) < metadata_.at(kMaxContextLen)) { + ET_LOG( + Error, + "seek() is unsupported for sliding-window models " + "(max_seq_len %" PRId64 " < max_context_len %" PRId64 ")", + metadata_.at(kMaxSeqLen), + metadata_.at(kMaxContextLen)); + return ::executorch::runtime::Error::NotSupported; + } + if (pos < 0 || pos > pos_) { + ET_LOG(Error, "seek(%" PRId64 ") out of range [0, %" PRId64 "]", pos, pos_); + return ::executorch::runtime::Error::InvalidArgument; + } + pos_ = pos; + prefill_next_token_.reset(); + prev_decode_token_.reset(); + return ::executorch::runtime::Error::Ok; +} + +::executorch::runtime::Result TextLLMRunner::prefill_tokens( + std::vector tokens, + float temperature) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + // The token-step primitives assume KV-cached decode (each step forwards only + // the new token at pos_). A non-KV model needs full-sequence re-forwarding, + // which this path does not implement — fail closed rather than decode wrong. + ET_CHECK_OR_RETURN_ERROR( + metadata_.at(kUseKVCache) != 0, + NotSupported, + "prefill_tokens/decode_one require a KV-cache model (use_kv_cache=true)"); + if (tokens.empty()) { + ET_LOG(Error, "prefill_tokens called with empty tokens"); + return ::executorch::runtime::Error::InvalidArgument; + } + // Same context-capacity guard as generate(): a caller that seek()s then + // prefill_tokens()es a suffix must not push pos_ past the KV cache. This is + // the only place the bound is enforced for prefill_tokens() (the public + // prefix-cache primitive), since it doesn't go through generate(prompt). + const int64_t max_seq_len = metadata_.at(kMaxSeqLen); + const int64_t max_context_len = metadata_.at(kMaxContextLen); + const int num_tokens = static_cast(tokens.size()); + if (max_seq_len >= max_context_len) { + ET_CHECK_OR_RETURN_ERROR( + pos_ + num_tokens < max_context_len, + InvalidArgument, + "pos_ %" PRId64 " + num_tokens %d >= max_context_len %" PRId64 + ", prefill_tokens would exceed KV cache capacity", + pos_, + num_tokens, + max_context_len); + } else { + ET_CHECK_OR_RETURN_ERROR( + num_tokens < max_context_len, + InvalidArgument, + "num_tokens %d >= max_context_len %" PRId64 + ", prefill_tokens exceeds KV cache capacity", + num_tokens, + max_context_len); + } + // Resolve temperature like decode_one() so the first token (sampled here in + // prefill) honors the request instead of defaulting to greedy. + const float temp = (temperature < 0.0f) + ? (temperature_ == -1.0f ? 0.0f : temperature_) + : temperature; + // A new prefill starts a fresh generation turn; clear any prior stop request. + stop_requested_.store(false, std::memory_order_relaxed); + auto prefill_res = text_prefiller_->prefill(tokens, pos_, temp); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + prefill_next_token_ = prefill_res.get(); + prev_decode_token_.reset(); + return prefill_next_token_.value(); +} + +::executorch::runtime::Result TextLLMRunner::decode_one( + float temperature) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + // See prefill_tokens(): single-token KV stepping is invalid without a KV + // cache. + ET_CHECK_OR_RETURN_ERROR( + metadata_.at(kUseKVCache) != 0, + NotSupported, + "decode_one requires a KV-cache model (use_kv_cache=true)"); + ET_CHECK_OR_RETURN_ERROR( + prefill_next_token_.has_value(), + InvalidState, + "decode_one requires a pending token; call prefill()/prefill_tokens() first"); + + // The pending token is the one we emit this step. + const uint64_t token = prefill_next_token_.value(); + const bool is_eos = text_token_generator_->is_eos(token); + + // Decode the text piece with BPE context (previous token), like generate(). + // Surface tokenizer errors rather than hiding them as an empty piece (matches + // generate(), which logs and returns InvalidArgument). + const uint64_t prev = prev_decode_token_.value_or(token); + auto decode_res = tokenizer_->decode(prev, token); + if (!decode_res.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(decode_res.error())); + return ::executorch::runtime::Error::InvalidArgument; + } + std::string text_piece = std::move(*decode_res); + + // Terminate WITHOUT forwarding the token — at EOS (like generate(), which + // breaks before the next step()) or at a cooperative stop() request observed + // at this token boundary. The token is not made resident and pos_ does not + // advance, so position()/prefix reuse stay correct; no pending token remains, + // so a subsequent decode_one() correctly errors (generation is complete). + // is_eos stays literal; is_terminal is set either way so the loop ends. + if (is_eos || stop_requested_.load(std::memory_order_relaxed)) { + prefill_next_token_.reset(); + return DecodeResult{ + token, std::move(text_piece), is_eos, /*is_terminal=*/true}; + } + + // Only a NON-EOS token is forwarded (made resident at pos_), so the capacity + // check belongs here — after the EOS short-circuit. This lets the final EOS + // be emitted even when the KV cache is exactly full. + if (metadata_.at(kMaxSeqLen) >= metadata_.at(kMaxContextLen)) { + ET_CHECK_OR_RETURN_ERROR( + pos_ < metadata_.at(kMaxContextLen), + InvalidArgument, + "decode_one would exceed KV cache capacity: pos_ %" PRId64 + " >= max_context_len %" PRId64, + pos_, + metadata_.at(kMaxContextLen)); + } + + // Forward `token` at pos_ to predict the next pending token. + std::vector tok_data = {token}; + std::vector<::executorch::aten::SizesType> shape = {1, 1}; + auto tok_tensor = + from_blob(tok_data.data(), shape, ::executorch::aten::ScalarType::Long); + auto logits_res = text_decoder_runner_->step(tok_tensor, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + // Apply the same logit processors generate() does (grammar/tool masks, + // penalties, top-k/top-p) so the session decode path can't diverge from it. + ET_CHECK_OK_OR_RETURN_ERROR( + text_token_generator_->apply_logit_processors(logits_res.get())); + const float temp = (temperature < 0.0f) + ? (temperature_ == -1.0f ? 0.0f : temperature_) + : temperature; + prefill_next_token_ = static_cast( + text_decoder_runner_->logits_to_token(logits_res.get(), temp)); + prev_decode_token_ = token; + pos_ += 1; + + return DecodeResult{ + token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false}; } } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index c73b6a4bed6..8867775568a 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include @@ -28,9 +29,15 @@ // Helper functions are now in llm_runner_helper.h // These are provided for backward compatibility #include +// DecodeResult (returned by decode_one) lives with the Engine/Session API. +#include namespace executorch::extension::llm { +namespace detail { +class TextLLMSession; +} // namespace detail + class ET_EXPERIMENTAL TextLLMRunner : public IRunner { public: /** @@ -123,8 +130,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { prefill(const std::string& prompt, int32_t num_bos = 0, int32_t num_eos = 0); /** - * Prefill a text prompt using GenerationConfig. - * Deprecated: prefer prefill(prompt, num_bos, num_eos). + * Prefill a text prompt using GenerationConfig. Samples the first token + * (sampled during prefill) at config.temperature, unlike the bare overloads + * which default to greedy. */ ::executorch::runtime::Result prefill( const std::string& prompt, @@ -161,6 +169,35 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { void stop() override; private: + // Internal serving hooks: the single-token-step primitives that detail:: + // TextLLMSession composes into the LLMSession serving API. Not a public or + // ABI-stable surface — call them through LLMEngine/LLMSession, never + // directly. Friending the adapter keeps them out of every other caller (and + // out of the Python bindings) so nothing can take a dependency on this shape. + friend class detail::TextLLMSession; + + ::executorch::runtime::Error seek(int64_t pos); + + ::executorch::runtime::Result prefill_tokens( + std::vector tokens, + float temperature = -1.0f); + + int64_t position() const { + return pos_; + } + + ::executorch::runtime::Result decode_one( + float temperature = -1.0f); + + // Shared implementation for the prefill() overloads: encodes the text inputs, + // prefills them at the current position, and samples the first token (sampled + // during prefill) at `temperature`. + ::executorch::runtime::Result prefill_impl( + const std::vector& inputs, + int32_t num_bos, + int32_t num_eos, + float temperature); + // Components std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; std::unordered_map metadata_; @@ -185,8 +222,17 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { // Token predicted by the last prefill() call, consumed by generate(""). std::optional prefill_next_token_; + // Previously emitted token, for BPE-context text decoding in decode_one(). + std::optional prev_decode_token_; + // The position in KV cache of the input, starting from 0. int64_t pos_ = 0; + + // Cooperative stop for the token-step (session) decode loop: stop() sets it, + // decode_one() honors it at the next token boundary, and prefill_tokens() / + // reset() clear it for a fresh generation. Atomic so stop() is safe to call + // from another thread while decode_one() runs. + std::atomic stop_requested_{false}; }; } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_llm_session.h b/extension/llm/runner/text_llm_session.h new file mode 100644 index 00000000000..700691f5317 --- /dev/null +++ b/extension/llm/runner/text_llm_session.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Internal adapter implementing the model-agnostic LLMSession over a +// TextLLMRunner. It lives in `detail` (not a public API) so it can be named as +// a friend of TextLLMRunner: this is the *only* caller of the runner's internal +// token-step hooks (prefill_tokens/decode_one/seek/position). Server and engine +// code depend on LLMSession alone, never on this type or on TextLLMRunner. + +#pragma once + +#include +#include + +#include +#include + +namespace executorch::extension::llm::detail { + +class TextLLMSession : public LLMSession { + public: + explicit TextLLMSession(std::unique_ptr runner); + + ::executorch::runtime::Error prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling = nullptr) override; + ::executorch::runtime::Result decode_one( + const SamplingConfig& sampling) override; + ::executorch::runtime::Error seek(int64_t pos) override; + int64_t position() const override; + ::executorch::runtime::Error reset() override; + void stop() override; + + private: + std::unique_ptr runner_; +}; + +} // namespace executorch::extension::llm::detail diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index a391cef01de..b87fc25a8ec 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -28,7 +28,8 @@ TextPrefiller::TextPrefiller( ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, - int64_t& start_pos) { + int64_t& start_pos, + float temperature) { ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null"); if (!text_decoder_runner_->is_method_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); @@ -54,8 +55,15 @@ ::executorch::runtime::Result TextPrefiller::prefill( num_tokens_to_prefill_with, prompt_tokens_to_process.begin()); - // Process this chunk - auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos); + // Process this chunk. Only the LAST chunk produces the first generated + // token, so apply `temperature` there; intermediate chunks just prefill. + const bool is_last_chunk = + num_tokens_to_process + num_tokens_to_prefill_with >= + num_prompt_tokens; + auto chunk_result = prefill_chunk( + prompt_tokens_to_process, + start_pos, + is_last_chunk ? temperature : 0.0f); ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); cur_token = chunk_result.get(); @@ -65,13 +73,14 @@ ::executorch::runtime::Result TextPrefiller::prefill( return cur_token; } else { // If prompt tokens don't exceed max_seq_len_, process them directly - return prefill_chunk(prompt_tokens, start_pos); + return prefill_chunk(prompt_tokens, start_pos, temperature); } } ::executorch::runtime::Result TextPrefiller::prefill_chunk( std::vector& prompt_tokens, - int64_t& start_pos) { + int64_t& start_pos, + float temperature) { // enable_parallel_prefill_ maybe set even when not using kv cache // When kv cache is not used, start pos is ignored int32_t num_prompt_tokens = prompt_tokens.size(); @@ -92,7 +101,8 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); start_pos += num_prompt_tokens; - cur_token = text_decoder_runner_->logits_to_token(outputs_res.get()); + cur_token = + text_decoder_runner_->logits_to_token(outputs_res.get(), temperature); } else { // sequential prefill int64_t pos = 0; // position in the sequence // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) @@ -128,7 +138,8 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( start_pos++; } - cur_token = text_decoder_runner_->logits_to_token(logits_tensor); + cur_token = + text_decoder_runner_->logits_to_token(logits_tensor, temperature); } return cur_token; } diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index a02cd3d1bf4..3ca961a15cb 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -32,22 +32,28 @@ class ET_EXPERIMENTAL TextPrefiller { * tokenizer. * @param start_pos The starting position in KV cache of the input in the LLM * Module. + * @param temperature Sampling temperature for the first generated token + * (which is sampled here during prefill). Defaults to greedy (0.0). * @return The next token of the LLM Module after prefill. */ virtual ::executorch::runtime::Result prefill( std::vector& prompt_tokens, - int64_t& start_pos); + int64_t& start_pos, + float temperature = 0.0f); /** * Helper method to prefill a chunk of tokens. * @param prompt_tokens The chunk of text prompt tokens to process. * @param start_pos The starting position in KV cache of the input in the LLM * Module. + * @param temperature Sampling temperature for the token produced by this + * chunk. Defaults to greedy (0.0). * @return The next token of the LLM Module after prefilling this chunk. */ virtual ::executorch::runtime::Result prefill_chunk( std::vector& prompt_tokens, - int64_t& start_pos); + int64_t& start_pos, + float temperature = 0.0f); /** * Load the necessary resources for the TextPrefiller. diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 3627cacf3c3..856f1af5b9a 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -55,6 +55,18 @@ class ET_EXPERIMENTAL TextTokenGenerator { return logit_processors_.size(); } + /// Apply the registered logit processors (grammar/tool masks, penalties, + /// top-k/top-p, ...) to `logits` in order, before sampling. Both the + /// generate() loop and session decode_one() call this so the two decode paths + /// stay consistent. + inline ::executorch::runtime::Error apply_logit_processors( + executorch::aten::Tensor& logits) { + for (auto& processor : logit_processors_) { + ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits)); + } + return ::executorch::runtime::Error::Ok; + } + virtual ~TextTokenGenerator() = default; /** @@ -126,9 +138,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { prev_token = cur_token; - for (auto& processor : logit_processors_) { - ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor)); - } + ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors(logits_tensor)); stats_->on_sampling_begin(); cur_token = @@ -180,6 +190,11 @@ class ET_EXPERIMENTAL TextTokenGenerator { should_stop_.store(true, std::memory_order_relaxed); } + /// Whether `token` is an end-of-sequence token (used by single-step decode). + inline bool is_eos(uint64_t token) const { + return eos_ids_->find(token) != eos_ids_->end(); + } + /** * Load the necessary resources for TextTokenGenerator. * This method should be called before using the generate() method. diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 6bfde46eda0..61860580c50 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -13,7 +13,9 @@ #include #include #include +#include #include +#include #include #if defined(__linux__) || defined(__ANDROID__) || defined(__unix__) #include @@ -66,6 +68,87 @@ ET_EXPERIMENTAL void inline safe_printf(const char* piece) { printf("%s", piece); } +// Length of the longest prefix of `s` that does not end in the middle of a +// UTF-8 multi-byte sequence. A byte-level tokenizer can emit a token that is +// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji), +// so a caller streaming text must hold the incomplete tail until it completes +// rather than decode the partial bytes. An invalid lead byte counts as length 1 +// (emitted, so the caller can replace it) rather than stalling output. +ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) { + size_t i = 0; + const size_t n = s.size(); + while (i < n) { + const unsigned char c = static_cast(s[i]); + size_t len; + if (c < 0x80) { + len = 1; + } else if ((c >> 5) == 0x6) { + len = 2; + } else if ((c >> 4) == 0xE) { + len = 3; + } else if ((c >> 3) == 0x1E) { + len = 4; + } else { + len = 1; // invalid lead byte; emit it and let the caller replace it + } + if (i + len > n) { + break; // incomplete trailing sequence: hold it for more bytes + } + i += len; + } + return i; +} + +// How many leading bytes of `text` a streaming consumer may safely emit given a +// set of `stops` strings, and whether a stop was hit (`stop_hit`). +// * If any stop occurs, returns the byte offset of the EARLIEST occurrence +// and +// sets stop_hit=true — text before it is safe; the stop and everything +// after are dropped (the stop is excluded from output). +// * Otherwise returns the length minus the longest possible partial-stop tail +// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a +// multi-byte character is never split; stop_hit=false. Holding back that +// tail lets a stop that straddles the next piece still be caught. +// `text` is expected to be complete-UTF-8 (e.g. the assembled output of +// utf8_complete_prefix_len). Empty `stops` => emit everything, no hold-back. +ET_EXPERIMENTAL size_t inline stop_safe_prefix_len( + const std::string& text, + const std::vector& stops, + bool& stop_hit) { + stop_hit = false; + if (stops.empty()) { + return text.size(); + } + size_t earliest = std::string::npos; + size_t max_len = 0; + for (const auto& s : stops) { + if (s.empty()) { + continue; + } + max_len = std::max(max_len, s.size()); + const size_t p = text.find(s); + if (p != std::string::npos && + (earliest == std::string::npos || p < earliest)) { + earliest = p; + } + } + if (earliest != std::string::npos) { + stop_hit = true; + return earliest; + } + const size_t hold = max_len > 0 ? max_len - 1 : 0; + if (text.size() <= hold) { + return 0; + } + size_t end = text.size() - hold; + // Don't cut in the middle of a UTF-8 character: back up over continuation + // bytes (10xxxxxx). + while (end > 0 && (static_cast(text[end]) & 0xC0) == 0x80) { + --end; + } + return end; +} + // ---------------------------------------------------------------------------- // utilities: time