Skip to content

Aryagm/dflash-mlx

Repository files navigation

dflash-mlx

DFlash implementation for Apple Silicon, using MLX.

Benchmarks

demo-comparison.mp4

Quick start

git clone https://github.com/aryagm/dflash-mlx.git && cd dflash-mlx
uv sync

uv run dflash-mlx --max-new-tokens 128

Defaults to Qwen3-4B BF16. First run downloads the target and draft checkpoints into the Hugging Face cache, which is roughly 12 GB for the default pair. Pass --target-model and --draft-model to override. dflash-mlx-chat for interactive chat, --json for machine-readable output. Benchmark history is opt-in with --history or --history-file.

from dflash_mlx import DFlashGenerator

# First run also downloads the default Qwen3-4B target and DFlash draft weights.
runner = DFlashGenerator()
result = runner.generate("Write a quicksort in Python.", max_new_tokens=128)
print(result.text)

for event in runner.stream("Write a quicksort in Python.", max_new_tokens=128):
    if not event.finished:
        print(event.delta, end="", flush=True)

Use uv run dflash-mlx --stream or uv run dflash-mlx-chat --stream to print verified text as it is committed.

OpenAI-compatible local server

A minimal text-only OpenAI-compatible HTTP server is included for local integrations:

dflash-mlx-openai-server \
  --host 127.0.0.1 \
  --port 8098 \
  --model-id qwen35-27b-dflash \
  --target-model /path/to/target \
  --draft-model /path/to/draft

Endpoints:

  • GET /health
  • GET /v1/models
  • POST /v1/chat/completions

Current limitations:

  • text-only message content
  • no image input

POST /v1/chat/completions supports both full responses and streaming SSE chunks with "stream": true.

Supported models

Target Draft
mlx-community/Qwen3-4B-bf16 (default) z-lab/Qwen3-4B-DFlash-b16
mlx-community/Qwen3.5-4B-MLX-bf16 z-lab/Qwen3.5-4B-DFlash

Qwen3.5 support is functional but incomplete. It is not as fast as the Qwen3 path today because Qwen3.5 uses a more complicated hybrid attention stack with recurrent linear-attention state, so exact partial-block acceptance needs custom cache rollback and currently has weaker long-generation acceptance.

Upstream DFlash has checkpoints for Llama 3.1, Qwen3 Coder, Kimi-K2.5, GPT-OSS, and more in the Hugging Face collection. Adding a new family starts with an adapter in dflash_mlx/adapters.py — see ADDING_MODELS.md.

Benchmarks

Full run details, acceptance stats, and quantized comparisons:

How it works

DFlash trains a small block-diffusion model to propose multiple tokens at once. The target verifies them in a single forward pass and accepts the longest correct prefix — identical output, fewer forward passes, higher throughput.

The original DFlash targets CUDA. dflash-mlx is a native MLX port for Apple Silicon. MLX has no speculative-decoding primitives, so every piece of the draft/verify loop had to be built from scratch on top of Metal:

  • Hidden-state extraction from the target. DFlash's drafter conditions on intermediate layer activations, not just logits. We hook into specific Qwen layers and surface those tensors without breaking the standard forward path or the KV cache, so a single target pass gives us both verification logits and the hidden states the next draft block needs.

  • Parallel block proposal. The draft model runs a block-diffusion denoising loop to propose several tokens at once. This runs entirely on the GPU with its own cache, sharing tokenization and positional context with the target.

  • Single-pass batched verification. Every proposed block is verified in one target forward pass. The target's logits are compared greedily against the draft's samples; we accept the longest matching prefix plus one bonus correction token, which is what makes the output bit-for-bit identical to plain target decoding.

  • Per-layer KV cache rollback on rejection. When the target rejects the tail of a proposal, the KV cache has to be rewound to the exact accepted length — per layer, because Qwen3.5 mixes full attention, sliding-window attention, and recurrent linear-attention state, and each has its own cache shape and rollback rule. Plain MLX caches don't expose this; we extend them.

  • Pluggable adapters. Target-specific concerns (layer ids to tap, cache types, stop tokens, chat template) are isolated in dflash_mlx/adapters.py. The core draft/verify loop is architecture-agnostic, so adding a new family is one adapter file rather than a rewrite.

  • Warm-path throughput engineering. MLX kernel compilation, lazy evaluation, and graph caching all affect the numbers. The bench CLI separates warmup from measurement and pins evaluation points so the reported tok/s reflects steady-state Metal performance, not first-run overhead.

Citation

@article{chen2026dflash,
  title   = {DFlash: Block Diffusion for Flash Speculative Decoding},
  author  = {Chen, Jian and Liang, Yesheng and Liu, Zhijian},
  journal = {arXiv preprint arXiv:2602.06036},
  year    = {2026}
}

License

MIT

About

Exact speculative decoding on Apple Silicon, powered by MLX.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages