Phase O: Backend-specific model architecture#20
Merged
dexwritescode merged 20 commits intomainfrom May 6, 2026
Merged
Conversation
Phase O.1: move LayerKVCache and GemmaLayerKVCache out of their respective model headers into a shared kv_cache.h. Phase O.2: extract non-computational shared state (config_, tokenizer_, weights_, get_weight, infer_quant_bits, num_parameters) from Qwen3MoeModel into Qwen3MoeModelBase. Qwen3MoeModel now inherits from both Qwen3MoeModelBase and LanguageModel. No behavioral change — all compute and service tests pass.
mx::dequantize returns bfloat16 (matching the scales dtype), so the entire forward pass in Qwen3MoeModelMLX runs in bfloat16. Calling array.data<float>() on a bfloat16 array silently reinterprets the 16-bit values as 32-bit floats, producing completely wrong logits. Fix: mx::astype(logits, mx::float32) before reading data<float>(). Qwen3MoeIntegrationTest.GenerateCapitalOfFrance now passes: output "<think>\n\n</think>\n\nThe capital of France is **Paris**."
- Batched prefill: process all prompt tokens as a single {seq_len, hidden}
matrix with causal mask instead of seq_len serial 1-token passes
- Growing KV cache: mx::concatenate vs pre-allocated slice_update; avoids
SDPA over padded positions and eliminates wasted GPU work per decode step
- Causal mask dtype inferred from embedding matrix (em.dtype()) so float16
models (Mistral) and bfloat16 models (Llama3) both work correctly
- Add mlx_prefill_batch() and mlx_run_step() split; mlx_build_decode_fn()
stores compiled_fn slot for upcoming mx::compile(shapeless=true)
- Expose context_size param through LanguageModel::load() and CLI --context flag
- Qwen3 MoE: explicit tensor reshapes for SSM kernel inputs; prep for shapeless compile
…per suite Each test suite now loads its model and backend once in SetUpTestSuite() and tears it down in TearDownTestSuite(). Individual tests check skip_reason_ in SetUp() and call GTEST_SKIP() if the model is absent. Prevents reloading multi-GB models for every test case, cutting peak memory from ~60 GB to a single model's footprint and halving test runtime.
- Wrap decode fn in mx::compile(shapeless=true) — Metal kernels now compiled once and reused every decode step - Use -1 for KV sequence dimension in reshape so the compiled graph accepts growing KV length without retracing - Evaluate only logits (outputs[0]) per step; KV caches remain lazy until consumed as inputs to the next step Measured: ~40 tok/s on Mistral-7B-Instruct-v0.3-4bit (M2 Ultra). Full GPU pipelining (O.6.4) is the next step to close the 2x gap vs Python mlx-lm (~80 tok/s).
Pass the unevaluated argmax from step N as input token for step N+1 before waiting for step N to complete. async_eval() enqueues GPU work for step N+1 immediately, so the GPU runs decode steps back-to-back with no CPU round-trip stall between them. Adds LlamaModel::mlx_generate_pipelined() and overrides generate() to dispatch greedy paths through it on Apple Silicon. Also fixes a service bug where temperature=0.0 from the GUI was being overridden to 0.7 (the proto3 float default is 0.0, indistinguishable from "not set" — the service now trusts the value directly). Mistral-7B-Instruct-v0.3-4bit: ~40 tok/s → ~71 tok/s (warm, greedy).
…mework Good changes (keep): - language_model.cpp: add "qwen3_moe" alias alongside "qwen3_5_moe" in factory - qwen3_moe_model_mlx.h: add generate() override + moe_generate_pipelined() declarations - qwen3_moe_model_mlx.cpp: add generate() + moe_generate_pipelined() (async_eval pipeline, same O.6.4 pattern as LlamaModel) Broken experiments (need to rework — see O.7 beads issue): - Chunked KV (kKvChunk=64) with full-model mx::compile: 7.5 tok/s, garbled output after first chunk boundary crossing. WORSE than 12.2 tok/s baseline. - Tried shapeless=true + CustomKernel output_shapes() patch: 3.3 tok/s (generic shapeless Metal kernels are slower than shape-specialised ones). - Tried fixed KV (max_ctx=4096) + plain compile: 5.2-9.8 tok/s (SDPA over all 4096 positions even when mostly empty kills memory bandwidth). Root cause confirmed via mlx-lm investigation: - Full-model mx::compile is the WRONG granularity for hybrid MoE models. - mlx-lm never compiles the full forward pass; compiles only small sub-functions (gate computations with shapeless=True, SSM step with fixed shapes, MoE routing). - The GatedDeltaNet custom Metal kernel blocks compile fusion across the graph. Next step (O.7 issue): 1. Remove full-model compile — store raw fn 2. Replace custom Metal kernel with pure MLX ops for SSM step 3. Wrap SSM step as separate compiled fn (shapeless=True, fixed shapes in decode) 4. Revert KV to growing-via-concat, slice to actual length before SDPA
…layer mlx_ssm_step and mlx_full_attn_step now accept [1,T,hidden] for both prefill (T>1) and decode (T=1). run_prefill() runs eagerly: one Metal dispatch per SSM layer (T-loop inside kernel), one SDPA per attention layer. Previously prefill looped T sequential decode steps. Remove mx::compile from build_decode_fn() — stock MLX CustomKernel does not implement Primitive::output_shapes(), so shapeless compile throws. Raw function stored instead; element-wise fusion is a future concern. Makefile: add stage-service target so `make run` always bundles a fresh neurons_service binary into the Flutter app.
…eless=true mx::compile(shapeless=true) on the decode function was silently falling back to eager mode because two MLX primitives — CustomKernel and GatherQMM — lacked output_shapes() implementations, which compile needs when re-using a traced graph with new shapes. Fix: point FetchContent to a fork (PR #3485) that adds output_shapes() to both primitives. With this in place, compile(shapeless=true) traces the decode graph once on the first call and reuses it for all subsequent steps regardless of the growing KV cache shape — no re-tracing, no eager fallback. Growing-KV + shapeless=true outperforms the pre-allocated fixed-KV alternative (Option A, also implemented here as a reference path) because SDPA only covers filled positions rather than a full 4096-slot window from step 1. Results on Qwen3.6-35B-A3B-4bit (Release, M-series): before: 11.2 tok/s (no compile) after: 51.5 tok/s (compile, shapeless=true, growing KV) TODO: revert CMakeLists.txt to upstream mlx once PR #3485 merges.
run_decode_step was only calling mx::eval(outputs[0]) (logits). All 20 KV/SSM arrays stayed lazy, each holding a reference to the previous step's arrays. After N steps this built an N-deep computation chain; MLX could not free any intermediate graph state. Fix: async_eval the logits + all KV/SSM outputs together, then block on logits only. This materialises the state arrays before the next step begins, breaking the chain and allowing MLX to free intermediate buffers. Observed before: 20 GB → 32 GB RSS on a 3574-token generation. After: memory stays flat. Throughput unchanged at ~55 tok/s.
LanguageModel::load() now dispatches qwen3_5_moe / qwen3_moe on Apple Silicon directly to Qwen3MoeModelMLX, so the #if MLX_BACKEND_ENABLED block inside Qwen3MoeModel is unreachable dead code. Remove MlxDecodeState struct, initialize_mlx_state(), build_mlx_compile_fn(), mlx_state_ member, and the entire 574-line MLX implementation block. Qwen3MoeModel is now a clean ComputeBackend-only class with no MLX awareness.
Implements the Phase O backend-specific model pattern for Gemma: GemmaModelBase holds shared config/tokenizer/weights; GemmaModelMLX (Apple Silicon only) owns mx::array weights and runs the full forward pass directly against MLX with mx::compile(shapeless=true) and async_eval GPU pipelining — same architecture as Qwen3MoeModelMLX. Gemma-specific MLX differences: - Embedding scale: h * sqrt(hidden_size) - RMSNorm: rms_norm(x, 1+weight, eps) — weights are zero-deviations - 4 norms per block (post_attention + post_feedforward residuals) - GeGLU: gelu(gate)*up — implemented via mx::erf exact formula - Per-layer rope_theta captured as vector<float> before compile - LM head: tied embed or separate lm_head.weight GemmaModel (ComputeBackend fallback) now extends GemmaModelBase, removing duplicated config_/tokenizer_/weights_ fields. Verified: GemmaIntegrationTest passes (510ms), CLI returns "The capital of France is **Paris**." at 19.6 tok/s.
Remove prefill(), decode(), reset_cache(), and backend() from the LanguageModel public interface. The generate loop is now owned by GenerateHelper::run() in sampler.h/.cpp, called via lambdas from each subclass's generate() override. Changes per class: - LlamaModel: generate() fallback → GenerateHelper::run(); prefill/ decode/reset_cache kept public (diagnostic tests use LlamaModel* directly) - GemmaModel, Qwen3MoeModel: generate() override added; prefill/decode/ reset_cache moved to private - GemmaModelMLX, Qwen3MoeModelMLX: same private treatment; generate() fallback (non-greedy) → GenerateHelper::run() LanguageModel now exposes only: generate(), config(), model_type(), tokenizer(), num_parameters(), and the tool-use / factory interface. 125/125 tests pass. Gemma 48.2 tok/s, Qwen3 MoE 33.7 tok/s (greedy CLI).
mlx_ops::bits() and mlx_ops::lin() replace three identical copies: - llama_mlx_bits/llama_mlx_lin in llama_model.cpp - gemma_mlx_bits/gemma_mlx_lin in gemma_model_mlx.cpp - mlx_bits/mlx_lin in qwen3_moe_model_mlx.cpp Each file now pulls them in via `using namespace compute::mlx_ops` inside its anonymous namespace. The WM type alias (unordered_map<string,array>) is also consolidated to compute::WM in mlx_ops.h.
…ouble-hop On Apple Silicon (MLX_BACKEND_ENABLED), from_model_dir() now calls ModelLoader::load_model_mlx() directly to obtain mx::array weights, eliminating the Tensor intermediary that required a full weight-map conversion after load. weights_ is left empty on the MLX path; all inference goes through mlx_weights_ / mlx_embed_mat_. Non-Apple-Silicon builds use the ComputeBackend/Tensor path unchanged. Three diagnostic tests that relied on the Tensor-path forward()/ attention_layer() APIs are skipped on Apple Silicon with an explanatory message — these methods are intentionally unavailable in the MLX path.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
O.1–O.2: Structural refactor
LayerKVCache/GemmaLayerKVCacheinto sharedkv_cache.hQwen3MoeModelintoQwen3MoeModelBaseO.3: Qwen3MoeModelMLX
mx::arrayforward pass, routed fromLanguageModel::load()on Apple SiliconO.6: LlamaModel MLX fast path (~71 tok/s)
mx::compile+mx::async_evalpipelined decodeO.7: Qwen3 MoE decode — 11 → 51 tok/s
mx::compile(shapeless=true)was silently falling back to eager because two MLX primitives (CustomKernel,GatherQMM) lackedoutput_shapes()— required when the compiled graph is reused across changing KV shapesoutput_shapes()to both; FetchContent points to the fork until the PR mergesmx::compile(shapeless=true)traces the decode graph once; all subsequent steps reuse it regardless of growing KV shape — no re-tracing, no eager fallback