diff --git a/Makefile b/Makefile index 9c8476d30ed..d0c046d878d 100644 --- a/Makefile +++ b/Makefile @@ -129,7 +129,7 @@ help: @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" - @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" + @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner + OpenAI serving worker (CUDA)" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -431,11 +431,13 @@ voxtral_tts-cuda: qwen3_5_moe-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." cmake --workflow --preset llm-release-cuda - @echo "==> Building Qwen3.5 MoE runner with CUDA..." + @echo "==> Building Qwen3.5 MoE runner + serving worker with CUDA..." cd examples/models/qwen3_5_moe && cmake --workflow --preset qwen3-5-moe-cuda @echo "" @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" + @echo " Serving worker: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_worker" + @echo " Launch: see examples/models/qwen3_5_moe/README.md (Serving)" gemma4_31b-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index d1cfe54a56f..a9b179f3fa0 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -15,6 +15,11 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) set(_common_include_directories ${EXECUTORCH_ROOT}/..) +# Vendored single-include nlohmann/json for the worker JSONL protocol (no new +# dependency). +set(_json_include + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include +) # gflags set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) @@ -60,7 +65,7 @@ endif() # Tokenizer list(APPEND link_libraries tokenizers::tokenizers) -add_executable(qwen3_5_moe_runner main.cpp) +add_executable(qwen3_5_moe_runner main.cpp qwen35_moe_engine.cpp) target_include_directories( qwen3_5_moe_runner PUBLIC ${_common_include_directories} ) @@ -70,3 +75,18 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(qwen3_5_moe_runner) target_link_options(qwen3_5_moe_runner PRIVATE "LINKER:-s") endif() + +# Process-isolated serving worker (qwen3_5_moe_worker): constructs +# Qwen35MoEEngine directly and speaks the JSONL worker protocol that the Python +# control plane drives via WorkerClient (no pybind, no Python model code). Built +# alongside the runner by the qwen3-5-moe-cuda preset. +add_executable(qwen3_5_moe_worker qwen35_moe_worker.cpp qwen35_moe_engine.cpp) +target_include_directories( + qwen3_5_moe_worker PUBLIC ${_common_include_directories} ${_json_include} +) +target_link_libraries(qwen3_5_moe_worker PUBLIC ${link_libraries}) + +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(qwen3_5_moe_worker) + target_link_options(qwen3_5_moe_worker PRIVATE "LINKER:-s") +endif() diff --git a/examples/models/qwen3_5_moe/CMakePresets.json b/examples/models/qwen3_5_moe/CMakePresets.json index 34ebc938280..d5f841c5587 100644 --- a/examples/models/qwen3_5_moe/CMakePresets.json +++ b/examples/models/qwen3_5_moe/CMakePresets.json @@ -13,7 +13,7 @@ }, { "name": "qwen3-5-moe-cuda", - "displayName": "Qwen3.5 MoE runner (CUDA)", + "displayName": "Qwen3.5 MoE runner + serving worker (CUDA)", "inherits": ["qwen3-5-moe-base"], "cacheVariables": { "EXECUTORCH_BUILD_CUDA": "ON" @@ -41,9 +41,9 @@ "buildPresets": [ { "name": "qwen3-5-moe-cuda", - "displayName": "Build Qwen3.5 MoE runner (CUDA)", + "displayName": "Build Qwen3.5 MoE runner + serving worker (CUDA)", "configurePreset": "qwen3-5-moe-cuda", - "targets": ["qwen3_5_moe_runner"] + "targets": ["qwen3_5_moe_runner", "qwen3_5_moe_worker"] }, { "name": "qwen3-5-moe-metal", @@ -55,7 +55,7 @@ "workflowPresets": [ { "name": "qwen3-5-moe-cuda", - "displayName": "Configure and build Qwen3.5 MoE runner (CUDA)", + "displayName": "Configure and build Qwen3.5 MoE runner + serving worker (CUDA)", "steps": [ { "type": "configure", diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 83373a804f4..e51744e25d7 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -100,14 +100,16 @@ It can be uploaded to HuggingFace Hub for easy sharing. ExecuTorch must be installed from source first (see [Prerequisites](#prerequisites)). The `make` target handles building -core libraries and the runner binary. +core libraries and the binaries. ```bash make qwen3_5_moe-cuda ``` This builds ExecuTorch with CUDA backend support, then the runner binary -at `cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner`. +at `cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner` and the +serving worker at `cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_worker` +(see [Serving](#serving-openai-compatible)). ## Run @@ -133,11 +135,95 @@ cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \ | `--data_path` | (none) | Path to `.ptd` delegate data file (required for CUDA) | | `--tokenizer_path` | (required) | Path to HuggingFace `tokenizer.json` | | `--prompt` | `"Hello"` | Input prompt text | +| `--prompt_file` | (none) | Path to a file with the prompt (overrides `--prompt`) | | `--temperature` | `0.8` | Sampling temperature (0 = greedy) | | `--max_new_tokens` | `128` | Maximum tokens to generate | +| `--cuda_graph` | off | Capture/replay the decode method as a CUDA graph (CUDA only). See the caveat below. | +| `--warmup` | `0` | Warmup iterations to discard before timing (one model load; the session is reset between iterations) | +| `--num_iters` | `1` | Timed iterations to average, after warmup | + +## Serving (OpenAI-compatible) + +Run an OpenAI-compatible HTTP server so an agent harness (pi, opencode, …) can +use the model for local tool-use. Point your client at `http://:/v1`. + +The CUDA build produces the runner **and** the serving worker: + +```bash +make qwen3_5_moe-cuda +``` + +Launch (the `LD_LIBRARY_PATH` shim is forwarded to the worker for the CUDA blob): + +```bash +LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \ + python -m executorch.examples.models.qwen3_5_moe.serve \ + --model-path qwen35_moe_exports/model.pte \ + --data-path qwen35_moe_exports/aoti_cuda_blob.ptd \ + --tokenizer-path ~/models/Qwen3.5-35B-A3B/tokenizer.json \ + --hf-tokenizer ~/models/Qwen3.5-35B-A3B \ + --model-id qwen3.5-moe --no-think +``` + +### Architecture (process isolation) + +Two processes, one model load: + +``` +serve.py (control plane: FastAPI/asyncio, OpenAI protocol, chat + templating, tool parsing, validation — NO CUDA, NO pybind) + │ JSONL over stdin/stdout + ▼ +qwen3_5_moe_worker (C++ binary: one Qwen35MoEEngine + one session, synchronous + loop — the CUDA model; NO asyncio server) +``` + +The model runs in a **separate worker process** because executing the AOTI CUDA +model inside a live asyncio server process segfaults in the int4 matmul +(reproducible, and isolated by elimination to the asyncio-loop × CUDA +interaction). The worker runs the model like the CLI — a plain synchronous loop — +which is reliable. The control plane only does blocking pipe I/O (no CUDA), which +is safe under asyncio. + +### Serve Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--model-path` | (required) | Path to exported `.pte` model | +| `--data-path` | (none) | Path to `.ptd` delegate data file (required for CUDA) | +| `--tokenizer-path` | (required) | Path to HuggingFace `tokenizer.json` | +| `--hf-tokenizer` | (required) | HF tokenizer id/dir for the chat template + encoding | +| `--model-id` | `qwen3.5-moe` | Model id reported on `/v1/models` | +| `--host` / `--port` | `127.0.0.1` / `8000` | Bind address | +| `--max-context` | (none) | Reject prompts that exceed it with 400 | +| `--no-think` | off | Default reasoning off (`enable_thinking=False`) | + +### V1 limitations + +- **Single-slot** (`serving_capacity=1`): one worker, one session, one model + load. `--num-runners > 1` is rejected; concurrent requests queue on the worker. +- **No prefix cache**: the recurrent/conv state cannot be rewound by position + (`seek()` is NotSupported), so turn-to-turn KV reuse is off. +- Supports the chat-completions contract of the generic server; `top_p != 1`, + `seed`, `top_k`, `logprobs`, etc. are rejected (only temperature is plumbed). ## Troubleshooting +- **Runner exits silently right after `Loading methods...`**: the AOTI CUDA blob + is compiled with the conda toolchain's `libstdc++`, which is newer than the + system one (it needs e.g. `GLIBCXX_3.4.34`). Prepend the conda lib dir so the + runner loads the matching `libstdc++`: + + ```bash + LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \ + cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner ... + ``` +- **`aoti_torch_cuda_sort_stable ... API call failed` when re-running prefill + with `--cuda_graph`**: capturing the decode CUDA graph and then running another + prefill in the same process currently fails (allocator interaction). Use + `--cuda_graph` for single prefill+decode runs; omit it when looping with + `--warmup`/`--num_iters`. + - **OOM during export**: The model requires significant GPU memory even with int4 quantization. Try reducing `--max-seq-len` or using a GPU with more VRAM. diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 19d93af0d58..88bb2e0ff83 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -6,17 +6,17 @@ * LICENSE file in the root directory of this source tree. */ +// Thin CLI over Qwen35MoEEngine / Qwen35MoESession: parse flags, build the +// engine + a session, encode the prompt, prefill_tokens(), then loop +// decode_one() printing pieces and timing/stats. All model execution lives in +// qwen35_moe_engine.{h,cpp}. + #include -#include +#include #include #include -#include -#include -#include -#include #include -#include #include #include @@ -26,8 +26,6 @@ #ifdef EXECUTORCH_BUILD_CUDA #include -#else -#include #endif DEFINE_string(model_path, "", "Model .pte file path."); @@ -44,57 +42,16 @@ DEFINE_bool( cuda_graph, false, "Enable CUDA graph for decode method. CUDA only."); +DEFINE_int32( + warmup, + 0, + "Warmup iterations to discard before timing. One model load; the session is " + "reset between iterations. Warmup captures the CUDA graph and ramps GPU " + "clocks so the timed iterations reflect steady state."); +DEFINE_int32(num_iters, 1, "Timed iterations to average (after warmup)."); namespace llm = ::executorch::extension::llm; -using ::executorch::extension::from_blob; -using ::executorch::extension::Module; -using ::executorch::extension::TensorPtr; using ::executorch::runtime::Error; -using ::executorch::runtime::EValue; - -using SizesType = executorch::aten::SizesType; - -// Convert a model output tensor to the next sampled token id. -// -// On the CUDA build, the model fuses the sampler in (see sampler.py / -// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1] -// float tensor; we just copy that scalar back from device. -// -// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits -// of shape [B, T, V] in the model dtype (typically bf16). We sample on -// CPU via the shared `llm::logits_to_token` helper, which accepts a -// temperature (0 = greedy / argmax). -static uint64_t read_token(const executorch::aten::Tensor& output) { -#ifdef EXECUTORCH_BUILD_CUDA - const void* ptr = output.const_data_ptr(); - - cudaPointerAttributes attrs; - bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && - attrs.type == cudaMemoryTypeDevice; - - float val; - if (on_device) { - cudaError_t err = - cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); - if (err != cudaSuccess) { - ET_LOG( - Error, - "read_token: cudaMemcpy D2H failed: %s", - cudaGetErrorString(err)); - return 0; - } - } else { - memcpy(&val, ptr, sizeof(float)); - } - return static_cast(val); -#else - // logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 / - // UInt16 dtypes. Negative temperatures are clamped to 0 (greedy). - const float temp = - FLAGS_temperature <= 0.0 ? 0.0f : static_cast(FLAGS_temperature); - return static_cast(llm::logits_to_token(output, temp)); -#endif -} int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -111,28 +68,6 @@ int main(int argc, char** argv) { llm::Stats stats; #ifdef EXECUTORCH_BUILD_CUDA - // GPU memory before load - size_t gpu_free_bytes = 0, gpu_total_bytes = 0; - cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); - stats.gpu_total_bytes = gpu_total_bytes; - stats.gpu_free_before_load_bytes = gpu_free_bytes; -#endif - - stats.model_load_start_ms = llm::time_in_ms(); - - // Load tokenizer - auto tokenizer = std::make_unique(); - auto tok_status = tokenizer->load(FLAGS_tokenizer_path); - if (tok_status != tokenizers::Error::Ok) { - ET_LOG( - Error, - "Failed to load tokenizer from %s", - FLAGS_tokenizer_path.c_str()); - return 1; - } - -#ifdef EXECUTORCH_BUILD_CUDA - // GPU memory: before load { size_t free = 0, total = 0; if (cudaMemGetInfo(&free, &total) == cudaSuccess) { @@ -144,90 +79,32 @@ int main(int argc, char** argv) { stats.model_load_start_ms = llm::time_in_ms(); - // Create Module with share_memory_arenas=true so prefill and decode - // share mutable buffers (KV cache, conv_state, recurrent_state). - std::vector data_files; - if (!FLAGS_data_path.empty()) { - data_files.push_back(FLAGS_data_path); - } - auto module = std::make_unique( - FLAGS_model_path, - data_files, - Module::LoadMode::File, - /*event_tracer=*/nullptr, - /*memory_allocator=*/nullptr, - /*temp_allocator=*/nullptr, - /*share_memory_arenas=*/true); - - // Get metadata - auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); - if (metadata_result.error() != Error::Ok) { - ET_LOG(Error, "Failed to get metadata from model"); - return 1; - } - auto metadata = metadata_result.get(); - -#ifdef EXECUTORCH_BUILD_CUDA - // Set CUDA graph option if requested (must be before load_method) - if (FLAGS_cuda_graph) { - executorch::runtime::BackendOptions<2> cuda_opts; - cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); - executorch::runtime::set_option("CudaBackend", cuda_opts.view()); - printf("CUDA graph enabled for decode method\n"); - } -#else - if (FLAGS_cuda_graph) { - ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); - } -#endif + // Build engine (reads tokenizer + metadata) and a session (loads weights and + // the prefill/decode methods). + llm::Qwen35MoEConfig config; + config.model_path = FLAGS_model_path; + config.data_path = FLAGS_data_path; + config.tokenizer_path = FLAGS_tokenizer_path; + config.cuda_graph = FLAGS_cuda_graph; printf("Loading methods...\n"); - -#ifdef EXECUTORCH_BUILD_CUDA - // Enable cross-method per-FQN weight sharing in the CUDA backend so that - // prefill and decode (which share KV cache and other mutable buffers / - // weights) avoid duplicate GPU allocations. This is critical for fitting - // Qwen 3.5 MoE on a single GPU. MUST be set BEFORE load_method, since the - // backend reads this flag during init() to decide between the per-weight - // cache path and the legacy per-method blob load. - { - executorch::runtime::BackendOptions<1> backend_options; - auto set_err = - backend_options.set_option("weight_sharing_across_methods", true); - if (set_err != Error::Ok) { - ET_LOG( - Error, - "Failed to construct weight_sharing_across_methods option: %d", - static_cast(set_err)); - return 1; - } - const auto opt_err = - executorch::runtime::set_option("CudaBackend", backend_options.view()); - if (opt_err != Error::Ok) { - ET_LOG( - Error, - "Failed to enable weight_sharing_across_methods: %d", - static_cast(opt_err)); - return 1; - } - } -#endif - - auto err = module->load_method("prefill"); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to load prefill method"); + auto engine_result = llm::Qwen35MoEEngine::create(config); + if (engine_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to create Qwen3.5 MoE engine"); return 1; } - err = module->load_method("decode"); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to load decode method"); + auto engine = std::move(engine_result.get()); + + auto session_result = engine->create_session(); + if (session_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to create session"); return 1; } + auto session = std::move(session_result.get()); stats.model_load_end_ms = llm::time_in_ms(); #ifdef EXECUTORCH_BUILD_CUDA - // GPU memory: after load { size_t free = 0, total = 0; if (cudaMemGetInfo(&free, &total) == cudaSuccess) { @@ -236,10 +113,7 @@ int main(int argc, char** argv) { } #endif - // Get EOS ids - auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); - - // Read prompt from file or flag + // Read prompt from file or flag. std::string prompt_text = FLAGS_prompt; if (!FLAGS_prompt_file.empty()) { std::ifstream f(FLAGS_prompt_file); @@ -252,157 +126,103 @@ int main(int argc, char** argv) { (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } - // Encode prompt - auto encode_result = tokenizer->encode(prompt_text); + // Encode prompt via the engine's tokenizer. + auto encode_result = engine->tokenizer()->encode(prompt_text); if (!encode_result.ok()) { ET_LOG(Error, "Failed to encode prompt"); return 1; } - auto prompt_tokens = std::move(*encode_result); - int64_t num_prompt_tokens = prompt_tokens.size(); + std::vector prompt_tokens = std::move(*encode_result); + const int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); stats.num_prompt_tokens = num_prompt_tokens; - stats.inference_start_ms = llm::time_in_ms(); - // --------------------------------------------------------------- - // Sampling tensors (shared between prefill and decode) - // --------------------------------------------------------------- - auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - -#ifdef EXECUTORCH_BUILD_CUDA - // CUDA build: model fuses the sampler in. Pass a temperature tensor as - // a third input. Use a very small temperature for greedy to avoid - // division by zero while keeping the Gumbel noise negligible relative - // to logit differences. - float temp_val = - FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); - auto temp_tensor = - from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); -#endif - - stats.inference_start_ms = llm::time_in_ms(); - stats.num_prompt_tokens = num_prompt_tokens; - - // --------------------------------------------------------------- - // Prefill - // --------------------------------------------------------------- - uint64_t cur_token = 0; - - // Use prefill method for T>=2, decode method for T=1 - // (prefill was exported with min seq_len=2) - std::string run_method = "prefill"; - if (num_prompt_tokens == 1) { - run_method = "decode"; - } - - std::vector pos_data(num_prompt_tokens); - for (int64_t i = 0; i < num_prompt_tokens; i++) { - pos_data[i] = i; - } - std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); - auto tokens_tensor = from_blob( - token_data.data(), - {1, S(num_prompt_tokens)}, - executorch::aten::ScalarType::Long); - auto pos_tensor = from_blob( - pos_data.data(), - {S(num_prompt_tokens)}, - executorch::aten::ScalarType::Long); - - std::vector prefill_inputs; - prefill_inputs.push_back(tokens_tensor); - prefill_inputs.push_back(pos_tensor); -#ifdef EXECUTORCH_BUILD_CUDA - prefill_inputs.push_back(temp_tensor); -#endif - - auto prefill_result = module->execute(run_method, prefill_inputs); - if (prefill_result.error() != Error::Ok) { - ET_LOG(Error, "Prefill failed"); - return 1; - } - auto& prefill_outputs = prefill_result.get(); - - cur_token = read_token(prefill_outputs[0].toTensor()); - - stats.prompt_eval_end_ms = llm::time_in_ms(); - stats.first_token_ms = stats.prompt_eval_end_ms; - double prefill_ms = - (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - printf( - "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", - num_prompt_tokens, - prefill_ms, - num_prompt_tokens / prefill_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); - -#ifdef EXECUTORCH_BUILD_CUDA - // Synchronize CUDA device to ensure prefill's writes to shared mutable - // buffers (KV cache, conv_state, recurrent_state) are visible to the - // decode method, which may run on a different CUDA stream. - cudaDeviceSynchronize(); -#endif - - // --------------------------------------------------------------- - // Decode — generate tokens one at a time - // --------------------------------------------------------------- - int64_t pos = num_prompt_tokens; - uint64_t prev_token; - - std::vector decode_token_data = {static_cast(cur_token)}; - std::vector decode_pos_data = {pos}; - auto decode_tokens = from_blob( - decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); - auto decode_pos = from_blob( - decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); - - for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { - decode_token_data[0] = static_cast(cur_token); - decode_pos_data[0] = pos; - - std::vector decode_inputs; - decode_inputs.push_back(EValue(decode_tokens)); - decode_inputs.push_back(EValue(decode_pos)); -#ifdef EXECUTORCH_BUILD_CUDA - decode_inputs.push_back(EValue(temp_tensor)); -#endif - - auto decode_result = module->execute("decode", decode_inputs); - if (decode_result.error() != Error::Ok) { - ET_LOG(Error, "Decode step %d failed", step); + // Warmup + timed iterations on one loaded session (reset between). The first + // FLAGS_warmup iterations are discarded; they trigger CUDA-graph capture, + // allocator growth, and GPU clock ramp so the timed iterations reflect steady + // state. Text is printed only on the first iteration (coherence check). + llm::SamplingConfig sampling; + sampling.temperature = static_cast(FLAGS_temperature); + const int total_iters = FLAGS_warmup + std::max(1, FLAGS_num_iters); + std::vector prefill_tps_samples; + std::vector decode_tps_samples; + double prefill_ms = 0.0; + int64_t num_generated = 0; + + for (int iter = 0; iter < total_iters; ++iter) { + if (iter > 0 && session->reset() != Error::Ok) { + ET_LOG(Error, "Session reset failed before iteration %d", iter); return 1; } - auto& decode_outputs = decode_result.get(); + const bool measured = iter >= FLAGS_warmup; + const bool print_text = (iter == 0); - prev_token = cur_token; - cur_token = read_token(decode_outputs[0].toTensor()); - - if (step == 0) { - stats.first_token_ms = llm::time_in_ms(); + stats.inference_start_ms = llm::time_in_ms(); + if (session->prefill_tokens(prompt_tokens, &sampling) != Error::Ok) { + ET_LOG(Error, "Prefill failed"); + return 1; } - - pos++; - - auto decode_str = tokenizer->decode(prev_token, cur_token); - if (decode_str.ok()) { - printf("%s", decode_str->c_str()); - fflush(stdout); + stats.prompt_eval_end_ms = llm::time_in_ms(); + stats.first_token_ms = stats.prompt_eval_end_ms; + + num_generated = 0; + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + auto step_result = session->decode_one(sampling); + if (step_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + const auto& d = step_result.get(); + // A terminal step (EOS or cooperative stop) is the loop terminator, not + // generated output: don't count or emit it (matches the JSONL workers and + // the LLMSession contract). + if (d.is_terminal) { + if (print_text) { + printf("\n"); + } + break; + } + num_generated++; + if (step == 0) { + stats.first_token_ms = llm::time_in_ms(); + } + if (print_text && !d.text_piece.empty()) { + fwrite(d.text_piece.data(), 1, d.text_piece.size(), stdout); + fflush(stdout); + } } - - if (eos_ids.find(cur_token) != eos_ids.end()) { - printf("\n"); - break; + stats.inference_end_ms = llm::time_in_ms(); + stats.num_generated_tokens = num_generated; + + prefill_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + const double decode_ms_iter = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + const double pf_tps = + num_prompt_tokens / prefill_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND; + const double dc_tps = + num_generated / decode_ms_iter * stats.SCALING_FACTOR_UNITS_PER_SECOND; + printf( + "[iter %d%s] prefill %.1f tok/s (%" PRId64 + " tok, %.1f ms) | " + "decode %.1f tok/s (%" PRId64 " tok, %.1f ms)\n", + iter, + measured ? "" : " warmup", + pf_tps, + num_prompt_tokens, + prefill_ms, + dc_tps, + num_generated, + decode_ms_iter); + if (measured) { + prefill_tps_samples.push_back(pf_tps); + decode_tps_samples.push_back(dc_tps); } } - stats.inference_end_ms = llm::time_in_ms(); - printf("\n"); - int64_t num_generated = pos - num_prompt_tokens; - stats.num_generated_tokens = num_generated; #ifdef EXECUTORCH_BUILD_CUDA - // GPU memory: after generate + peak usage { size_t free = 0, total = 0; if (cudaMemGetInfo(&free, &total) == cudaSuccess) { @@ -420,8 +240,7 @@ int main(int argc, char** argv) { #endif printf("\n"); - - double decode_ms = + const double decode_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); printf( "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", @@ -435,21 +254,20 @@ int main(int argc, char** argv) { num_generated / decode_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); - // Structured stats report (matches stats.h print_report) printf("PyTorchObserver %s\n", llm::stats_to_json_string(stats).c_str()); - double ms_per_s = stats.SCALING_FACTOR_UNITS_PER_SECOND; - - double model_load_s = + const double ms_per_s = stats.SCALING_FACTOR_UNITS_PER_SECOND; + const double model_load_s = (double)(stats.model_load_end_ms - stats.model_load_start_ms) / ms_per_s; - double inference_time_ms = + const double inference_time_ms = (double)(stats.inference_end_ms - stats.inference_start_ms); - double prompt_eval_ms = + const double prompt_eval_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - double eval_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); - double ttft_s = + const double eval_ms = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + const double ttft_s = (double)(stats.first_token_ms - stats.inference_start_ms) / ms_per_s; - double sampling_s = (double)stats.aggregate_sampling_time_ms / ms_per_s; + const double sampling_s = (double)stats.aggregate_sampling_time_ms / ms_per_s; printf("\n"); printf( @@ -477,7 +295,6 @@ int main(int argc, char** argv) { stats.num_prompt_tokens + stats.num_generated_tokens, sampling_s); - // GPU memory reporting if (stats.gpu_total_bytes != static_cast(-1)) { printf( "\tGPU total memory: %.2f MB\n", @@ -502,5 +319,24 @@ int main(int argc, char** argv) { } } + if (!prefill_tps_samples.empty()) { + auto mean = [](const std::vector& v) { + double s = 0.0; + for (double x : v) { + s += x; + } + return s / v.size(); + }; + printf( + "\n=== mean over %zu timed iter(s) (warmup %d) | prompt %" PRId64 + ", gen %" PRId64 " ===\n", + prefill_tps_samples.size(), + FLAGS_warmup, + num_prompt_tokens, + num_generated); + printf("\tPrefill: %.1f tok/s\n", mean(prefill_tps_samples)); + printf("\tDecode: %.1f tok/s\n", mean(decode_tps_samples)); + } + return 0; } diff --git a/examples/models/qwen3_5_moe/model.md b/examples/models/qwen3_5_moe/model.md index 32510859b28..d29177c4c87 100644 --- a/examples/models/qwen3_5_moe/model.md +++ b/examples/models/qwen3_5_moe/model.md @@ -136,6 +136,35 @@ matmul). Visual and MTP keys are skipped. `lm_head.weight` is cloned from `embed_tokens.weight` if not present in checkpoint (tied embeddings). +## Serving (Engine/Session adapter) + +`main.cpp` is a thin CLI over `Qwen35MoEEngine` / `Qwen35MoESession` +(`qwen35_moe_engine.{h,cpp}`), which implement the model-agnostic +`LLMEngine` / `LLMSession` serving contract in +`extension/llm/runner/llm_session.h`. This lets an OpenAI-compatible server (or +any harness) drive the model without knowing it is Qwen-MoE or CUDA. + +- **`Qwen35MoEEngine`** owns immutable resources (tokenizer, metadata, EOS ids, + config). `create_session()` builds a `Module` with `share_memory_arenas=true` + and, on CUDA, sets the backend options that must precede `load_method` + (`weight_sharing_across_methods`, optional `enable_cuda_graph_for_method`), + then loads the `prefill`/`decode` methods. `serving_capacity()` reports a + single physical session — cross-session weight sharing is not yet proven, so + it fails closed to 1. +- **`Qwen35MoESession`** owns the mutable conversation state (KV / conv / + recurrent arenas via the Module, position cursor, pending token). + `prefill_tokens` dispatches to `prefill` (T≥2) or `decode` (T==1); + `decode_one` emits the pending token and forwards it, stopping at EOS without + forwarding it (EOS is not made resident and `position()` does not advance). + `seek()` returns `NotSupported` — the recurrent/conv state cannot be rewound + by logical position. `reset()` is a logical rewind to position 0; the model + zeroes `conv_state`/`recurrent_state` whenever prefill runs at + `input_pos[0]==0`, so no Module rebuild is needed. +- Backend-specific execution (CUDA in-graph sampling via a temperature input, + device sync, backend options) is isolated behind `EXECUTORCH_BUILD_CUDA` — the + extension point where an MLX runtime would slot in. The public + `LLMEngine`/`LLMSession` surface stays backend-agnostic. + ## References - [HF Transformers Qwen3.5 MoE](https://github.com/huggingface/transformers) — `transformers/models/qwen3_5_moe/` diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp new file mode 100644 index 00000000000..32ff1b4c0f9 --- /dev/null +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -0,0 +1,430 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef EXECUTORCH_BUILD_CUDA +#include +#else +#include +#endif + +namespace executorch::extension::llm { + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::TensorPtr; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; +using SizesType = executorch::aten::SizesType; + +namespace { + +// --------------------------------------------------------------------------- +// Backend-specific helpers (the MLX extension points live here). On CUDA the +// model fuses the sampler in and returns the sampled token id as a [B,1] float; +// non-CUDA returns logits and we sample on host. Keep these isolated so the +// session logic below stays backend-agnostic. +// --------------------------------------------------------------------------- + +Result read_sampled_token( + const executorch::aten::Tensor& output, + float temperature) { +#ifdef EXECUTORCH_BUILD_CUDA + (void)temperature; + const void* ptr = output.const_data_ptr(); + cudaPointerAttributes attrs; + const bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + float val = 0.0f; + if (on_device) { + if (cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost) != + cudaSuccess) { + // Don't fabricate token id 0 (a valid token) on a copy failure — that is + // silent corruption. Surface it so the caller aborts the request. + ET_LOG(Error, "read_sampled_token: cudaMemcpy D2H failed"); + return Error::Internal; + } + } else { + std::memcpy(&val, ptr, sizeof(float)); + } + return static_cast(val); +#else + return static_cast( + logits_to_token(output, temperature < 0.0f ? 0.0f : temperature)); +#endif +} + +// Build a Qwen Module with shared mutable arenas (so prefill and decode share +// KV/conv/recurrent state) and, on CUDA, the weight-sharing/cuda-graph backend +// options that MUST be set before load_method. Loads the prefill+decode methods +// (this is the heavy ~weights load). Shared by create_session() and reset(). +Result> build_qwen_module( + const Qwen35MoEConfig& config) { + std::vector data_files; + if (!config.data_path.empty()) { + data_files.push_back(config.data_path); + } + auto module = std::make_unique( + config.model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + +#ifdef EXECUTORCH_BUILD_CUDA + // Backend options are read during backend init(), so they must be set before + // load_method. + if (config.cuda_graph) { + executorch::runtime::BackendOptions<1> cuda_opts; + cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::runtime::set_option("CudaBackend", cuda_opts.view())); + } + { + // Cross-method per-FQN weight sharing: prefill and decode reuse one weight + // allocation instead of duplicating it (critical to fit on one GPU). + executorch::runtime::BackendOptions<1> backend_options; + ET_CHECK_OK_OR_RETURN_ERROR( + backend_options.set_option("weight_sharing_across_methods", true)); + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::runtime::set_option("CudaBackend", backend_options.view())); + } +#endif + + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("prefill")); + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("decode")); + return module; +} + +// LLMSession over the Qwen3.5 MoE prefill/decode methods. Owns one physical +// Module (one weight allocation + its KV/recurrent/conv state). Internal: the +// server depends only on the LLMSession base. +class Qwen35MoESession : public LLMSession { + public: + Qwen35MoESession( + std::unique_ptr module, + ::tokenizers::Tokenizer* tokenizer, + std::unordered_map metadata, + std::unordered_set eos_ids) + : module_(std::move(module)), + tokenizer_(tokenizer), + metadata_(std::move(metadata)), + eos_ids_(std::move(eos_ids)) { + // Persistent single-step decode buffers: stable addresses are required so + // CUDA-graph capture (which records buffer pointers) can replay each step. + decode_tokens_ = from_blob( + decode_token_data_, {1, 1}, executorch::aten::ScalarType::Long); + decode_pos_ = + from_blob(decode_pos_data_, {1}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + temp_tensor_ = + from_blob(&temp_val_, {1}, executorch::aten::ScalarType::Float); +#endif + } + + Error prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling) override { + if (tokens.empty()) { + ET_LOG(Error, "prefill_tokens: empty token list"); + return Error::InvalidArgument; + } + // The model samples the FIRST generated token in-graph during prefill, so + // it must use the request's sampling, not a stale session default. Only + // temperature is plumbed; reject non-default top_p/top_k/seed (parity with + // decode_one). + float first_token_temp = temperature_; + if (initial_sampling != nullptr) { + if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 || + initial_sampling->seed != 0) { + ET_LOG( + Error, + "prefill_tokens: only temperature is supported; top_p/top_k/seed " + "are not yet implemented"); + return Error::NotSupported; + } + first_token_temp = initial_sampling->temperature; + } + const int64_t T = static_cast(tokens.size()); + const auto ctx_it = metadata_.find(kMaxContextLen); + // Require room for at least one generated token: after prefill, pos_ == T + // and decode_one() forwards the first token at pos_, which must be < the + // context length. Rejecting pos_ + T == max_context (not just > it) keeps a + // full prompt from reaching decode_one with no room to step. + if (ctx_it != metadata_.end() && pos_ + T >= ctx_it->second) { + ET_LOG( + Error, + "prefill_tokens would leave no room to generate (pos %" PRId64 + " + %" PRId64 " >= max_context %" PRId64 ")", + pos_, + T, + ctx_it->second); + return Error::InvalidArgument; + } + + // A new prefill starts a fresh generation turn; clear any prior stop. + stop_.store(false, std::memory_order_relaxed); + std::vector token_data(tokens.begin(), tokens.end()); + std::vector pos_data(T); + for (int64_t i = 0; i < T; ++i) { + pos_data[i] = pos_ + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, static_cast(T)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {static_cast(T)}, + executorch::aten::ScalarType::Long); + + // prefill method handles T>=2; the model exports decode for the T==1 case. + const char* method = (T >= 2) ? "prefill" : "decode"; + std::vector inputs; + inputs.push_back(tokens_tensor); + inputs.push_back(pos_tensor); +#ifdef EXECUTORCH_BUILD_CUDA + set_temp(first_token_temp); + inputs.push_back(EValue(temp_tensor_)); +#endif + auto res = module_->execute(method, inputs); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + auto sampled = + read_sampled_token(res.get()[0].toTensor(), first_token_temp); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + pending_ = sampled.get(); + prev_decode_token_.reset(); + pos_ += T; // the prompt tokens are now resident in KV/state +#ifdef EXECUTORCH_BUILD_CUDA + // Make prefill's writes to the shared mutable arenas visible to decode + // (which may run on a different stream). This barrier is relied upon for + // correctness, and it also surfaces any async error from the prefill + // launch, so a non-success here must abort the request rather than decode + // on stale or corrupt state. + if (cudaDeviceSynchronize() != cudaSuccess) { + ET_LOG(Error, "prefill_tokens: cudaDeviceSynchronize failed"); + return Error::Internal; + } +#endif + return Error::Ok; + } + + Result decode_one(const SamplingConfig& sampling) override { + // Only temperature is plumbed; reject the rest rather than silently ignore + // (callers must not assume top_p/top_k/seed are applied). + if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + ET_LOG( + Error, + "Qwen35MoESession: only temperature is supported; top_p/top_k/seed " + "are not implemented"); + return Error::NotSupported; + } + ET_CHECK_OR_RETURN_ERROR( + pending_.has_value(), + InvalidState, + "decode_one requires a pending token; call prefill_tokens() first"); + temperature_ = sampling.temperature; + + const uint64_t token = pending_.value(); + const bool is_eos = eos_ids_.find(token) != eos_ids_.end(); + + // Decode the text piece with BPE context (previous token); surface + // tokenizer errors instead of hiding them as empty text. + const uint64_t prev = prev_decode_token_.value_or(token); + auto dec = tokenizer_->decode(prev, token); + if (!dec.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(dec.error())); + return Error::InvalidArgument; + } + std::string text_piece = std::move(*dec); + + // Terminate WITHOUT forwarding the token: at EOS (like the reference + // runner, EOS is not made resident and position() does not advance) or at a + // cooperative stop() observed at this boundary. No pending token remains. + // is_eos stays literal; is_terminal ends the loop either way. + if (is_eos || stop_.load(std::memory_order_relaxed)) { + pending_.reset(); + return DecodeResult{ + token, std::move(text_piece), is_eos, /*is_terminal=*/true}; + } + + // Only a NON-EOS, non-stopped token is forwarded (made resident at pos_), + // so the capacity check belongs here — after the short-circuit, so a final + // EOS is still emitted when state is exactly full. Without it, decode would + // write KV/recurrent state past the context window. + const auto ctx_it = metadata_.find(kMaxContextLen); + if (ctx_it != metadata_.end()) { + ET_CHECK_OR_RETURN_ERROR( + pos_ < ctx_it->second, + InvalidArgument, + "decode_one would exceed context capacity: pos_ %" PRId64 + " >= max_context %" PRId64, + pos_, + ctx_it->second); + } + + // Forward `token` at pos_ through the decode method to get the next pending + // token. Update the persistent buffers in place (stable addresses). + decode_token_data_[0] = static_cast(token); + decode_pos_data_[0] = pos_; + std::vector inputs; + inputs.push_back(EValue(decode_tokens_)); + inputs.push_back(EValue(decode_pos_)); +#ifdef EXECUTORCH_BUILD_CUDA + set_temp(temperature_); + inputs.push_back(EValue(temp_tensor_)); +#endif + auto res = module_->execute("decode", inputs); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + auto sampled = read_sampled_token(res.get()[0].toTensor(), temperature_); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + pending_ = sampled.get(); + prev_decode_token_ = token; + pos_ += 1; + return DecodeResult{ + token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false}; + } + + Error seek(int64_t pos) override { + // The hybrid model carries recurrent/conv state that cannot be safely + // rewound by logical position the way contiguous KV can. Fail closed so the + // prefix cache falls back to reset + full prefill (V1). + (void)pos; + return Error::NotSupported; + } + + int64_t position() const override { + return pos_; + } + + Error reset() override { + // Logical reset is sufficient: the model zeroes conv_state/recurrent_state + // whenever prefill runs at input_pos[0]==0 (model.py), and a fresh prefill + // overwrites the KV cache at [0, T). So rewinding to position 0 and + // clearing the pending token gives a clean conversation without a Module + // rebuild. + pos_ = 0; + pending_.reset(); + prev_decode_token_.reset(); + stop_.store(false, std::memory_order_relaxed); + return Error::Ok; + } + + void stop() override { + // Cooperative, token-boundary: the driving loop checks between decode_one() + // calls. A single decode_one() forward is not interruptible. + stop_.store(true, std::memory_order_relaxed); + } + + private: +#ifdef EXECUTORCH_BUILD_CUDA + // Greedy (temperature <= 0) maps to a tiny temperature so the in-graph + // sampler avoids division by zero while staying effectively argmax. + void set_temp(float t) { + temp_val_ = (t <= 0.0f) ? 1e-6f : t; + } +#endif + + std::unique_ptr module_; + ::tokenizers::Tokenizer* tokenizer_; // non-owning; owned by the engine + std::unordered_map metadata_; + std::unordered_set eos_ids_; + + int64_t pos_ = 0; + std::optional pending_; + std::optional prev_decode_token_; + float temperature_ = -1.0f; + std::atomic stop_{false}; + + // Persistent single-step decode buffers (stable addresses for CUDA graph). + int64_t decode_token_data_[1] = {0}; + int64_t decode_pos_data_[1] = {0}; + TensorPtr decode_tokens_; + TensorPtr decode_pos_; +#ifdef EXECUTORCH_BUILD_CUDA + float temp_val_ = 1e-6f; + TensorPtr temp_tensor_; +#endif +}; + +} // namespace + +Result> Qwen35MoEEngine::create( + const Qwen35MoEConfig& config) { + if (config.model_path.empty() || config.tokenizer_path.empty()) { + ET_LOG( + Error, "Qwen35MoEEngine: model_path and tokenizer_path are required"); + return Error::InvalidArgument; + } + + auto tokenizer = std::make_unique<::tokenizers::HFTokenizer>(); + if (tokenizer->load(config.tokenizer_path) != ::tokenizers::Error::Ok) { + ET_LOG( + Error, + "Qwen35MoEEngine: failed to load tokenizer from %s", + config.tokenizer_path.c_str()); + return Error::InvalidArgument; + } + + // Read metadata + eos from a lightweight Module (program + tiny metadata + // methods only; the heavy prefill/decode weights are NOT loaded here). + std::vector data_files; + if (!config.data_path.empty()) { + data_files.push_back(config.data_path); + } + auto meta_module = std::make_unique( + config.model_path, data_files, Module::LoadMode::File); + auto metadata_result = get_llm_metadata(tokenizer.get(), meta_module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Qwen35MoEEngine: failed to read metadata"); + return metadata_result.error(); + } + auto eos_ids = get_eos_ids(tokenizer.get(), meta_module.get()); + // This export's metadata doesn't carry the chat-turn EOS (config.json has no + // eos_token_id and the .pte exports no get_eos_ids method), so get_eos_ids() + // misses it and a session would never terminate — it would decode to + // max_new_tokens every turn. <|im_end|> ends every Qwen assistant turn; add + // it explicitly so decode_one() stops at end of turn. + if (auto im_end = tokenizer->piece_to_id("<|im_end|>"); im_end.ok()) { + eos_ids.insert(*im_end); + } else { + ET_LOG( + Error, + "Qwen35MoEEngine: could not resolve <|im_end|> token id; the model may " + "not stop at end of turn"); + } + + return std::unique_ptr(new Qwen35MoEEngine( + config, std::move(tokenizer), metadata_result.get(), std::move(eos_ids))); +} + +Result> Qwen35MoEEngine::create_session() { + auto module = build_qwen_module(config_); + ET_CHECK_OK_OR_RETURN_ERROR(module.error()); + return std::unique_ptr(new Qwen35MoESession( + std::move(module.get()), tokenizer_.get(), metadata_, eos_ids_)); +} + +} // namespace executorch::extension::llm diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.h b/examples/models/qwen3_5_moe/qwen35_moe_engine.h new file mode 100644 index 00000000000..9fb9e99d71e --- /dev/null +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Engine/Session adapter for the Qwen3.5 MoE model, implementing the +// model-agnostic LLMEngine/LLMSession serving contract (llm_session.h) over the +// model's exported prefill/decode methods. +// +// The public surface is backend-agnostic: the server receives an LLMEngine and +// never branches on CUDA vs MLX. Backend-specific execution (CUDA in-graph +// sampling, weight-sharing/cuda-graph backend options, device sync) is isolated +// behind EXECUTORCH_BUILD_CUDA inside the .cpp; those isolated points are where +// an MLX runtime would slot in. MLX is NOT implemented or validated here. +// +// V1: serving_capacity() reports a single physical session (one Module = one +// weight allocation). Multiple weight-sharing sessions are a measured V2 step. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +/// Immutable configuration for a Qwen3.5 MoE engine. +struct Qwen35MoEConfig { + std::string model_path; // .pte + std::string data_path; // .ptd (CUDA delegate blob); empty if none + std::string tokenizer_path; // HuggingFace tokenizer.json + bool cuda_graph = false; // enable CUDA graph capture for the decode method +}; + +/// Engine over one loaded Qwen3.5 MoE Program. Owns immutable resources +/// (tokenizer, metadata, eos ids, config) and creates sessions that each own a +/// physical Module with its own KV/recurrent/conv state. +class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { + public: + static ::executorch::runtime::Result> create( + const Qwen35MoEConfig& config); + + ::executorch::runtime::Result> create_session() + override; + + // V1: one physical session; weight sharing across sessions is unproven, so we + // fail closed to 1 (the server queues concurrent requests on the resident + // session rather than duplicating ~18GB of weights). + LLMServingCapacity serving_capacity() const override { + return LLMServingCapacity{}; + } + + const std::unordered_map& metadata() const override { + return metadata_; + } + + // Non-owning; valid for the engine's lifetime (the engine must outlive any + // session and any caller using this). Used by the runner to encode prompts; + // not part of the model-agnostic LLMEngine surface the server depends on. + ::tokenizers::Tokenizer* tokenizer() const { + return tokenizer_.get(); + } + + Qwen35MoEEngine(const Qwen35MoEEngine&) = delete; + Qwen35MoEEngine& operator=(const Qwen35MoEEngine&) = delete; + + private: + Qwen35MoEEngine( + Qwen35MoEConfig config, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unordered_map metadata, + std::unordered_set eos_ids) + : config_(std::move(config)), + tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + eos_ids_(std::move(eos_ids)) {} + + Qwen35MoEConfig config_; + std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; + std::unordered_map metadata_; + std::unordered_set eos_ids_; +}; + +} // namespace executorch::extension::llm diff --git a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp new file mode 100644 index 00000000000..2cc705f96e1 --- /dev/null +++ b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Model-execution worker for Qwen3.5 MoE (CUDA/AOTI). +// +// All model execution lives here in C++ via Qwen35MoEEngine / LLMSession; no +// Python model code, no pybind. The OpenAI control plane (serve.py) spawns this +// process and drives it over JSONL through the generic WorkerClient — the same +// protocol and decode loop every worker uses (worker_loop.h); this file only +// constructs the engine/session. +// +// Isolation rationale: executing the AOTI CUDA model inside a live asyncio HTTP +// process segfaults in the int4 matmul (validated). Here the model runs in a +// plain synchronous loop in its own process, which is reliable. +// +// V1: single-slot (one engine == one ~18GB weight allocation == one session); +// the control plane queues concurrent requests on the resident session. + +#include + +#include +#include +#include +#include + +#include + +DEFINE_string(model_path, "", "Model .pte file path."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); +DEFINE_string(data_path, "", "Data file (.ptd) for the CUDA backend."); +DEFINE_bool(cuda_graph, false, "Enable CUDA graph for the decode method."); + +namespace { +namespace llm = ::executorch::extension::llm; +using ::executorch::runtime::Error; +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty() || FLAGS_tokenizer_path.empty()) { + ET_LOG( + Error, "qwen35_moe_worker: --model_path and --tokenizer_path required"); + return 1; + } + + llm::Qwen35MoEConfig config; + config.model_path = FLAGS_model_path; + config.data_path = FLAGS_data_path; + config.tokenizer_path = FLAGS_tokenizer_path; + config.cuda_graph = FLAGS_cuda_graph; + + auto engine_result = llm::Qwen35MoEEngine::create(config); + if (engine_result.error() != Error::Ok) { + ET_LOG(Error, "qwen35_moe_worker: failed to create engine"); + return 1; + } + auto engine = std::move(engine_result.get()); + + auto session_result = engine->create_session(); + if (session_result.error() != Error::Ok) { + ET_LOG(Error, "qwen35_moe_worker: failed to create session"); + return 1; + } + auto session = std::move(session_result.get()); + + // The engine's tokenizer encodes the rendered prompt to ids; the session + // decodes ids back to text internally. + ::tokenizers::Tokenizer* tokenizer = engine->tokenizer(); + + return llm::run_worker_stdio_loop(*session, *tokenizer, engine->metadata()); +} diff --git a/examples/models/qwen3_5_moe/serve.py b/examples/models/qwen3_5_moe/serve.py new file mode 100644 index 00000000000..229a84425fb --- /dev/null +++ b/examples/models/qwen3_5_moe/serve.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible HTTP server for Qwen3.5 MoE (process-isolated). + +This is the CONTROL PLANE only: FastAPI/uvicorn + OpenAI protocol, chat +templating, tool parsing, request validation. It runs NO CUDA model code and +imports no model pybind. Model execution lives in a separate C++ worker +process (qwen3_5_moe_worker) that this process drives over JSONL via the generic +WorkerClient — the same protocol the generic text_llm_worker speaks. + +Why two processes: executing the AOTI CUDA model inside a live asyncio server +process segfaults in the int4 matmul (validated by elimination — the trigger is +CUDA execution while a live asyncio loop is resident). Isolating CUDA in a plain +(no-asyncio) C++ worker process is the reliable shape, and it loads weights once. + +V1 constraints: + * single-slot: one worker, one session; concurrent HTTP requests queue. + * prefix cache off (Qwen seek() is NotSupported). + * The control plane only does blocking pipe I/O on its executor thread (no + CUDA), which is safe under asyncio. + +Launch (LD_LIBRARY_PATH shim is forwarded to the worker for the CUDA blob): + + LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \\ + python -m executorch.examples.models.qwen3_5_moe.serve \\ + --model-path qwen35_moe_exports/model.pte \\ + --data-path qwen35_moe_exports/aoti_cuda_blob.ptd \\ + --tokenizer-path ~/models/Qwen3.5-35B-A3B/tokenizer.json \\ + --hf-tokenizer ~/models/Qwen3.5-35B-A3B \\ + --model-id qwen3.5-moe --no-think +""" + +import argparse +import logging +import os +from pathlib import Path + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.session_runtime import SessionRuntime +from executorch.extension.llm.server.python.tool_parsers import QwenFunctionCallDetector +from executorch.extension.llm.server.python.worker_client import spawn_worker + +logger = logging.getLogger(__name__) + + +def _default_worker_bin() -> str: + repo_root = Path(__file__).resolve().parents[3] + return str( + repo_root + / "cmake-out" + / "examples" + / "models" + / "qwen3_5_moe" + / "qwen3_5_moe_worker" + ) + + +def _spawn(args): + """Spawn the C++ Qwen worker and return a ready WorkerClient.""" + env = dict(os.environ) + conda = os.environ.get("CONDA_PREFIX") + if conda: + # The AOTI CUDA blob needs the conda libstdc++; forward it to the worker. + env["LD_LIBRARY_PATH"] = f"{conda}/lib:" + env.get("LD_LIBRARY_PATH", "") + worker_bin = args.worker_bin or _default_worker_bin() + cmd = [ + worker_bin, + "--model_path", + args.model_path, + "--tokenizer_path", + args.tokenizer_path, + ] + if args.data_path: + cmd += ["--data_path", args.data_path] + logger.info("Starting Qwen worker subprocess (loads the model once)...") + return spawn_worker(cmd, env=env) + + +def build_app_from_args(args): + """Construct the FastAPI app + the model worker. Returns (app, model_id).""" + default_template_kwargs = {"enable_thinking": False} if args.no_think else None + template = ChatTemplate( + args.hf_tokenizer, default_template_kwargs=default_template_kwargs + ) + + worker = _spawn(args) # one worker == one session (single-slot V1) + runtime = SessionRuntime(worker) + serving = ServingChat( + runtime, + template, + args.model_id, + max_context=args.max_context, + # Qwen3.5-MoE emits the XML tool format. + tool_detector_cls=QwenFunctionCallDetector, + ) + + from executorch.extension.llm.server.python.server import build_app + + app = build_app(serving, args.model_id) + + @app.on_event("shutdown") + def _stop_worker(): + runtime.close_worker() + + return app, args.model_id + + +def main() -> None: + p = argparse.ArgumentParser( + description="OpenAI-compatible LLM server for Qwen3.5 MoE (process-isolated, V1)" + ) + p.add_argument("--model-path", required=True, help="Path to the .pte model") + p.add_argument( + "--data-path", default=None, help="Path to the .ptd CUDA delegate blob" + ) + p.add_argument( + "--tokenizer-path", required=True, help="Path to the HuggingFace tokenizer.json" + ) + p.add_argument( + "--hf-tokenizer", + required=True, + help="HF tokenizer id/dir for the model's chat template", + ) + p.add_argument("--model-id", default="qwen3.5-moe") + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, default=8000) + p.add_argument( + "--max-context", + type=int, + default=None, + help="Context window; prompts exceeding it are rejected with 400.", + ) + p.add_argument( + "--no-think", + action="store_true", + help="Default reasoning off (enable_thinking=False).", + ) + p.add_argument( + "--num-runners", + type=int, + default=1, + help="V1 supports 1 only (single-slot).", + ) + p.add_argument( + "--worker-bin", + default=None, + help="Path to the qwen3_5_moe_worker binary " + "(default: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_worker).", + ) + args = p.parse_args() + logging.basicConfig(level=logging.INFO) + + if args.num_runners != 1: + p.error( + "Qwen3.5 MoE V1 is single-slot: one worker serves one session; " + "concurrent requests queue." + ) + + app, _ = build_app_from_args(args) + + import uvicorn + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3_5_moe/test_serve.py b/examples/models/qwen3_5_moe/test_serve.py new file mode 100644 index 00000000000..fdaa6a1ea62 --- /dev/null +++ b/examples/models/qwen3_5_moe/test_serve.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the Qwen3.5 MoE process-isolated OpenAI launcher (serve.py). + +Hermetic: no model, GPU, or worker subprocess. Covers layering (Qwen stays an +example; the control plane runs no CUDA and imports no model pybind), the worker +spawn command, and the single-slot CLI guard. The generic JSONL protocol is +covered by extension/llm/server/python/tests/test_worker_client.py; the live +HTTP smoke test is documented in README.md and run on a CUDA box. +""" + +import pathlib +from types import SimpleNamespace + +import pytest + +from executorch.examples.models.qwen3_5_moe import serve + +_HERE = pathlib.Path(serve.__file__).resolve().parent +_REPO_ROOT = _HERE.parents[2] # qwen3_5_moe -> models -> examples -> repo root + + +# --- Layering --------------------------------------------------------------- + + +def test_generic_runner_pybind_has_no_qwen_include(): + src = (_REPO_ROOT / "extension/llm/runner/pybindings.cpp").read_text() + assert "qwen3_5_moe" not in src and "qwen35_moe" not in src + + +def test_generic_server_does_not_reference_qwen(): + server_dir = _REPO_ROOT / "extension/llm/server" + offenders = [ + p + for p in server_dir.rglob("*.py") + if "qwen3_5_moe" in p.read_text() or "_qwen35_moe" in p.read_text() + ] + assert offenders == [], f"generic server must not reference Qwen: {offenders}" + + +def test_control_plane_runs_no_model_code(): + # serve.py is the control plane: it constructs no engine and imports no model + # pybind. Model execution lives entirely in the C++ worker. + serve_src = (_HERE / "serve.py").read_text() + assert "Qwen35MoEEngine" not in serve_src + assert "_qwen35_moe" not in serve_src + worker_src = (_HERE / "qwen35_moe_worker.cpp").read_text() + assert "Qwen35MoEEngine" in worker_src + + +def test_python_worker_and_pybind_are_gone(): + # The Python worker and the model pybind have been replaced by the C++ worker. + assert not (_HERE / "worker.py").exists() + assert not (_HERE / "qwen35_moe_pybindings.cpp").exists() + + +# --- Worker spawn wiring ---------------------------------------------------- + + +def test_spawn_builds_worker_command(monkeypatch): + captured = {} + + def fake_spawn(cmd, env=None): + captured["cmd"] = cmd + return object() # stand-in WorkerClient + + monkeypatch.setattr(serve, "spawn_worker", fake_spawn) + serve._spawn( + SimpleNamespace( + worker_bin="/bin/qwen_worker", + model_path="m.pte", + tokenizer_path="t.json", + data_path="d.ptd", + ) + ) + assert captured["cmd"] == [ + "/bin/qwen_worker", + "--model_path", + "m.pte", + "--tokenizer_path", + "t.json", + "--data_path", + "d.ptd", + ] + + +def test_spawn_defaults_worker_bin_and_omits_empty_data_path(monkeypatch): + captured = {} + monkeypatch.setattr( + serve, "spawn_worker", lambda cmd, env=None: captured.update(cmd=cmd) + ) + serve._spawn( + SimpleNamespace( + worker_bin=None, model_path="m.pte", tokenizer_path="t.json", data_path=None + ) + ) + cmd = captured["cmd"] + assert cmd[0].endswith("qwen3_5_moe_worker") # default binary path + assert "--data_path" not in cmd # omitted when no .ptd + + +# --- CLI guard -------------------------------------------------------------- + + +def test_rejects_multiple_runners(monkeypatch): + import sys + + monkeypatch.setattr( + sys, + "argv", + [ + "serve.py", + "--model-path", + "m.pte", + "--tokenizer-path", + "t.json", + "--hf-tokenizer", + "hf", + "--num-runners", + "2", + ], + ) + with pytest.raises(SystemExit): + serve.main()