diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index d56e994eab4..f153c652e3b 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -179,7 +179,9 @@ install( ) # CUDA backend implementation -set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) +set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp + runtime/cuda_mutable_state.cpp +) if(_cuda_is_msvc_toolchain) # MSVC links aoti_cuda_backend into portable_lib without relying on C++ # symbols exported from aoti_cuda_shims.dll. diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index c8449a95718..5c342f3ab83 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -105,9 +105,11 @@ runtime.cxx_library( name = "cuda_backend", srcs = [ "cuda_backend.cpp", + "cuda_mutable_state.cpp", ], headers = [ "cuda_delegate_handle.h", + "cuda_mutable_state.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index d2738f7a976..ee81bd26a6e 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -466,6 +467,10 @@ class ET_EXPERIMENTAL CudaBackend final kCudaGraphWarmupSteps); } + // Record whether this AOTI build exposes the constant-management symbols + // needed for per-session mutable-buffer rebinding (CUDA V2 multi-session). + mutable_state_note_handle(handle); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -514,6 +519,12 @@ class ET_EXPERIMENTAL CudaBackend final static_cast(device_type)); } + // CUDA V2 multi-session: if a logical session is active on this thread, + // rebind this container's mutable constants (KV/conv/recurrent) to the + // session's own GPU buffers before running. No-op for + // single-session/legacy. + ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle)); + // --------------------------------------------------------------- // CUDA graph REPLAY path — skip all tensor setup and just replay // --------------------------------------------------------------- diff --git a/backends/cuda/runtime/cuda_mutable_state.cpp b/backends/cuda/runtime/cuda_mutable_state.cpp new file mode 100644 index 00000000000..d59f52849fb --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.cpp @@ -0,0 +1,474 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::backends::aoti::slim::from_blob; +using ::executorch::backends::aoti::slim::SlimTensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +namespace { + +// Per-handle descriptor of one mutable constant (AOTI internal name differs per +// compiled method, so this is keyed per delegate handle within a context). +struct Desc { + std::string internal_name; + std::vector sizes; + std::vector strides; + slimc10::ScalarType dtype{slimc10::ScalarType::Float}; + slimc10::Device device{slimc10::DeviceType::CUDA, 0}; + size_t nbytes{0}; +}; + +// Cached user-managed pairs for a (handle, session): SlimTensors wrapping the +// session's GPU buffers (kept alive here) and the flat pairs array AOTI +// rebinds. +struct Bound { + std::vector> tensors; + std::vector pairs; +}; + +// All per-engine/model mutable state. Keyed by context id in Manager. +struct Context { + std::vector fqns; + std::unordered_set fqn_set; + + bool symbols_checked{false}; + bool symbols_available{false}; + + // FQN -> device template (the model's initial mutable contents) + sizes. + std::unordered_map template_ptr; + std::unordered_map template_nbytes; + int64_t total_bytes{0}; + + // Per-handle descriptor table + the union of discovered FQNs (for coverage). + std::unordered_map> + desc; + std::unordered_set discovered_fqns; + Error build_error{Error::Ok}; + + std::unordered_set sessions; + int next_token{0}; + // token -> (fqn -> device buffer) shared across the session's handles. + std::unordered_map> session_buf; + // (handle, token) -> cached wrappers + pairs. + std::unordered_map> bound; +}; + +struct Manager { + std::mutex mu; + std::unordered_map contexts; + std::unordered_map handle_ctx; + MutableStateContext next_ctx{1}; +}; + +Manager& mgr() { + static Manager m; + return m; +} + +// The context whose model is currently being loaded on this thread (so +// note_handle, called from CudaBackend::init, can associate handles). And the +// active (context, session) selected before execute on this thread. +thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext; +thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext; +thread_local int tl_active_token = kNoMutableSession; + +bool handle_has_symbols(CudaDelegateHandle* h) { + return h->get_num_constants && h->get_constant_name && + h->get_constant_original_fqn && h->extract_constants_map && + h->update_user_managed_constant_buffer_pairs; +} + +// Build the descriptor table for a handle and capture per-FQN initial +// templates. Caller holds mgr().mu. Runs before any session has rebound this +// container, so the constants still hold the model's initial mutable state. +Error build_descriptors(Context& c, CudaDelegateHandle* h) { + auto container = h->container_handle; + + size_t n = 0; + h->get_num_constants(container, &n); + std::unordered_map fqn_to_internal; + for (size_t i = 0; i < n; ++i) { + const char* internal = nullptr; + const char* fqn = nullptr; + h->get_constant_name(container, i, &internal); + h->get_constant_original_fqn(container, i, &fqn); + if (internal && fqn && fqn[0] != '\0') { + fqn_to_internal[fqn] = internal; + } + } + + std::unordered_map extracted; + ET_CHECK_OK_OR_RETURN_ERROR( + h->extract_constants_map( + container, + reinterpret_cast(&extracted), + /*use_inactive=*/false), + "mutable_state: extract_constants_map failed"); + + auto& table = c.desc[h]; + for (const auto& fqn : c.fqns) { + auto it_name = fqn_to_internal.find(fqn); + auto it_t = extracted.find(fqn); + // A mutable FQN not present in this container = a method that does not use + // it (method-scoped). Skip; another container will own it. + if (it_name == fqn_to_internal.end() || it_t == extracted.end()) { + continue; + } + auto* t = reinterpret_cast(it_t->second); + Desc d; + d.internal_name = it_name->second; + d.sizes.assign(t->sizes().begin(), t->sizes().end()); + d.strides.assign(t->strides().begin(), t->strides().end()); + d.dtype = t->dtype(); + d.device = t->device(); + d.nbytes = t->nbytes(); + table.emplace(fqn, std::move(d)); + c.discovered_fqns.insert(fqn); + + if (c.template_ptr.find(fqn) == c.template_ptr.end()) { + void* tpl = nullptr; + if (cudaMalloc(&tpl, t->nbytes()) != cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMalloc template '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy( + tpl, t->data_ptr(), t->nbytes(), cudaMemcpyDeviceToDevice) != + cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMemcpy template '%s'", fqn.c_str()); + return Error::Internal; + } + c.template_ptr[fqn] = tpl; + c.template_nbytes[fqn] = t->nbytes(); + c.total_bytes += static_cast(t->nbytes()); + } + } + return Error::Ok; +} + +// Allocate a session's GPU buffers, cloned from the initial templates. Caller +// holds mgr().mu. Allocates PER FQN so a buffer is created for any template +// discovered after the session's first allocation. +Error ensure_session_buffers(Context& c, int token) { + auto& buf = c.session_buf[token]; + for (const auto& kv : c.template_ptr) { + const std::string& fqn = kv.first; + if (buf.find(fqn) != buf.end()) { + continue; // already allocated for this session + } + void* tpl = kv.second; + size_t nbytes = c.template_nbytes[fqn]; + void* p = nullptr; + if (cudaMalloc(&p, nbytes) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMalloc session buffer '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy(p, tpl, nbytes, cudaMemcpyDeviceToDevice) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMemcpy session buffer '%s'", fqn.c_str()); + return Error::Internal; + } + buf[fqn] = p; + } + return Error::Ok; +} + +// Build the cached wrappers + pairs for (handle, token). Caller holds mgr().mu. +Error ensure_bound(Context& c, CudaDelegateHandle* h, int token) { + if (c.bound[h].find(token) != c.bound[h].end()) { + return Error::Ok; + } + Bound b; + auto& buf = c.session_buf[token]; + for (const auto& fd : c.desc[h]) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto buf_it = buf.find(fqn); + if (buf_it == buf.end() || buf_it->second == nullptr) { + // Every descriptor for this handle must have a backing session buffer; + // a null bind would silently corrupt state. + ET_LOG(Error, "mutable_state: no session buffer for '%s'", fqn.c_str()); + return Error::Internal; + } + void* ptr = buf_it->second; + auto st = std::make_unique(from_blob( + ptr, + ::executorch::runtime::makeArrayRef(d.sizes.data(), d.sizes.size()), + ::executorch::runtime::makeArrayRef(d.strides.data(), d.strides.size()), + d.dtype, + d.device)); + aoti::AOTInductorConstantMapEntry entry; + entry.name = d.internal_name.c_str(); + entry.handle = reinterpret_cast(st.get()); + b.pairs.push_back(entry); + b.tensors.push_back(std::move(st)); + } + c.bound[h].emplace(token, std::move(b)); + return Error::Ok; +} + +void free_session_buffers(Context& c, int token) { + auto it = c.session_buf.find(token); + if (it != c.session_buf.end()) { + for (auto& kv : it->second) { + if (kv.second) { + cudaFree(kv.second); + } + } + c.session_buf.erase(it); + } + for (auto& hb : c.bound) { + hb.second.erase(token); + } + c.sessions.erase(token); +} + +} // namespace + +MutableStateContext mutable_state_create_context() { + auto& m = mgr(); + std::lock_guard g(m.mu); + MutableStateContext id = m.next_ctx++; + m.contexts[id]; // default-construct + return id; +} + +void mutable_state_destroy_context(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + for (auto& kv : c.template_ptr) { + if (kv.second) { + cudaFree(kv.second); + } + } + for (auto& sb : c.session_buf) { + for (auto& kv : sb.second) { + if (kv.second) { + cudaFree(kv.second); + } + } + } + // Drop handle->ctx associations for this context. + for (auto hit = m.handle_ctx.begin(); hit != m.handle_ctx.end();) { + hit = (hit->second == ctx) ? m.handle_ctx.erase(hit) : std::next(hit); + } + m.contexts.erase(it); +} + +void mutable_state_begin_load(MutableStateContext ctx) { + tl_loading_ctx = ctx; +} + +void mutable_state_end_load() { + tl_loading_ctx = kInvalidMutableContext; +} + +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + it->second.fqns = fqns; + it->second.fqn_set.clear(); + it->second.fqn_set.insert(fqns.begin(), fqns.end()); +} + +bool mutable_state_available(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it != m.contexts.end() && it->second.symbols_available; +} + +int64_t mutable_state_bytes_per_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it == m.contexts.end() ? 0 : it->second.total_bytes; +} + +Error mutable_state_validate_coverage(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + Context& c = it->second; + if (!c.symbols_available) { + return Error::NotSupported; + } + if (c.build_error != Error::Ok) { + return c.build_error; + } + bool ok = true; + for (const auto& fqn : c.fqns) { + if (c.discovered_fqns.find(fqn) == c.discovered_fqns.end()) { + ET_LOG( + Error, + "mutable_state: declared mutable buffer '%s' not found in any loaded " + "method's constants (FQN mismatch?)", + fqn.c_str()); + ok = false; + } + } + return ok ? Error::Ok : Error::InvalidProgram; +} + +Result mutable_state_create_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + if (!it->second.symbols_available) { + ET_LOG( + Error, "mutable_state: rebinding unavailable; cannot create session"); + return Error::NotSupported; + } + int token = it->second.next_token++; + it->second.sessions.insert(token); + return token; +} + +void mutable_state_destroy_session(MutableStateContext ctx, int token) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; // context already torn down; nothing to free + } + free_session_buffers(it->second, token); +} + +void mutable_state_set_active(MutableStateContext ctx, int token) { + tl_active_ctx = ctx; + tl_active_token = token; +} + +void mutable_state_note_handle(CudaDelegateHandle* handle) { + MutableStateContext ctx = tl_loading_ctx; + if (ctx == kInvalidMutableContext) { + return; // not loading within a managed context (e.g. non-V2 path) + } + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + m.handle_ctx[handle] = ctx; + bool ok = handle_has_symbols(handle); + c.symbols_available = c.symbols_checked ? (c.symbols_available && ok) : ok; + c.symbols_checked = true; + // Build this method's descriptor table + capture initial templates now, while + // the container still holds the model's initial mutable state and before any + // session rebinds. Requires FQNs registered before load_method. + if (ok && !c.fqns.empty() && c.desc.find(handle) == c.desc.end()) { + Error e = build_descriptors(c, handle); + if (e != Error::Ok) { + c.build_error = e; + } + } +} + +Error mutable_state_rebind_for_execute(CudaDelegateHandle* handle) { + if (tl_active_token == kNoMutableSession) { + return Error::Ok; // single-session / legacy: nothing to rebind + } + auto& m = mgr(); + std::lock_guard g(m.mu); + + auto hit = m.handle_ctx.find(handle); + if (hit == m.handle_ctx.end()) { + ET_LOG( + Error, + "mutable_state: active session set but handle has no context (load " + "scope missed?)"); + return Error::Internal; + } + MutableStateContext ctx = hit->second; + if (ctx != tl_active_ctx) { + ET_LOG( + Error, + "mutable_state: active context mismatch (caller set a different context " + "active than the one executing)"); + return Error::Internal; + } + auto cit = m.contexts.find(ctx); + if (cit == m.contexts.end()) { + return Error::Internal; + } + Context& c = cit->second; + if (!c.symbols_available) { + ET_LOG( + Error, "mutable_state: active session set but rebinding unavailable"); + return Error::NotSupported; + } + if (c.desc.find(handle) == c.desc.end()) { + ET_LOG( + Error, + "mutable_state: no descriptors for handle (note_handle missed?)"); + return Error::Internal; + } + const int token = tl_active_token; + ET_CHECK_OK_OR_RETURN_ERROR(ensure_session_buffers(c, token)); + ET_CHECK_OK_OR_RETURN_ERROR(ensure_bound(c, handle, token)); + + const Bound& b = c.bound[handle][token]; + if (b.pairs.empty()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR( + handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + b.pairs.data(), + b.pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false), + "mutable_state: update_user_managed_constant_buffer_pairs failed"); + return Error::Ok; +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/cuda_mutable_state.h b/backends/cuda/runtime/cuda_mutable_state.h new file mode 100644 index 00000000000..9ad2a05d92d --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +// CUDA-PRIVATE per-session mutable-state management. This is intentionally NOT +// a generic ExecuTorch (Module/Method/BackendInterface) API: it is the +// CUDA/AOTI implementation of "one loaded model, many logical contexts" and is +// consumed only by CUDA-specific LLM engines (e.g. Qwen35MoEEngine). The public +// serving abstraction stays LLMEngine/LLMSession. +// +// State is keyed by a CONTEXT (one per loaded model/engine), NOT +// process-global, so multiple models (e.g. Qwen + Gemma) and repeated engine +// lifecycles in one process stay isolated. An engine: creates a context, scopes +// its model load (begin/end) so the backend associates each delegate handle +// with the context, registers the model's mutable FQNs, creates sessions, +// selects an active session before each execute, and destroys the context on +// teardown. + +namespace executorch { +namespace backends { +namespace cuda { + +struct CudaDelegateHandle; // defined in cuda_delegate_handle.h + +// Opaque per-engine context id (0 = invalid). +using MutableStateContext = int; +constexpr MutableStateContext kInvalidMutableContext = 0; + +// Active-session sentinel: execute() rebinds nothing (single-session / legacy). +constexpr int kNoMutableSession = -1; + +// --- Engine-facing API (call from the CUDA-specific LLM engine) ------------- + +// Create / destroy a context. destroy frees all of the context's sessions, +// templates, descriptors, and handle associations (safe to call once at engine +// teardown; sessions destroyed afterward become no-ops). +MutableStateContext mutable_state_create_context(); +void mutable_state_destroy_context(MutableStateContext ctx); + +// Scope a model load to a context: call begin BEFORE load_method and end AFTER, +// so the delegate handles initialized during the load are associated with +// `ctx`. Nesting is not supported (one load at a time per thread). +void mutable_state_begin_load(MutableStateContext ctx); +void mutable_state_end_load(); + +// Declare the context's per-session mutable-state FQNs (from the model's +// get_mutable_buffer_metadata). Call before begin_load/load_method. +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns); + +// True if the context's loaded delegate(s) expose the AOTI constant-management +// symbols required for per-session rebinding. If false, the caller MUST run +// single-session. +bool mutable_state_available(MutableStateContext ctx); + +// Bytes one session adds (sum of mutable-buffer sizes), 0 if not yet known. +int64_t mutable_state_bytes_per_session(MutableStateContext ctx); + +// Validate every declared FQN was discovered in some loaded method's constants. +// Call after loading all methods; non-Ok must abort multi-session serving. +::executorch::runtime::Error mutable_state_validate_coverage( + MutableStateContext ctx); + +// Create / destroy a logical session within a context. create returns a token +// (>= 0); buffers are allocated lazily on the session's first execute. +::executorch::runtime::Result mutable_state_create_session( + MutableStateContext ctx); +void mutable_state_destroy_session(MutableStateContext ctx, int token); + +// Select the active (context, session) for subsequent Module::execute calls ON +// THIS THREAD. Set before execute, reset token to kNoMutableSession after; the +// engine must hold its serialization lock across set + execute + read-out. +void mutable_state_set_active(MutableStateContext ctx, int token); + +// --- CudaBackend-internal hooks (called from cuda_backend.cpp) --------------- + +// From CudaBackend::init: associate this handle with the context currently +// being loaded (begin_load), record symbol availability, and build the +// descriptor table + capture initial templates from the still-initial +// constants. +void mutable_state_note_handle(CudaDelegateHandle* handle); + +// From CudaBackend::execute, before running: if a session is active on this +// thread for this handle's context, rebind the container's mutable constants to +// the session's buffers. No-op (Ok) when no session is active. +::executorch::runtime::Error mutable_state_rebind_for_execute( + CudaDelegateHandle* handle); + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index b9e10b4d72d..0fe24130c35 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -67,7 +67,7 @@ list(APPEND link_libraries tokenizers::tokenizers) add_executable(qwen3_5_moe_runner main.cpp qwen35_moe_engine.cpp) target_include_directories( - qwen3_5_moe_runner PUBLIC ${_common_include_directories} + qwen3_5_moe_runner PUBLIC ${_common_include_directories} ${_json_include} ) target_link_libraries(qwen3_5_moe_runner PUBLIC ${link_libraries}) diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 37455c85c9d..2231bdd7566 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -136,7 +136,6 @@ cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \ | `--prompt_file` | (none) | Path to a file with the prompt (overrides `--prompt`) | | `--temperature` | `0.8` | Sampling temperature (0 = greedy) | | `--max_new_tokens` | `128` | Maximum tokens to generate | -| `--cuda_graph` | off | Capture/replay the decode method as a CUDA graph (CUDA only). See the caveat below. | | `--warmup` | `0` | Warmup iterations to discard before timing (one model load; the session is reset between iterations) | | `--num_iters` | `1` | Timed iterations to average, after warmup | @@ -216,12 +215,6 @@ is safe under asyncio. LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \ cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner ... ``` -- **`aoti_torch_cuda_sort_stable ... API call failed` when re-running prefill - with `--cuda_graph`**: capturing the decode CUDA graph and then running another - prefill in the same process currently fails (allocator interaction). Use - `--cuda_graph` for single prefill+decode runs; omit it when looping with - `--warmup`/`--num_iters`. - - **OOM during export**: The model requires significant GPU memory even with int4 quantization. Try reducing `--max-seq-len` or using a GPU with more VRAM. diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index ed787b3c110..0be63a49a4e 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -913,6 +913,77 @@ def _export_metal(model, config, args): print("Done!") +def _qwen_mutable_buffer_fqns(model): + """Explicit per-session mutable-state buffers of the Qwen model (source of truth). + + State ownership is safety-critical for multi-session serving — a single missed + buffer silently bleeds one session's context into another — so we enumerate it + from the model's own module contract rather than inferring it from export + internals. Read-only buffers are deliberately excluded: RoPE inv_freq, + cache_positions, attention masks, TurboQuant centroids/boundaries/rotation, and + all weights. + + This is the model-specific half of the contract; the backend consumes the + resulting get_mutable_buffer_metadata generically. Other models supply their own. + """ + from executorch.examples.models.qwen3_5_moe.model import GatedDeltaNet, KVCache + + fqns = [] + for prefix, module in model.named_modules(): + # TurboQuantKVCache is not a KVCache subclass; match by name to avoid a hard + # dependency on the turboquant module. Its codebook buffers (centroids, + # boundaries, rotation, rotation_T) are read-only and excluded. + if module.__class__.__name__ == "TurboQuantKVCache": + fqns += [ + f"{prefix}.k_packed", + f"{prefix}.k_norms", + f"{prefix}.v_packed", + f"{prefix}.v_norms", + ] + elif isinstance(module, KVCache): + fqns += [f"{prefix}.k_cache", f"{prefix}.v_cache"] + elif isinstance(module, GatedDeltaNet): + fqns += [f"{prefix}.conv_state", f"{prefix}.recurrent_state"] + + named = dict(model.named_buffers()) + missing = [f for f in fqns if f not in named] + if missing: + raise RuntimeError( + f"Qwen mutable-buffer contract references missing buffers: {missing}" + ) + if not fqns: + raise RuntimeError("Qwen mutable-buffer contract is empty") + return sorted(fqns) + + +def _mutable_buffer_metadata_json(model): + """Minimal JSON naming the model's per-session mutable buffers. + + The exported .pte advertises ONLY which named constants are per-session state + (KV/conv/recurrent) versus globally shared weights: + + {"version": 1, "mutable_buffers": ["layers.1.attn.kv_cache.k_cache", ...]} + + All tensor descriptors (dtype/sizes/strides/nbytes/device/AOTI internal name/ + initial template) are the backend's to derive from the loaded container — the + single source of truth — so we deliberately do NOT duplicate them here and risk + drift. The serving runtime backs each session's state with its own storage while + the immutable weights stay shared. + + The FQN set is the model's explicit contract (_qwen_mutable_buffer_fqns). + """ + import json + + fqns = _qwen_mutable_buffer_fqns(model) + named = dict(model.named_buffers()) + total = sum(named[f].numel() * named[f].element_size() for f in fqns) + print( + f" Recorded {len(fqns)} mutable buffers " + f"({total} B / {total / 1024:.1f} KiB per session)" + ) + return json.dumps({"version": 1, "mutable_buffers": fqns}) + + def _export_cuda(model, config, args): """Export model to .pte via torch.export + CUDA backend. @@ -1000,6 +1071,7 @@ def _export_cuda(model, config, args): "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, + "get_mutable_buffer_metadata": _mutable_buffer_metadata_json(model), } et_prog = to_edge_transform_and_lower( {"decode": decode_ep, "prefill": prefill_ep}, diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 88bb2e0ff83..3fd4a062306 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -38,16 +38,12 @@ DEFINE_string( "Path to file containing prompt text (overrides --prompt)."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); -DEFINE_bool( - cuda_graph, - false, - "Enable CUDA graph for decode method. CUDA only."); DEFINE_int32( warmup, 0, "Warmup iterations to discard before timing. One model load; the session is " - "reset between iterations. Warmup captures the CUDA graph and ramps GPU " - "clocks so the timed iterations reflect steady state."); + "reset between iterations. Warmup ramps GPU clocks so the timed iterations " + "reflect steady state."); DEFINE_int32(num_iters, 1, "Timed iterations to average (after warmup)."); namespace llm = ::executorch::extension::llm; @@ -85,7 +81,6 @@ int main(int argc, char** argv) { config.model_path = FLAGS_model_path; config.data_path = FLAGS_data_path; config.tokenizer_path = FLAGS_tokenizer_path; - config.cuda_graph = FLAGS_cuda_graph; printf("Loading methods...\n"); auto engine_result = llm::Qwen35MoEEngine::create(config); @@ -139,9 +134,9 @@ int main(int argc, char** argv) { stats.num_prompt_tokens = num_prompt_tokens; // Warmup + timed iterations on one loaded session (reset between). The first - // FLAGS_warmup iterations are discarded; they trigger CUDA-graph capture, - // allocator growth, and GPU clock ramp so the timed iterations reflect steady - // state. Text is printed only on the first iteration (coherence check). + // FLAGS_warmup iterations are discarded; they let allocator growth and GPU + // clock ramp settle so the timed iterations reflect steady state. Text is + // printed only on the first iteration (coherence check). llm::SamplingConfig sampling; sampling.temperature = static_cast(FLAGS_temperature); const int total_iters = FLAGS_warmup + std::max(1, FLAGS_num_iters); diff --git a/examples/models/qwen3_5_moe/model.md b/examples/models/qwen3_5_moe/model.md index d29177c4c87..7706aab43f4 100644 --- a/examples/models/qwen3_5_moe/model.md +++ b/examples/models/qwen3_5_moe/model.md @@ -145,12 +145,14 @@ Visual and MTP keys are skipped. `lm_head.weight` is cloned from any harness) drive the model without knowing it is Qwen-MoE or CUDA. - **`Qwen35MoEEngine`** owns immutable resources (tokenizer, metadata, EOS ids, - config). `create_session()` builds a `Module` with `share_memory_arenas=true` - and, on CUDA, sets the backend options that must precede `load_method` - (`weight_sharing_across_methods`, optional `enable_cuda_graph_for_method`), - then loads the `prefill`/`decode` methods. `serving_capacity()` reports a - single physical session — cross-session weight sharing is not yet proven, so - it fails closed to 1. + config) and one shared `Module` (`share_memory_arenas=true`, plus on CUDA the + `weight_sharing_across_methods` backend option set before `load_method`), then + loads the `prefill`/`decode` methods once. `create_session()` returns a session + that shares that one model but owns its own per-session mutable state + (KV/conv/recurrent), rebound before execute under the engine lock. + `serving_capacity()` reports how many such sessions fit without duplicating + weights (or 1 if the backend can't rebind). The serving path is still + single-slot until the worker exposes multi-session. - **`Qwen35MoESession`** owns the mutable conversation state (KV / conv / recurrent arenas via the Module, position cursor, pending token). `prefill_tokens` dispatches to `prefill` (T≥2) or `decode` (T==1); diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 32ff1b4c0f9..b14ab38e656 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -20,6 +20,8 @@ #ifdef EXECUTORCH_BUILD_CUDA #include +#include +#include #else #include #endif @@ -71,10 +73,11 @@ Result read_sampled_token( #endif } -// Build a Qwen Module with shared mutable arenas (so prefill and decode share -// KV/conv/recurrent state) and, on CUDA, the weight-sharing/cuda-graph backend -// options that MUST be set before load_method. Loads the prefill+decode methods -// (this is the heavy ~weights load). Shared by create_session() and reset(). +// Build the one shared Qwen Module: shared mutable arenas (so prefill and +// decode share KV/conv/recurrent state) and, on CUDA, the weight-sharing +// backend option that MUST be set before load_method. Loads the prefill+decode +// methods once (the heavy ~weights load). Called once when the engine is +// created. Result> build_qwen_module( const Qwen35MoEConfig& config) { std::vector data_files; @@ -92,13 +95,9 @@ Result> build_qwen_module( #ifdef EXECUTORCH_BUILD_CUDA // Backend options are read during backend init(), so they must be set before - // load_method. - if (config.cuda_graph) { - executorch::runtime::BackendOptions<1> cuda_opts; - cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); - ET_CHECK_OK_OR_RETURN_ERROR( - executorch::runtime::set_option("CudaBackend", cuda_opts.view())); - } + // load_method. (CUDA graph is intentionally not enabled: V2 rebinds each + // session's mutable buffers before execute, which a captured graph's baked + // pointers would ignore.) { // Cross-method per-FQN weight sharing: prefill and decode reuse one weight // allocation instead of duplicating it (critical to fit on one GPU). @@ -115,22 +114,82 @@ Result> build_qwen_module( return module; } +#ifdef EXECUTORCH_BUILD_CUDA +// Read the model's per-session mutable-buffer FQNs from its export metadata +// ({"version":1,"mutable_buffers":[...]}) and register them with the CUDA +// backend so it can give each session its own GPU buffers for that state. +Error register_mutable_fqns(Module* module, int mutable_ctx) { + auto res = module->execute("get_mutable_buffer_metadata"); + if (res.error() != Error::Ok) { + ET_LOG( + Error, + "Qwen35MoEEngine: model has no get_mutable_buffer_metadata; re-export " + "for V2 multi-session"); + return res.error(); + } + const auto& outs = res.get(); + if (outs.empty() || !outs[0].isString()) { + ET_LOG(Error, "get_mutable_buffer_metadata did not return a string"); + return Error::InvalidProgram; + } + std::string json_str(outs[0].toString()); + auto j = nlohmann::json::parse(json_str, nullptr, /*allow_exceptions=*/false); + if (j.is_discarded() || !j.is_object()) { + ET_LOG(Error, "get_mutable_buffer_metadata is not a valid JSON object"); + return Error::InvalidProgram; + } + if (!j.contains("version") || !j["version"].is_number_integer() || + j["version"].get() != 1) { + ET_LOG(Error, "get_mutable_buffer_metadata: unsupported/missing version"); + return Error::InvalidProgram; + } + if (!j.contains("mutable_buffers") || !j["mutable_buffers"].is_array() || + j["mutable_buffers"].empty()) { + ET_LOG( + Error, + "get_mutable_buffer_metadata: mutable_buffers must be a non-empty array"); + return Error::InvalidProgram; + } + std::vector fqns; + for (const auto& f : j["mutable_buffers"]) { + if (!f.is_string() || f.get().empty()) { + ET_LOG( + Error, + "get_mutable_buffer_metadata: every mutable_buffers entry must be a " + "non-empty string"); + return Error::InvalidProgram; + } + fqns.push_back(f.get()); + } + ::executorch::backends::cuda::mutable_state_register_fqns(mutable_ctx, fqns); + return Error::Ok; +} +#endif + // LLMSession over the Qwen3.5 MoE prefill/decode methods. Owns one physical // Module (one weight allocation + its KV/recurrent/conv state). Internal: the // server depends only on the LLMSession base. class Qwen35MoESession : public LLMSession { public: Qwen35MoESession( - std::unique_ptr module, + Module* module, + std::mutex* exec_mutex, + int mutable_ctx, + int session_token, + std::atomic* live_sessions, ::tokenizers::Tokenizer* tokenizer, std::unordered_map metadata, std::unordered_set eos_ids) - : module_(std::move(module)), + : module_(module), + exec_mutex_(exec_mutex), + mutable_ctx_(mutable_ctx), + session_token_(session_token), + live_sessions_(live_sessions), tokenizer_(tokenizer), metadata_(std::move(metadata)), eos_ids_(std::move(eos_ids)) { - // Persistent single-step decode buffers: stable addresses are required so - // CUDA-graph capture (which records buffer pointers) can replay each step. + // Persistent single-step decode buffers, reused (updated in place) across + // decode steps to avoid per-step reallocation. decode_tokens_ = from_blob( decode_token_data_, {1, 1}, executorch::aten::ScalarType::Long); decode_pos_ = @@ -141,6 +200,19 @@ class Qwen35MoESession : public LLMSession { #endif } + ~Qwen35MoESession() override { +#ifdef EXECUTORCH_BUILD_CUDA + if (session_token_ != ::executorch::backends::cuda::kNoMutableSession) { + ::executorch::backends::cuda::mutable_state_destroy_session( + mutable_ctx_, session_token_); + } +#endif + // Release the engine's capacity slot reserved in create_session(). + if (live_sessions_ != nullptr) { + live_sessions_->fetch_sub(1); + } + } + Error prefill_tokens( std::vector tokens, const SamplingConfig* initial_sampling) override { @@ -206,25 +278,12 @@ class Qwen35MoESession : public LLMSession { set_temp(first_token_temp); inputs.push_back(EValue(temp_tensor_)); #endif - auto res = module_->execute(method, inputs); - ET_CHECK_OK_OR_RETURN_ERROR(res.error()); auto sampled = - read_sampled_token(res.get()[0].toTensor(), first_token_temp); + run_locked(method, inputs, first_token_temp, /*sync_after=*/true); ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); pending_ = sampled.get(); prev_decode_token_.reset(); pos_ += T; // the prompt tokens are now resident in KV/state -#ifdef EXECUTORCH_BUILD_CUDA - // Make prefill's writes to the shared mutable arenas visible to decode - // (which may run on a different stream). This barrier is relied upon for - // correctness, and it also surfaces any async error from the prefill - // launch, so a non-success here must abort the request rather than decode - // on stale or corrupt state. - if (cudaDeviceSynchronize() != cudaSuccess) { - ET_LOG(Error, "prefill_tokens: cudaDeviceSynchronize failed"); - return Error::Internal; - } -#endif return Error::Ok; } @@ -296,9 +355,8 @@ class Qwen35MoESession : public LLMSession { set_temp(temperature_); inputs.push_back(EValue(temp_tensor_)); #endif - auto res = module_->execute("decode", inputs); - ET_CHECK_OK_OR_RETURN_ERROR(res.error()); - auto sampled = read_sampled_token(res.get()[0].toTensor(), temperature_); + auto sampled = + run_locked("decode", inputs, temperature_, /*sync_after=*/false); ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); pending_ = sampled.get(); prev_decode_token_ = token; @@ -347,7 +405,49 @@ class Qwen35MoESession : public LLMSession { } #endif - std::unique_ptr module_; + // Run a method with THIS session's mutable state bound, then read the sampled + // token — all inside one engine-lock critical section so another session + // cannot rebind between this session's rebind, execute, and read-out. + Result run_locked( + const char* method, + std::vector& inputs, + float temperature, + bool sync_after) { + std::lock_guard guard(*exec_mutex_); +#ifdef EXECUTORCH_BUILD_CUDA + ::executorch::backends::cuda::mutable_state_set_active( + mutable_ctx_, session_token_); +#endif + auto res = module_->execute(method, inputs); +#ifdef EXECUTORCH_BUILD_CUDA + ::executorch::backends::cuda::mutable_state_set_active( + mutable_ctx_, ::executorch::backends::cuda::kNoMutableSession); +#endif + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + auto sampled = read_sampled_token(res.get()[0].toTensor(), temperature); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); +#ifdef EXECUTORCH_BUILD_CUDA + // Prefill runs on a different stream than decode; sync so its writes to the + // session's mutable buffers are visible to the session's first decode (also + // surfaces any async launch error). Decode reads its own writes in stream + // order, so it does not need this. + if (sync_after && cudaDeviceSynchronize() != cudaSuccess) { + ET_LOG(Error, "run_locked: cudaDeviceSynchronize failed"); + return Error::Internal; + } +#else + (void)sync_after; +#endif + return sampled.get(); + } + + Module* module_; // non-owning; the engine's one shared physical model + std::mutex* + exec_mutex_; // non-owning; serializes rebind+execute across sessions + int mutable_ctx_; // engine's CUDA mutable-state context (per-engine) + int session_token_; // CUDA per-session mutable-state token (or + // kNoMutableSession) + std::atomic* live_sessions_; // non-owning; engine capacity counter ::tokenizers::Tokenizer* tokenizer_; // non-owning; owned by the engine std::unordered_map metadata_; std::unordered_set eos_ids_; @@ -358,7 +458,7 @@ class Qwen35MoESession : public LLMSession { float temperature_ = -1.0f; std::atomic stop_{false}; - // Persistent single-step decode buffers (stable addresses for CUDA graph). + // Persistent single-step decode buffers (reused across decode steps). int64_t decode_token_data_[1] = {0}; int64_t decode_pos_data_[1] = {0}; TensorPtr decode_tokens_; @@ -416,15 +516,133 @@ Result> Qwen35MoEEngine::create( "not stop at end of turn"); } + int mutable_ctx = 0; // kInvalidMutableContext +#ifdef EXECUTORCH_BUILD_CUDA + // Create this engine's own mutable-state context (per-engine, not global) and + // register the per-session mutable-buffer FQNs from the .pte metadata BEFORE + // loading the heavy methods, so the CUDA backend associates the load's + // handles with this context and builds descriptors from the still-initial + // constants. + mutable_ctx = ::executorch::backends::cuda::mutable_state_create_context(); + if (Error e = register_mutable_fqns(meta_module.get(), mutable_ctx); + e != Error::Ok) { + ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx); + return e; + } + ::executorch::backends::cuda::mutable_state_begin_load(mutable_ctx); +#endif + + // Build the ONE shared physical model (the heavy ~weights load). All sessions + // reuse it; each rebinds its own mutable buffers before execute. + auto module_res = build_qwen_module(config); +#ifdef EXECUTORCH_BUILD_CUDA + ::executorch::backends::cuda::mutable_state_end_load(); +#endif + if (module_res.error() != Error::Ok) { +#ifdef EXECUTORCH_BUILD_CUDA + ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx); +#endif + return module_res.error(); + } + std::unique_ptr shared_module = std::move(module_res.get()); + + bool rebind_available = false; +#ifdef EXECUTORCH_BUILD_CUDA + rebind_available = + ::executorch::backends::cuda::mutable_state_available(mutable_ctx); + if (rebind_available) { + // Fail closed: if any declared mutable FQN was not found in the loaded + // methods' constants, multi-session would run without rebinding it and + // bleed state — fall back to single-session instead. + if (::executorch::backends::cuda::mutable_state_validate_coverage( + mutable_ctx) != Error::Ok) { + ET_LOG( + Error, + "Qwen35MoEEngine: mutable-buffer coverage check failed; disabling " + "multi-session (capacity clamped to 1)."); + rebind_available = false; + } + } + if (!rebind_available) { + ET_LOG( + Info, + "Qwen35MoEEngine: per-session rebinding unavailable; serving capacity " + "clamped to 1 session."); + } +#endif + return std::unique_ptr(new Qwen35MoEEngine( - config, std::move(tokenizer), metadata_result.get(), std::move(eos_ids))); + config, + std::move(tokenizer), + metadata_result.get(), + std::move(eos_ids), + std::move(shared_module), + rebind_available, + mutable_ctx)); +} + +Qwen35MoEEngine::~Qwen35MoEEngine() { +#ifdef EXECUTORCH_BUILD_CUDA + if (mutable_ctx_ != 0) { + ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx_); + } +#endif } Result> Qwen35MoEEngine::create_session() { - auto module = build_qwen_module(config_); - ET_CHECK_OK_OR_RETURN_ERROR(module.error()); + // Enforce serving_capacity(): without rebinding, capacity is 1, so a second + // session would silently share the resident KV/conv/recurrent state. Reserve + // a slot under the exec lock (released in ~Qwen35MoESession). + const int cap = + serving_capacity().max_physical_sessions_without_weight_duplication; + { + std::lock_guard g(exec_mutex_); + if (live_sessions_.load() >= cap) { + ET_LOG( + Error, + "Qwen35MoEEngine: at session capacity (%d); refusing create_session " + "(would share state or duplicate weights)", + cap); + return Error::InvalidState; + } + live_sessions_.fetch_add(1); + } + + int token = -1; // kNoMutableSession: single-session / no rebind +#ifdef EXECUTORCH_BUILD_CUDA + if (rebind_available_) { + auto t = ::executorch::backends::cuda::mutable_state_create_session( + mutable_ctx_); + if (t.error() != Error::Ok) { + live_sessions_.fetch_sub(1); + return t.error(); + } + token = t.get(); + } +#endif return std::unique_ptr(new Qwen35MoESession( - std::move(module.get()), tokenizer_.get(), metadata_, eos_ids_)); + shared_module_.get(), + &exec_mutex_, + mutable_ctx_, + token, + &live_sessions_, + tokenizer_.get(), + metadata_, + eos_ids_)); +} + +LLMServingCapacity Qwen35MoEEngine::serving_capacity() const { + LLMServingCapacity cap; // default: 1 session, 0 bytes (unknown) +#ifdef EXECUTORCH_BUILD_CUDA + if (rebind_available_) { + cap.max_physical_sessions_without_weight_duplication = + config_.max_sessions > 1 ? config_.max_sessions : 1; + cap.estimated_bytes_per_session = + ::executorch::backends::cuda::mutable_state_bytes_per_session( + mutable_ctx_); + } +#endif + return cap; } } // namespace executorch::extension::llm diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.h b/examples/models/qwen3_5_moe/qwen35_moe_engine.h index 9fb9e99d71e..72793972d3e 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.h +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.h @@ -12,17 +12,28 @@ // // The public surface is backend-agnostic: the server receives an LLMEngine and // never branches on CUDA vs MLX. Backend-specific execution (CUDA in-graph -// sampling, weight-sharing/cuda-graph backend options, device sync) is isolated -// behind EXECUTORCH_BUILD_CUDA inside the .cpp; those isolated points are where -// an MLX runtime would slot in. MLX is NOT implemented or validated here. +// sampling, the weight-sharing backend option, per-session mutable rebinding, +// device sync) is isolated behind EXECUTORCH_BUILD_CUDA inside the .cpp; those +// isolated points are where an MLX runtime would slot in. MLX is NOT +// implemented or validated here. // -// V1: serving_capacity() reports a single physical session (one Module = one -// weight allocation). Multiple weight-sharing sessions are a measured V2 step. +// V2 (CUDA): the ENGINE is multi-session — one shared Module (weights loaded +// once); create_session() hands out multiple logical sessions, each rebinding +// its own GPU buffers for the model's mutable state (KV/conv/recurrent) before +// execute, serialized by the engine lock. serving_capacity() reports how many +// such sessions fit without duplicating weights, or 1 if the backend cannot +// rebind. The per-session rebind machinery is CUDA-backend-private (see +// backends/cuda/runtime/cuda_mutable_state). +// +// The SERVING path (qwen3_5_moe_worker + control plane) is still single-slot: +// it creates one session and queues requests on it. Exposing the engine's +// multi-session capability over the worker protocol is a follow-up. #pragma once #include #include +#include #include #include #include @@ -41,26 +52,32 @@ struct Qwen35MoEConfig { std::string model_path; // .pte std::string data_path; // .ptd (CUDA delegate blob); empty if none std::string tokenizer_path; // HuggingFace tokenizer.json - bool cuda_graph = false; // enable CUDA graph capture for the decode method + // V2 multi-session: max physical sessions to advertise when the backend can + // host them without weight duplication (CUDA per-session mutable rebinding). + // Clamped to 1 if the backend cannot rebind. + int32_t max_sessions = 1; }; /// Engine over one loaded Qwen3.5 MoE Program. Owns immutable resources -/// (tokenizer, metadata, eos ids, config) and creates sessions that each own a -/// physical Module with its own KV/recurrent/conv state. +/// (tokenizer, metadata, eos ids, config) plus one shared Module (weights +/// loaded once); creates sessions that share that Module but each own their +/// per-session mutable state (KV/recurrent/conv), rebound before execute under +/// the engine lock. class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { public: static ::executorch::runtime::Result> create( const Qwen35MoEConfig& config); + ~Qwen35MoEEngine() override; + ::executorch::runtime::Result> create_session() override; - // V1: one physical session; weight sharing across sessions is unproven, so we - // fail closed to 1 (the server queues concurrent requests on the resident - // session rather than duplicating ~18GB of weights). - LLMServingCapacity serving_capacity() const override { - return LLMServingCapacity{}; - } + // CUDA V2: one shared Module (one weight allocation); each session rebinds + // its own GPU buffers for the model's mutable state. Reports + // config.max_sessions when the backend supports per-session rebinding, else + // fails closed to 1. + LLMServingCapacity serving_capacity() const override; const std::unordered_map& metadata() const override { return metadata_; @@ -81,16 +98,37 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { Qwen35MoEConfig config, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unordered_map metadata, - std::unordered_set eos_ids) + std::unordered_set eos_ids, + std::unique_ptr shared_module, + bool rebind_available, + int mutable_ctx) : config_(std::move(config)), tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), - eos_ids_(std::move(eos_ids)) {} + eos_ids_(std::move(eos_ids)), + shared_module_(std::move(shared_module)), + rebind_available_(rebind_available), + mutable_ctx_(mutable_ctx) {} Qwen35MoEConfig config_; std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; std::unordered_map metadata_; std::unordered_set eos_ids_; + + // One physical model shared by all sessions (one weight allocation). Sessions + // hold a non-owning pointer to it and execute under exec_mutex_. + std::unique_ptr shared_module_; + std::mutex exec_mutex_; + // Whether the loaded CUDA delegate supports per-session mutable rebinding. + bool rebind_available_ = false; + // CUDA mutable-state context for this engine's model (per-engine, not + // global); destroyed in the destructor. kInvalidMutableContext (0) when + // unused. + int mutable_ctx_ = 0; + // Live sessions, enforced against serving_capacity() so the engine never + // hands out more sessions than it can host without sharing state / + // duplicating weights. Decremented when a session is destroyed. + std::atomic live_sessions_{0}; }; } // namespace executorch::extension::llm diff --git a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp index 2cc705f96e1..aa94a704bc2 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp @@ -18,8 +18,10 @@ // process segfaults in the int4 matmul (validated). Here the model runs in a // plain synchronous loop in its own process, which is reliable. // -// V1: single-slot (one engine == one ~18GB weight allocation == one session); -// the control plane queues concurrent requests on the resident session. +// Single-slot serving: this worker creates one session and the control plane +// queues concurrent requests on it. (The engine itself can host multiple +// sessions on the one ~18GB weight allocation; exposing that over the worker +// protocol is a follow-up.) #include @@ -33,7 +35,6 @@ DEFINE_string(model_path, "", "Model .pte file path."); DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); DEFINE_string(data_path, "", "Data file (.ptd) for the CUDA backend."); -DEFINE_bool(cuda_graph, false, "Enable CUDA graph for the decode method."); namespace { namespace llm = ::executorch::extension::llm; @@ -53,7 +54,6 @@ int main(int argc, char** argv) { config.model_path = FLAGS_model_path; config.data_path = FLAGS_data_path; config.tokenizer_path = FLAGS_tokenizer_path; - config.cuda_graph = FLAGS_cuda_graph; auto engine_result = llm::Qwen35MoEEngine::create(config); if (engine_result.error() != Error::Ok) {