Skip to content

Commit 202c6af

Browse files
authored
Voxtral Realtime: enable bf16 for Metal backend with quantization (#17845)
The Metal AOTI backend already handles bf16 correctly (fp32 attention masks, fp32 RoPE upcast, dtype-agnostic KV caches and SDPA). Enable --dtype bf16 as the default recipe for Metal CI and update all documentation to recommend bf16 with fpa4w quantization.
1 parent 69094af commit 202c6af

4 files changed

Lines changed: 17 additions & 11 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then
334334
VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w"
335335
elif [ "$QUANT_NAME" = "quantized-int4-metal" ]; then
336336
VR_QUANT_ARGS="--qlinear-encoder fpa4w --qlinear fpa4w"
337+
VR_DTYPE_ARGS="--dtype bf16"
337338
elif [ "$QUANT_NAME" = "quantized-int4-tile-packed" ]; then
338339
VR_QUANT_ARGS="--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d --qlinear 4w --qlinear-packing-format tile_packed_to_4d --qembedding 8w"
339340
VR_DTYPE_ARGS="--dtype bf16"

examples/models/voxtral_realtime/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ python export_voxtral_rt.py \
8484
python export_voxtral_rt.py \
8585
--model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \
8686
--backend metal \
87+
--dtype bf16 \
8788
--streaming \
8889
--output-dir ./voxtral_rt_exports \
8990
--qlinear-encoder fpa4w \

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
Usage:
3131
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602
3232
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming
33-
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal
33+
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal --dtype bf16 --qlinear-encoder fpa4w --qlinear fpa4w
3434
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend cuda --qlinear 4w
3535
"""
3636

examples/models/voxtral_realtime/model.md

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,15 @@ or masked-scatter like the original non-realtime Voxtral).
7474

7575
## Memory Footprint
7676

77-
Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × 4 bytes
78-
≈ 832 MB. Encoder KV caches (streaming): 32 layers × 2 × 1500 × 32 ×
79-
64 × 4 bytes ≈ 786 MB.
77+
Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem.
78+
fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming):
79+
32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB,
80+
bf16: ≈ 393 MB.
8081

8182
Runtime memory = model weights (from `.pte`) + KV caches + working
82-
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~4 GB
83-
(8w), ~2 GB (4w/8da4w).
83+
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~8 GB
84+
(bf16), ~4 GB (8w), ~2 GB (4w/8da4w). Metal and CUDA backends are recommended to use
85+
bf16 (`--dtype bf16`) when quantization is enabled.
8486

8587
## Class Hierarchy
8688

@@ -163,8 +165,9 @@ fused kernel with causal masking via `start_pos` + `is_causal=True`.
163165
Handles GQA expansion internally and upcasts to float32.
164166

165167
**Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps`
166-
which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth
167-
overhead of `repeat_interleave`. Uses explicit additive attention masks
168+
which handles GQA natively (the kernel infers the group ratio from differing
169+
Q vs K/V head counts), avoiding the memory bandwidth overhead of
170+
`repeat_interleave`. Uses explicit additive attention masks
168171
that must match the Q/K/V dtype (the kernel reads masks as `device T*`).
169172
Used for both decoder (GQA, `transpose_kv=False`) and streaming encoder
170173
(no GQA, `transpose_kv=True`).
@@ -280,7 +283,7 @@ enabling streaming of arbitrary length audio.
280283
5-8, giving query 5 full access to its window.
281284
- Default `max_enc_len=750` (matching the model's trained
282285
sliding window). Configurable via `--max-enc-len`.
283-
- Memory: 32 layers × 2 × 1500 × 32 × 64 × 4 bytes ≈ 786 MB (fp32)
286+
- Memory: 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem ≈ 786 MB (fp32), 393 MB (bf16)
284287
- Duration: unlimited (ring buffer overwrites old entries, RoPE computed on-the-fly)
285288

286289
**Naming note:** `max_enc_len` in `StreamingAudioEncoderExport` (default
@@ -370,7 +373,7 @@ Parakeet pattern), allowing different configs for encoder vs decoder:
370373
--qlinear 8da4w # decoder linear layers
371374
--qembedding 8w # embedding layer
372375

373-
# Metal
376+
# Metal (use --dtype bf16 for reduced memory and improved throughput)
374377
--qlinear-encoder fpa4w # encoder linear layers
375378
--qlinear fpa4w # decoder linear layers
376379

@@ -428,7 +431,8 @@ of ~34 GB for the full-size model):
428431
1. **Meta device construction**`with torch.device("meta"):` builds the
429432
model with zero-storage parameter tensors (shape/dtype metadata only).
430433
2. **safetensors lazy access**`safe_open` loads tensors on demand, cast
431-
to the configured dtype (`--dtype`, default fp32; CUDA uses bf16).
434+
to the configured dtype (`--dtype`, default fp32; bf16 recommended for
435+
Metal and CUDA with quantization).
432436
3. **`assign=True` state dict loading** — replaces meta tensors by reference
433437
instead of copying into pre-allocated storage. No duplication.
434438
4. **Post-load fixups** — re-tie `output.weight = tok_embeddings.weight`

0 commit comments

Comments
 (0)