Skip to content

Commit c1545a1

Browse files
committed
update examples, readme
1 parent 0ebd3f4 commit c1545a1

20 files changed

Lines changed: 1708 additions & 952 deletions

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
11
# BitDecoding
2+
BitDecoding is a high-performance, GPU-optimized system
3+
designed to accelerate long-context LLMs decoding with a low-bit KV
4+
cache. Acheive more than **3x speedup** than FlashDecoding-v2.
5+
![overview](imgs/overview.png)
6+
![scheme](imgs/scheme.png)
7+
8+
## Benchmark
9+
* Kernel Performance in RTX4090
10+
![overview](imgs/4090.png)
11+
* Kernel Performance in A100
12+
![overview](imgs/a100.png)
13+
14+
## Installation
15+
```
16+
git clone --recursive https://github.com/DD-DuDa/BitDecoding.git
17+
conda create -n bitdecode python=3.10
18+
conda activate bitdecode
19+
pip install -r requirements.txt
20+
python setup.py install
21+
```
222

323
## Quick Start
4-
2. Run with libtorch c++
24+
1. See benchmark/bench_single_decode.ipynb
25+
2. (Optional) Play with libtorch c++
526
```
627
cd libs/
728
wget https://download.pytorch.org/libtorch /cu124/libtorch-shared-with-deps-2.5.1%2Bcu124.zip
@@ -12,4 +33,13 @@
1233
mkdir build && cd build
1334
cmake -DCMAKE_PREFIX_PATH=<libtorch_path> ..
1435
make -j12
15-
```
36+
```
37+
38+
## Release Progress
39+
40+
- [ ] Page Implementation
41+
- [ ] Hopper Implementation
42+
- [ ] End-2-end LLMs Inference
43+
44+
## Citation
45+
If you find BitDecoding useful or want to use in your projects, please kindly cite our paper:

benchmark/bench_single_decode.ipynb

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.

bit_decode/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__version__ = "1.0.0.post1"
2+
3+
from bit_decode.bit_decode_interface import (
4+
kvcache_pack_int,
5+
fwd_kvcache_int
6+
)

bit_decode/bit_decode_interface.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2025, Dayou Du.
2+
3+
from typing import Optional, Union
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
import bit_decode_cuda as bit_decode_cuda
9+
10+
def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torch.Tensor,
11+
v_cache: torch.Tensor, v_pack: torch.Tensor, v_params: torch.Tensor,
12+
opt_block_table: Optional[torch.Tensor] = None,
13+
cu_seqlens_k: torch.Tensor = None,
14+
seqlen_k: int = 0,
15+
quant_mode: str = "k-channel",
16+
group_size: int = 128,
17+
num_bits: int = 4):
18+
19+
batch_size, seqlen_k, nheads_k, d = k_cache.shape
20+
21+
K_unpad = k_cache.reshape(batch_size * seqlen_k, nheads_k, d)
22+
V_unpad = v_cache.reshape(batch_size * seqlen_k, nheads_k, d)
23+
24+
if num_bits == 4:
25+
bit_decode_cuda.kvcache_pack_i4(K_unpad, k_pack, k_params,
26+
V_unpad, v_pack, v_params,
27+
opt_block_table,
28+
cu_seqlens_k,
29+
seqlen_k,
30+
quant_mode,
31+
group_size
32+
)
33+
else:
34+
bit_decode_cuda.kvcache_pack_i2(K_unpad, k_pack, k_params,
35+
V_unpad, v_pack, v_params,
36+
opt_block_table,
37+
cu_seqlens_k,
38+
seqlen_k,
39+
quant_mode,
40+
group_size
41+
)
42+
43+
def fwd_kvcache_int(q: torch.Tensor,
44+
k_pack: torch.Tensor, k_params: torch.Tensor,
45+
v_pack: torch.Tensor, v_params: torch.Tensor,
46+
opt_block_table: Optional[torch.Tensor] = None,
47+
softmax_scale: float = 1.0,
48+
quant_mode: str = "k-channel",
49+
group_size: int = 128,
50+
num_bits: int = 4):
51+
52+
if num_bits == 4:
53+
out_bit = bit_decode_cuda.fwd_kvcache_i4(
54+
q,
55+
k_pack, k_params,
56+
v_pack, v_params,
57+
opt_block_table,
58+
softmax_scale,
59+
quant_mode,
60+
group_size,
61+
False, # is_causal
62+
-1, # window_size_left
63+
-1, # window_size_right
64+
0.0, # softcap
65+
True, # is_rotary_interleaved
66+
0 # num_splits
67+
)
68+
else:
69+
out_bit = bit_decode_cuda.fwd_kvcache_i2(
70+
q,
71+
k_pack, k_params,
72+
v_pack, v_params,
73+
opt_block_table,
74+
softmax_scale,
75+
quant_mode,
76+
group_size,
77+
False, # Added
78+
-1, # Added
79+
-1, # Added
80+
0.0, # Added
81+
True, # Added
82+
0 # Added
83+
)
84+
85+
86+
return out_bit

csrc/bit_decode/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,16 @@ target_link_libraries(test_single_packdecode "${TORCH_LIBRARIES}")
3131
target_include_directories(test_single_packdecode PRIVATE ${INCLUDE_DIR})
3232
target_compile_options(test_single_packdecode PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>)
3333

34+
message(STATUS "Compile benchmarking kernel.")
35+
add_executable(bench_single_packdecode
36+
${PROJECT_SOURCE_DIR}/src/bench_single_packdecode.cu
37+
${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_sm80.cu
38+
${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu
39+
${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu
40+
${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu
41+
${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu
42+
)
43+
target_link_libraries(bench_single_packdecode "${TORCH_LIBRARIES}")
44+
target_include_directories(bench_single_packdecode PRIVATE ${INCLUDE_DIR})
45+
target_compile_options(bench_single_packdecode PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>)
46+

0 commit comments

Comments
 (0)