Skip to content

Phase O: Backend-specific model architecture#20

Merged
dexwritescode merged 20 commits intomainfrom
phase-o-backend-specific-models
May 6, 2026
Merged

Phase O: Backend-specific model architecture#20
dexwritescode merged 20 commits intomainfrom
phase-o-backend-specific-models

Conversation

@dexwritescode
Copy link
Copy Markdown
Owner

@dexwritescode dexwritescode commented May 1, 2026

Summary

O.1–O.2: Structural refactor

  • Extract LayerKVCache / GemmaLayerKVCache into shared kv_cache.h
  • Extract shared non-computational state from Qwen3MoeModel into Qwen3MoeModelBase

O.3: Qwen3MoeModelMLX

  • New class with a pure mx::array forward pass, routed from LanguageModel::load() on Apple Silicon

O.6: LlamaModel MLX fast path (~71 tok/s)

  • Batched prefill over all T prompt tokens in one eager pass (one SDPA dispatch per layer)
  • Growing KV cache with lazy eval and mx::compile + mx::async_eval pipelined decode

O.7: Qwen3 MoE decode — 11 → 51 tok/s

  • Root cause: mx::compile(shapeless=true) was silently falling back to eager because two MLX primitives (CustomKernel, GatherQMM) lacked output_shapes() — required when the compiled graph is reused across changing KV shapes
  • Fix: upstream PR ml-explore/mlx#3485 adds output_shapes() to both; FetchContent points to the fork until the PR merges
  • Batched prefill (O.7a): T SSM states processed in one Metal dispatch per layer instead of T serial calls
  • Compiled decode (O.7b): mx::compile(shapeless=true) traces the decode graph once; all subsequent steps reuse it regardless of growing KV shape — no re-tracing, no eager fallback

Phase O.1: move LayerKVCache and GemmaLayerKVCache out of their
respective model headers into a shared kv_cache.h.

Phase O.2: extract non-computational shared state (config_, tokenizer_,
weights_, get_weight, infer_quant_bits, num_parameters) from
Qwen3MoeModel into Qwen3MoeModelBase. Qwen3MoeModel now inherits from
both Qwen3MoeModelBase and LanguageModel.

No behavioral change — all compute and service tests pass.
@dexwritescode dexwritescode added the release:skip Skips release creation on merge label May 1, 2026
mx::dequantize returns bfloat16 (matching the scales dtype), so the
entire forward pass in Qwen3MoeModelMLX runs in bfloat16. Calling
array.data<float>() on a bfloat16 array silently reinterprets the
16-bit values as 32-bit floats, producing completely wrong logits.

Fix: mx::astype(logits, mx::float32) before reading data<float>().

Qwen3MoeIntegrationTest.GenerateCapitalOfFrance now passes:
output "<think>\n\n</think>\n\nThe capital of France is **Paris**."
- Batched prefill: process all prompt tokens as a single {seq_len, hidden}
  matrix with causal mask instead of seq_len serial 1-token passes
- Growing KV cache: mx::concatenate vs pre-allocated slice_update; avoids
  SDPA over padded positions and eliminates wasted GPU work per decode step
- Causal mask dtype inferred from embedding matrix (em.dtype()) so float16
  models (Mistral) and bfloat16 models (Llama3) both work correctly
- Add mlx_prefill_batch() and mlx_run_step() split; mlx_build_decode_fn()
  stores compiled_fn slot for upcoming mx::compile(shapeless=true)
- Expose context_size param through LanguageModel::load() and CLI --context flag
- Qwen3 MoE: explicit tensor reshapes for SSM kernel inputs; prep for shapeless compile
…per suite

Each test suite now loads its model and backend once in SetUpTestSuite() and
tears it down in TearDownTestSuite(). Individual tests check skip_reason_ in
SetUp() and call GTEST_SKIP() if the model is absent.

Prevents reloading multi-GB models for every test case, cutting peak memory
from ~60 GB to a single model's footprint and halving test runtime.
- Wrap decode fn in mx::compile(shapeless=true) — Metal kernels now
  compiled once and reused every decode step
- Use -1 for KV sequence dimension in reshape so the compiled graph
  accepts growing KV length without retracing
- Evaluate only logits (outputs[0]) per step; KV caches remain lazy
  until consumed as inputs to the next step

Measured: ~40 tok/s on Mistral-7B-Instruct-v0.3-4bit (M2 Ultra).
Full GPU pipelining (O.6.4) is the next step to close the 2x gap
vs Python mlx-lm (~80 tok/s).
Pass the unevaluated argmax from step N as input token for step N+1
before waiting for step N to complete. async_eval() enqueues GPU work
for step N+1 immediately, so the GPU runs decode steps back-to-back
with no CPU round-trip stall between them.

Adds LlamaModel::mlx_generate_pipelined() and overrides generate() to
dispatch greedy paths through it on Apple Silicon.

Also fixes a service bug where temperature=0.0 from the GUI was being
overridden to 0.7 (the proto3 float default is 0.0, indistinguishable
from "not set" — the service now trusts the value directly).

Mistral-7B-Instruct-v0.3-4bit: ~40 tok/s → ~71 tok/s (warm, greedy).
…mework

Good changes (keep):
- language_model.cpp: add "qwen3_moe" alias alongside "qwen3_5_moe" in factory
- qwen3_moe_model_mlx.h: add generate() override + moe_generate_pipelined() declarations
- qwen3_moe_model_mlx.cpp: add generate() + moe_generate_pipelined() (async_eval pipeline,
  same O.6.4 pattern as LlamaModel)

Broken experiments (need to rework — see O.7 beads issue):
- Chunked KV (kKvChunk=64) with full-model mx::compile: 7.5 tok/s, garbled output
  after first chunk boundary crossing. WORSE than 12.2 tok/s baseline.
- Tried shapeless=true + CustomKernel output_shapes() patch: 3.3 tok/s (generic
  shapeless Metal kernels are slower than shape-specialised ones).
- Tried fixed KV (max_ctx=4096) + plain compile: 5.2-9.8 tok/s (SDPA over all
  4096 positions even when mostly empty kills memory bandwidth).

Root cause confirmed via mlx-lm investigation:
- Full-model mx::compile is the WRONG granularity for hybrid MoE models.
- mlx-lm never compiles the full forward pass; compiles only small sub-functions
  (gate computations with shapeless=True, SSM step with fixed shapes, MoE routing).
- The GatedDeltaNet custom Metal kernel blocks compile fusion across the graph.

Next step (O.7 issue):
1. Remove full-model compile — store raw fn
2. Replace custom Metal kernel with pure MLX ops for SSM step
3. Wrap SSM step as separate compiled fn (shapeless=True, fixed shapes in decode)
4. Revert KV to growing-via-concat, slice to actual length before SDPA
…layer

mlx_ssm_step and mlx_full_attn_step now accept [1,T,hidden] for both
prefill (T>1) and decode (T=1). run_prefill() runs eagerly: one Metal
dispatch per SSM layer (T-loop inside kernel), one SDPA per attention
layer. Previously prefill looped T sequential decode steps.

Remove mx::compile from build_decode_fn() — stock MLX CustomKernel does
not implement Primitive::output_shapes(), so shapeless compile throws.
Raw function stored instead; element-wise fusion is a future concern.

Makefile: add stage-service target so `make run` always bundles a fresh
neurons_service binary into the Flutter app.
…eless=true

mx::compile(shapeless=true) on the decode function was silently falling back to
eager mode because two MLX primitives — CustomKernel and GatherQMM — lacked
output_shapes() implementations, which compile needs when re-using a traced
graph with new shapes.

Fix: point FetchContent to a fork (PR #3485) that adds output_shapes() to both
primitives.  With this in place, compile(shapeless=true) traces the decode graph
once on the first call and reuses it for all subsequent steps regardless of the
growing KV cache shape — no re-tracing, no eager fallback.

Growing-KV + shapeless=true outperforms the pre-allocated fixed-KV alternative
(Option A, also implemented here as a reference path) because SDPA only covers
filled positions rather than a full 4096-slot window from step 1.

Results on Qwen3.6-35B-A3B-4bit (Release, M-series):
  before: 11.2 tok/s  (no compile)
  after:  51.5 tok/s  (compile, shapeless=true, growing KV)

TODO: revert CMakeLists.txt to upstream mlx once PR #3485 merges.
run_decode_step was only calling mx::eval(outputs[0]) (logits). All 20
KV/SSM arrays stayed lazy, each holding a reference to the previous
step's arrays. After N steps this built an N-deep computation chain;
MLX could not free any intermediate graph state.

Fix: async_eval the logits + all KV/SSM outputs together, then block on
logits only. This materialises the state arrays before the next step
begins, breaking the chain and allowing MLX to free intermediate buffers.

Observed before: 20 GB → 32 GB RSS on a 3574-token generation.
After: memory stays flat. Throughput unchanged at ~55 tok/s.
LanguageModel::load() now dispatches qwen3_5_moe / qwen3_moe on Apple
Silicon directly to Qwen3MoeModelMLX, so the #if MLX_BACKEND_ENABLED
block inside Qwen3MoeModel is unreachable dead code.

Remove MlxDecodeState struct, initialize_mlx_state(), build_mlx_compile_fn(),
mlx_state_ member, and the entire 574-line MLX implementation block.
Qwen3MoeModel is now a clean ComputeBackend-only class with no MLX awareness.
Implements the Phase O backend-specific model pattern for Gemma:
GemmaModelBase holds shared config/tokenizer/weights; GemmaModelMLX
(Apple Silicon only) owns mx::array weights and runs the full forward
pass directly against MLX with mx::compile(shapeless=true) and
async_eval GPU pipelining — same architecture as Qwen3MoeModelMLX.

Gemma-specific MLX differences:
- Embedding scale: h * sqrt(hidden_size)
- RMSNorm: rms_norm(x, 1+weight, eps) — weights are zero-deviations
- 4 norms per block (post_attention + post_feedforward residuals)
- GeGLU: gelu(gate)*up — implemented via mx::erf exact formula
- Per-layer rope_theta captured as vector<float> before compile
- LM head: tied embed or separate lm_head.weight

GemmaModel (ComputeBackend fallback) now extends GemmaModelBase,
removing duplicated config_/tokenizer_/weights_ fields.

Verified: GemmaIntegrationTest passes (510ms), CLI returns
"The capital of France is **Paris**." at 19.6 tok/s.
Remove prefill(), decode(), reset_cache(), and backend() from the
LanguageModel public interface. The generate loop is now owned by
GenerateHelper::run() in sampler.h/.cpp, called via lambdas from each
subclass's generate() override.

Changes per class:
- LlamaModel: generate() fallback → GenerateHelper::run(); prefill/
  decode/reset_cache kept public (diagnostic tests use LlamaModel* directly)
- GemmaModel, Qwen3MoeModel: generate() override added; prefill/decode/
  reset_cache moved to private
- GemmaModelMLX, Qwen3MoeModelMLX: same private treatment; generate()
  fallback (non-greedy) → GenerateHelper::run()

LanguageModel now exposes only: generate(), config(), model_type(),
tokenizer(), num_parameters(), and the tool-use / factory interface.

125/125 tests pass. Gemma 48.2 tok/s, Qwen3 MoE 33.7 tok/s (greedy CLI).
mlx_ops::bits() and mlx_ops::lin() replace three identical copies:
- llama_mlx_bits/llama_mlx_lin in llama_model.cpp
- gemma_mlx_bits/gemma_mlx_lin in gemma_model_mlx.cpp
- mlx_bits/mlx_lin in qwen3_moe_model_mlx.cpp

Each file now pulls them in via `using namespace compute::mlx_ops` inside
its anonymous namespace. The WM type alias (unordered_map<string,array>)
is also consolidated to compute::WM in mlx_ops.h.
…ouble-hop

On Apple Silicon (MLX_BACKEND_ENABLED), from_model_dir() now calls
ModelLoader::load_model_mlx() directly to obtain mx::array weights,
eliminating the Tensor intermediary that required a full weight-map
conversion after load. weights_ is left empty on the MLX path; all
inference goes through mlx_weights_ / mlx_embed_mat_.

Non-Apple-Silicon builds use the ComputeBackend/Tensor path unchanged.

Three diagnostic tests that relied on the Tensor-path forward()/
attention_layer() APIs are skipped on Apple Silicon with an explanatory
message — these methods are intentionally unavailable in the MLX path.
@dexwritescode dexwritescode marked this pull request as ready for review May 6, 2026 18:37
@dexwritescode dexwritescode merged commit 61aebab into main May 6, 2026
3 checks passed
@dexwritescode dexwritescode deleted the phase-o-backend-specific-models branch May 6, 2026 18:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

release:skip Skips release creation on merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant