1- ![ overview] ( imgs/title.png )
1+ <p align =" center " >
2+ <img src =" imgs/title.png " width =" 400 " >
3+ </p >
24
35<div align =" center " >
46
5- ## Efficient low-bit KV cache decoding
7+ ## Efficient LLMs decoding with low-bit KV cache
68
79[ ![ arXiv] ( https://img.shields.io/badge/arXiv-2410.13276-b31b1b.svg )] ( https://arxiv.org/abs/2503.18773 )
810[ ![ License] ( https://img.shields.io/badge/License-MIT-green.svg )] ( LICENSE )
@@ -16,13 +18,12 @@ cache. Achieve **3-9x speedup** than Flash-Decoding-v2.
1618
1719
1820## News
19- * [ 2025.11] 🔥 BitDecoding has been accepted to HPCA 2025 !
21+ * [ 2025.11] 🔥 BitDecoding has been accepted to HPCA 2026 !
2022
2123## Benchmark
22- * Kernel Performance in RTX4090
23- ![ overview] ( imgs/4090.png )
24- * Kernel Performance in A100
25- ![ overview] ( imgs/a100.png )
24+ * Kernel Performance in Blackwell GPU
25+ ![ overview] ( imgs/blackwell.jpg )
26+
2627
2728## Installation
2829```
@@ -34,17 +35,51 @@ python setup.py install
3435```
3536
3637## Quick Start
37- 1 . See benchmark/bench_single_decode.ipynb
38- 2 . (Optional) Play with libtorch c++
39- ```
40- # download libtorch
4138
39+ ``` python
40+ import torch
41+ import math
42+ from bit_decode import kvcache_pack_int, fwd_kvcache_int
43+
44+ # Parameters
45+ batch_size, nheads, nheads_k, d = 1 , 32 , 32 , 128
46+ seqlen_q, seqlen_kv = 1 , 4096
47+ num_bits, group_size = 4 , 128 # 4-bit quantization
48+ quant_mode = " k-channel"
49+ pack_nums = int (16 / num_bits)
50+
51+ # Input tensors
52+ q = torch.randn(batch_size, seqlen_q, nheads, d, device = " cuda" , dtype = torch.float16)
53+ k_cache = torch.randn(batch_size, seqlen_kv, nheads_k, d, device = " cuda" , dtype = torch.float16)
54+ v_cache = torch.randn(batch_size, seqlen_kv, nheads_k, d, device = " cuda" , dtype = torch.float16)
55+
56+ # Quantized KV cache buffers
57+ k_pack = torch.zeros((batch_size, seqlen_kv // pack_nums, nheads_k, d), dtype = torch.uint16, device = " cuda" )
58+ k_params = torch.zeros((batch_size, seqlen_kv // group_size, nheads_k, d), dtype = torch.float32, device = " cuda" )
59+ v_pack = torch.zeros((batch_size, seqlen_kv, nheads_k, d // pack_nums), dtype = torch.uint16, device = " cuda" )
60+ v_params = torch.zeros((batch_size, d // group_size, nheads_k, seqlen_kv), dtype = torch.float32, device = " cuda" )
61+ cu_seqlens_k = torch.arange(0 , (batch_size + 1 ) * seqlen_kv, seqlen_kv, dtype = torch.int32, device = " cuda" )
62+
63+ # Pack KV cache
64+ kvcache_pack_int(k_cache, k_pack, k_params, v_cache, v_pack, v_params,
65+ None , cu_seqlens_k, seqlen_kv, quant_mode, group_size, num_bits)
66+
67+ # Decode with BitDecoding
68+ output = fwd_kvcache_int(q, k_pack, k_params, v_pack, v_params, None ,
69+ 1.0 / math.sqrt(d), quant_mode, group_size, num_bits)
70+ ```
71+
72+ ## Examples
73+
74+ - ** Benchmark notebook** : See [ benchmark/bench_single_decode.ipynb] ( benchmark/bench_single_decode.ipynb )
75+ - ** End-to-end inference** : See [ e2e branch] ( https://github.com/DD-DuDa/BitDecoding/tree/e2e )
76+ - ** (Optional) LibTorch C++ build** :
77+ ``` bash
4278 cd BitDecoding/csrc/bit_decode
4379 mkdir build && cd build
4480 cmake -DCMAKE_PREFIX_PATH=< libtorch_path> ..
4581 make -j12
4682 ```
47- 3. End2end inference example, please see [e2e](https://github.com/DD-DuDa/BitDecoding/tree/e2e)
4883
4984# # Citation
5085If you find BitDecoding useful or want to use in your projects, please kindly cite our paper:
0 commit comments