Skip to content

humanrouter/mlx-ddtree-failed

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mlx-ddtree-failed

Speculative decoding on Apple Silicon using DFlash + DDTree — what we tried, what we measured, and why we stopped.

This repo contains the actual code we ran, the benchmark results, and the full write-up of a multi-week experiment trying to get speculative decoding to beat raw autoregressive inference for Qwen 3.5/3.6 27B 4-bit on a Mac Studio M3 Ultra. We did not succeed. This is a record of what we tried so others don't have to repeat it.


Background

DFlash is a speculative decoding method designed for hybrid attention+linear models (Mamba-style SSMs). The key idea: a small 2B "drafter" model predicts a block of tokens at once using the large target model's hidden states, the target verifies the block, and accepted tokens are committed in a single forward pass. The z-lab/Qwen3.5-27B-DFlash drafter was officially released alongside Qwen 3.5 and achieves ~74 tok/s with the matched 3.5 target at ~92% acceptance.

DDTree extends DFlash by constructing a speculation tree (multiple candidate paths) instead of a single path. At each tree verify step, the target evaluates multiple draft branches simultaneously under a tree-structured attention mask, then walks the tree to find the deepest accepted path. This recovers some of the per-token acceptance loss from weaker drafters.

We implemented DDTree for MLX from scratch. The core paper is liranringel/ddtree.


Setup

  • Target model: mlx-community/Qwen3.5-27B-4bit and mlx-community/Qwen3.6-27B-4bit
  • Drafter: z-lab/Qwen3.5-27B-DFlash (stock matched), z-lab/Qwen3.6-27B-DFlash (official, released later), and a custom continue-trained drafter (see below)
  • Hardware: Mac Studio M3 Ultra, 128 GB unified memory, macOS 15.x, mlx 0.25.x
  • Baseline: mlx_lm.server serving raw 4-bit Qwen — no speculation

Qwen 3.5 and 3.6 share byte-identical tokenizers and text configs (verified). The drafter architecture ports unchanged; only the target distribution differs.


What we tried

Round 11 — DFlash baseline with matched drafter (Qwen 3.5)

Ran dflash_server.py (our FastAPI wrapper around dflash_mlx) with the stock z-lab/Qwen3.5-27B-DFlash drafter against the Qwen 3.5 27B 4-bit target.

Result:

  • Code prompts: ~72–74 tok/s ✅ (vs raw 3.5 at ~37 tok/s)
  • Prose prompts: ~14–15 tok/s ❌ (vs raw 3.5 at ~37 tok/s)

The prose regression was immediate and large. On long-form writing, finance analysis, and anything with high per-token entropy, the drafter commits a single candidate path that gets rejected frequently, wasting all draft compute. The single-path-commit design is fundamentally sensitive to target entropy.

Round 12 — DDTree budget=4 with matched Qwen 3.5 drafter

Added tree budgeting: instead of committing one path, the drafter generates a budget-4 tree. The target verifies all branches in one forward pass, walks the tree, and commits the deepest accepted path.

Result:

  • Code prompts: ~39–40 tok/s ✅ (small gain over raw)
  • Prose prompts: ~24–25 tok/s (better than DFlash-alone, still -13 vs raw)
  • Average across 9 prompt types: ~31.9 vs raw 37.3 tok/s ❌

DDTree recovers most of the DFlash prose regression but the verify overhead on tree paths costs enough that the average is still ~5 tok/s under raw. Budget=8 is worse than budget=4: more verify compute, marginal acceptance gain.

Round 13 — Qwen 3.6 with mismatched 3.5 drafter

Qwen 3.6 is a quality upgrade (better outputs) at identical decode speed. We tested whether the 3.5 drafter could still draft for 3.6.

Preflight CE/top-1 check (runtime-faithful, 15-position block):

  • 3.5 drafter on 3.5 target (matched): CE 1.667, top-1 60.4%
  • 3.5 drafter on 3.6 target (mismatched): CE 4.739, top-1 37.3%

Result: DDTree b4 with mismatched drafter: avg 29.7 tok/s. Worse than 3.5 matched DDTree. Raw Qwen 3.6: 37.1 tok/s. Raw 3.6 wins by 7+ tok/s.

Round 14 — Continue-training the drafter on Qwen 3.6 distribution (Phase A)

Since the distribution shift was quantified (CE gap: 3.072 nats), we set up a self-distillation pipeline to continue-train the drafter on 3.6's output:

Pipeline (dflash-train/):

  1. preflight.py — validated runtime-faithful CE/top-1 against a known-good matched pair
  2. baseline_gap.py — measured the 3.5-drafter-on-3.6-target gap
  3. gen_data.py — sampled 1000 Alpaca-train prompts, greedy-generated from 3.6 (temp=0, matching runtime verify), captured hidden states from layers {2, 17, 32, 47, 62} in BF16 (~20 GB on disk)
  4. train.py — Phase A: froze everything except draft_model.fc (131M params) and draft_model.hidden_norm (5.1K params); AdamW + cosine decay; 500 steps; ~9 min wall-clock
  5. verify.py — measured CE and top-1 for the retrained checkpoint

Phase A training results:

Run Steps Best val loss CE Top-1 Gap closed
Smoke (3 steps) 3 5.094 4.740 37.3% 0%
POC 1 100 4.482 4.448 38.2% 9.5%
POC 2 300 3.825 3.358 39.6% 44.9%
Full Phase A 500 3.244 3.303 39.6% 46.7%

Phase A saturated after ~300 steps. fc + hidden_norm can only absorb a linear manifold shift; non-linear per-layer corrections require unfreezing decoder layers (Phase B).

End-to-end tok/s for Phase A drafter vs raw:

Config Code Prose Reasoning Overall
D — raw 3.6 36.6 37.5 37.2 37.1
E — DDTree b4 + Phase A drafter 39.3 24.6 33.2 31.4
F — DFlash-only + Phase A drafter 31.4 14.3 24.2 22.2

Phase A added +1.6 tok/s over the mismatched drafter on DDTree but was still 5.7 tok/s below raw. CE improved 46.7% but tok/s improved only ~6%. Threshold effects dominate: the marginal acceptance probability has to cross a specific value to be commit-worthy.

Round 15 — Official z-lab/Qwen3.6-27B-DFlash drafter (best possible case)

Z-Lab released z-lab/Qwen3.6-27B-DFlash, a 2B-parameter drafter trained specifically for Qwen 3.6. This is the ceiling case: a purpose-built, fully-trained matched drafter.

Preflight:

  • 3.6 target + official 3.6 drafter: CE 1.633, top-1 61.8% — slightly better than the 3.5 matched pair. A real matched pair.

Full benchmark (9 prompts × 3 runs, non-streaming):

Prompt Raw 3.6 DFlash b1 DDTree b4 DDTree b8
code_twosum 35.4 31.7 38.4 36.0
code_debug 36.6 27.8 40.4 32.8
code_lru 37.7 35.5 42.2 41.1
write_story 37.6 12.8 24.1 18.2
write_explain 37.0 14.7 24.3 19.8
reason_sheep 36.8 20.7 31.6 26.0
reason_alice 37.7 26.2 37.7 32.0
finance_dcf 37.7 16.0 28.3 22.7
analysis_arch 37.8 14.9 27.4 21.5

Averages (excl. short_factual):

Config Code avg Prose avg All avg
Raw 3.6 36.6 37.5 37.1
DFlash-only (official drafter) 31.7 14.6 22.3
DDTree b4 (official drafter) 40.3 26.0 32.7
DDTree b8 (official drafter) 36.6 20.6 27.8

Final verdict:

  • DDTree b4 wins on code by +3.7 tok/s
  • DDTree b4 loses on prose by -11.5 tok/s
  • Raw 3.6 wins overall: 37.1 vs 32.7

Round 16 — Updated z-lab weights (re-test 2026-04-24)

Z-Lab updated the drafter checkpoint. Re-downloaded and re-tested. Code held at ~35–37 tok/s (+2–3 vs prior round), prose dropped to ~14 tok/s on the first write prompt — identical pattern. Stopped early. The prose regression is structural.


Why it doesn't work

The fundamental issue is single-path speculation and target entropy.

DFlash commits a single speculated sequence. When the target model's output distribution has significant entropy (prose, reasoning, finance analysis), the committed path is rejected frequently enough that draft compute is largely wasted. This is not a drafter quality problem: we tested four drafters across matched and mismatched conditions, all showing the same ~50–60% prose regression.

DDTree partially addresses this with tree branching — instead of one path, you evaluate a tree and commit the best accepted branch. This recovers most of the prose regression but the verify overhead of computing tree attention adds enough fixed cost that the average remains below raw.

The math:

  • Raw Qwen 3.6 27B 4-bit decodes at ~37 tok/s flat across all prompt types
  • Code prompts have low per-token entropy (correct tokens are fairly predictable), so the drafter accepts well → DDTree beats raw by +3–4 tok/s
  • Prose prompts have high per-token entropy, drafter acceptance is low, verify overhead dominates → DDTree loses to raw by 10–12 tok/s
  • Budget=8 is worse than budget=4 everywhere: bigger trees don't compensate for higher verify cost
  • The only path to parity with raw would be a code-only routing layer that sends coding requests to DDTree b4 and everything else to raw

CE → tok/s translation is nonlinear. Phase A training closed 47% of the CE gap but only moved DDTree tok/s by ~6%. Acceptance is a thresholded quantity: mid-range CE improvement doesn't cleanly map to mid-range acceptance improvement. The Phase A checkpoint's tok/s would have predicted much better if CE linearly predicted acceptance.


What's in this repo

mlx-ddtree-failed/
├── ddtree_server.py          # OpenAI-compatible FastAPI server for DDTree
├── dflash_server.py          # OpenAI-compatible FastAPI server for DFlash-only
├── pyproject.toml            # Package setup (ddtree-mlx)
├── ddtree_mlx/
│   ├── __init__.py
│   ├── cache.py              # Snapshot/restore/commit for attention and linear caches
│   ├── compile.py            # Build MLX tensors for tree verify forward
│   ├── kernels.py            # Custom Metal kernels: parent-aware conv1d + gated-delta
│   ├── model_refs.py         # Drafter resolution helpers
│   ├── runtime.py            # Main DDTree generate loop (draft → tree → verify → commit)
│   ├── tree.py               # DDTree construction: Algorithm 1 from the paper
│   └── verify.py             # Tree verification forward pass (attention + linear layers)
└── benchmarks/
    ├── run_bench.py                        # Benchmark script (10 prompts × 3 runs, non-streaming)
    ├── DRAFTER_TRAINING_EXPERIMENT.md      # Phase A continue-training full write-up
    └── QWEN36_OFFICIAL_DRAFTER.md         # Official z-lab drafter results and analysis

Not included:

  • The dflash-mlx package (a dependency, not ours — install from PyPI)
  • Training data and checkpoints (deleted, ~20 GB + 3 GB)
  • Phase A training scripts (preflight.py, gen_data.py, train.py, verify.py) — these were in a separate directory that was cleaned up. The methodology is fully documented in DRAFTER_TRAINING_EXPERIMENT.md.

Usage

Dependencies

pip install dflash-mlx mlx-lm fastapi uvicorn
pip install -e .  # installs ddtree-mlx

dflash-mlx must be installed first — DDTree wraps its runtime internals.

DDTree server

python3 ddtree_server.py \
  --model mlx-community/Qwen3.6-27B-4bit \
  --draft z-lab/Qwen3.6-27B-DFlash \
  --tree-budget 4 \
  --port 8006

The server exposes a standard OpenAI-compatible /v1/chat/completions endpoint.

DFlash-only server

python3 dflash_server.py \
  --model mlx-community/Qwen3.6-27B-4bit \
  --draft z-lab/Qwen3.6-27B-DFlash \
  --port 8007

Benchmark

python3 benchmarks/run_bench.py \
  --port 8006 \
  --model mlx-community/Qwen3.6-27B-4bit \
  --label ddtree_b4_official \
  --out results.json

Architecture notes

Tree verification for hybrid models

Qwen 3.5/3.6 is a hybrid model with interleaved attention and linear (Mamba-style gated-delta) layers. Tree verification works differently for each:

  • Attention layers: each tree node gets its own position ID for RoPE; the attention mask is a (T, T) tree-structured mask where a node can only attend to its ancestors. Standard SDPA handles this with an additive mask.
  • Linear layers (gated-delta recurrence): these are stateful. Each token's state depends on its predecessor, so a tree of T nodes needs T different recurrent states — one per node, forked from that node's parent. We implement this with custom Metal kernels (kernels.py) that fan out states from parent to children in a single GPU launch.

The tree-aware path keeps logits exact for every branch. The legacy DFS path (fallback when DDTREE_TREE_AWARE_LINEAR=0) keeps only the DFS-prefix exact and re-forwards divergent suffixes.

Cache commit strategies

Three strategies depending on what was accepted:

  1. Fast path: accepted path is a prefix of the DFS traversal order → use tape rollback for linear layers, trim KV cache offset. No re-forward.
  2. Tree-aware path: accepted path is an arbitrary tree branch → pack accepted KV entries from tree-index positions into the cache, set linear layer states from the per-node captured state for the final accepted node.
  3. Slow path: re-forward accepted tokens from cache snapshot. Lossless but costs a forward pass.

Metal kernels

ddtree_gated_delta_tree: parent-indexed gated-delta recurrence over all tree nodes in one Metal launch. Each thread handles one position in the Dk-dimensional state space. Parents array is -1 for the root (uses base state from KV cache) and t for tree node t's parent.

ddtree_tree_conv1d: parent-aware depthwise causal conv over tree nodes. Same parent-indexing scheme. Used for Mamba's short conv layers.

Both kernels are compiled lazily at first use via mx.fast.metal_kernel. They fall back to a sequential Python loop when Metal is unavailable.


Conclusion

If you're doing speculative decoding with a hybrid attention+linear model on Apple Silicon:

  • DFlash-only is only worth it for code-heavy workloads with a well-matched drafter. Anything with high output entropy will regress severely.
  • DDTree b4 beats raw on code (+3–4 tok/s) but loses on everything else. If you can classify requests at routing time, DDTree-for-code is the only remaining win.
  • Drafter quality above a threshold doesn't help. CE 1.63 (official matched drafter) and CE 3.30 (our Phase A retrain) produce the same pattern: code wins, prose loses, overall loss. The bottleneck is the verify overhead, not drafter acceptance rate.
  • Budget 4 > 8 > 1 for this model size. More tree nodes cost more verify compute than they gain in accepted depth.

Raw mlx_lm.server at 37 tok/s flat is the right answer for a general-purpose workload.

About

DDTree speculative decoding for Qwen 3.5/3.6 on Apple Silicon — what we tried and why raw inference still wins

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages