Skip to content

Alberto-Codes/turboquant-vllm

Repository files navigation

PyPI Python License Ruff docs vetted

turboquant-vllm

TurboQuant KV cache compression as a drop-in vLLM plugin. 3.76x KV cache compression with asymmetric K/V support, validated across 8 models.

Implements Google's TurboQuant (ICLR 2026) — the first KV cache quantization method with provably near-optimal distortion rates.

Install

pip install turboquant-vllm[vllm]

Or with uv:

uv add turboquant-vllm --extra vllm

Quick Start (vLLM)

The TQ4 attention backend registers automatically via vLLM's plugin system:

vllm serve meta-llama/Llama-3.1-8B-Instruct --attention-backend CUSTOM

No code changes required. The plugin compresses KV cache pages to 68 bytes/token/head (vs 256 bytes FP16). For asymmetric K/V compression:

TQ4_K_BITS=4 TQ4_V_BITS=3 vllm serve meta-llama/Llama-3.1-8B-Instruct --attention-backend CUSTOM

Quick Start (HuggingFace)

from transformers import DynamicCache
from turboquant_vllm import CompressedDynamicCache

cache = DynamicCache()
compressed = CompressedDynamicCache(cache, head_dim=128, k_bits=4, v_bits=3)

# Pass cache (not the wrapper) to model.generate()
# Compression happens transparently on every cache.update()

Compression Quality

Per-layer minimum cosine similarity on real model activations (128-token prefill, RTX 4090):

Model head_dim K4/V4 cosine K4/V3 cosine
Llama 3.1 8B 128 0.9947 0.9823
Qwen2.5 3B 128 0.9935 0.9823
Mistral 7B 128 0.9947 0.9825
Phi-3-mini 96 0.9950 0.9827
Phi-4 128 0.9945 0.9824
Gemma 2 2B 256 0.9948 0.9823
Gemma 3 4B 256 0.9911 0.9794
Molmo2 4B 128 0.9943 0.9821

Validate any model yourself with the verify CLI:

python -m turboquant_vllm.verify --model meta-llama/Llama-3.1-8B --bits 4
python -m turboquant_vllm.verify --model meta-llama/Llama-3.1-8B --k-bits 4 --v-bits 3 --threshold 0.97

Serving Performance

Llama-3.1-8B-Instruct on RTX 4090, 200 concurrent requests (Exp 029):

Metric Baseline TQ4 (K4/V4) Delta
Request throughput 8.14 req/s 7.55 req/s -7.3%
Output tok/s 1,042 967 -7.3%
Median TTFT 9,324 ms 6,977 ms -25.2%
Median TPOT 47.6 ms 143.6 ms +201%

TQ4 reduces time-to-first-token by 25% (smaller cache pages = faster prefill) but increases per-token decode latency ~3x due to online decompression. Net throughput impact is -7% at high concurrency. Best suited for memory-bound workloads: long contexts, high batch sizes, or limited VRAM.

How It Works

Implements Google's TurboQuant algorithm (ICLR 2026):

  1. Random orthogonal rotation maps each KV vector onto coordinates that follow a known Beta distribution
  2. Lloyd-Max scalar quantization finds optimal centroids for that distribution at 3-4 bits per coordinate
  3. Nibble packing stores two 4-bit indices per byte for 3.76x compression
  4. Incremental dequantization only decompresses new tokens each decode step, keeping overhead at 1.78x

What Gets Compressed

Data Compressed Format
Key cache vectors Yes (k_bits, default 4) uint8 nibble-packed indices + fp32 norms
Value cache vectors Yes (v_bits, default 4) uint8 nibble-packed indices + fp32 norms
Rotation matrices No Generated once per layer from fixed seed
Lloyd-Max codebook No Computed once, shared across all layers

Roadmap

  • Core TurboQuant algorithm (Lloyd-Max, MSE quantizer, compressors)
  • CompressedDynamicCache with incremental dequantization
  • vLLM TQ4 attention backend plugin
  • Fused Triton kernels (4.5x compress, 4x decompress speedup)
  • Fused paged TQ4 decode with 8.5x HBM bandwidth reduction
  • INT8 Q@K^T prefill path
  • CUDA graph compatibility (buffer pre-allocation)
  • Multi-model validation (8 families, head_dim 64/96/128/256)
  • Sliding window attention bypass (Gemma 2/3)
  • Asymmetric K/V compression (k_bits/v_bits)
  • Sparse V decompression for decode acceleration
  • Container image with turboquant-vllm baked in
  • Full Flash Attention fusion with fp32 online softmax

Documentation

  • Architecture -- Module map, dependency DAG, data flow diagrams
  • Roadmap -- Detailed implementation status and experiment results
  • Development Guide -- Setup, build, test, lint commands

Citation

@inproceedings{zandieh2025turboquant,
  title={TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate},
  author={Zandieh, Amir and Han, Insu and Daliri, Majid and Karbasi, Amin},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}

License

Apache 2.0

About

TurboQuant KV cache compression plugin for vLLM — asymmetric K/V, 8 models validated, consumer GPUs

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors