@@ -74,13 +74,15 @@ or masked-scatter like the original non-realtime Voxtral).
7474
7575## Memory Footprint
7676
77- Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × 4 bytes
78- ≈ 832 MB. Encoder KV caches (streaming): 32 layers × 2 × 1500 × 32 ×
79- 64 × 4 bytes ≈ 786 MB.
77+ Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem.
78+ fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming):
79+ 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB,
80+ bf16: ≈ 393 MB.
8081
8182Runtime memory = model weights (from ` .pte ` ) + KV caches + working
82- memory. Weight sizes depend on quantization: ~ 16 GB (fp32), ~ 4 GB
83- (8w), ~ 2 GB (4w/8da4w).
83+ memory. Weight sizes depend on quantization: ~ 16 GB (fp32), ~ 8 GB
84+ (bf16), ~ 4 GB (8w), ~ 2 GB (4w/8da4w). Metal and CUDA backends are recommended to use
85+ bf16 (` --dtype bf16 ` ) when quantization is enabled.
8486
8587## Class Hierarchy
8688
@@ -163,8 +165,9 @@ fused kernel with causal masking via `start_pos` + `is_causal=True`.
163165Handles GQA expansion internally and upcasts to float32.
164166
165167** Metal:** ` MetalSDPA ` uses ` torch.ops.aten._scaled_dot_product_attention_math_for_mps `
166- which handles GQA natively via ` gqa_factor ` , avoiding the memory bandwidth
167- overhead of ` repeat_interleave ` . Uses explicit additive attention masks
168+ which handles GQA natively (the kernel infers the group ratio from differing
169+ Q vs K/V head counts), avoiding the memory bandwidth overhead of
170+ ` repeat_interleave ` . Uses explicit additive attention masks
168171that must match the Q/K/V dtype (the kernel reads masks as ` device T* ` ).
169172Used for both decoder (GQA, ` transpose_kv=False ` ) and streaming encoder
170173(no GQA, ` transpose_kv=True ` ).
@@ -280,7 +283,7 @@ enabling streaming of arbitrary length audio.
280283 5-8, giving query 5 full access to its window.
281284- Default ` max_enc_len=750 ` (matching the model's trained
282285 sliding window). Configurable via ` --max-enc-len ` .
283- - Memory: 32 layers × 2 × 1500 × 32 × 64 × 4 bytes ≈ 786 MB (fp32)
286+ - Memory: 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem ≈ 786 MB (fp32), 393 MB (bf16 )
284287- Duration: unlimited (ring buffer overwrites old entries, RoPE computed on-the-fly)
285288
286289** Naming note:** ` max_enc_len ` in ` StreamingAudioEncoderExport ` (default
@@ -370,7 +373,7 @@ Parakeet pattern), allowing different configs for encoder vs decoder:
370373--qlinear 8da4w # decoder linear layers
371374--qembedding 8w # embedding layer
372375
373- # Metal
376+ # Metal (use --dtype bf16 for reduced memory and improved throughput)
374377--qlinear-encoder fpa4w # encoder linear layers
375378--qlinear fpa4w # decoder linear layers
376379
@@ -428,7 +431,8 @@ of ~34 GB for the full-size model):
4284311 . ** Meta device construction** — ` with torch.device("meta"): ` builds the
429432 model with zero-storage parameter tensors (shape/dtype metadata only).
4304332 . ** safetensors lazy access** — ` safe_open ` loads tensors on demand, cast
431- to the configured dtype (` --dtype ` , default fp32; CUDA uses bf16).
434+ to the configured dtype (` --dtype ` , default fp32; bf16 recommended for
435+ Metal and CUDA with quantization).
4324363 . ** ` assign=True ` state dict loading** — replaces meta tensors by reference
433437 instead of copying into pre-allocated storage. No duplication.
4344384 . ** Post-load fixups** — re-tie ` output.weight = tok_embeddings.weight `
0 commit comments