Skip to content

Commit e324766

Browse files
committed
fix group_size=32
1 parent 23cad7c commit e324766

6 files changed

Lines changed: 169 additions & 8 deletions

File tree

csrc/bit_decode/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ set(CMAKE_CUDA_STANDARD 17)
77
set(CMAKE_CUDA_ARCHITECTURES 80)
88

99
# set(INCLUDE_DIR ${PROJECT_SOURCE_DIR}/../../libs/cutlass/include)
10-
set(INCLUDE_DIR /home/ddy/Projects/BitAttn_v2/3rdparty/cutlass/include)
10+
set(INCLUDE_DIR /home/ddy/Projects/BitDecoding/libs/cutlass)
1111

1212
# Enable ccache if available
1313
find_program(CCACHE_PROGRAM ccache)

csrc/bit_decode/decode_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
202202
} else if (params.group_size == 64) {
203203
// run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, num_bits, 64>(params, stream);
204204
} else if (params.group_size == 32) {
205-
// run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, num_bits, 32>(params, stream);
205+
run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, num_bits, 32>(params, stream);
206206
}
207207
} else {
208208
if (params.group_size == 128) {
@@ -220,7 +220,7 @@ template <int num_bits>
220220
void run_kvcache_qpack(Flash_fwd_params &params, cudaStream_t stream) {
221221
if (params.quant_mode == "k-channel") {
222222
if (params.group_size == 32) {
223-
// run_kvcache_qpack_<cutlass::half_t, 128, 1, num_bits, 32>(params, stream);
223+
run_kvcache_qpack_<cutlass::half_t, 128, 1, num_bits, 32>(params, stream);
224224
} else if (params.group_size == 64) {
225225
// run_kvcache_qpack_<cutlass::half_t, 128, 1, num_bits, 64>(params, stream);
226226
} else if (params.group_size == 128) {

csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111

1212
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 128>(Flash_fwd_params &params, cudaStream_t stream);
1313
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 64>(Flash_fwd_params &params, cudaStream_t stream);
14-
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream);
14+
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream);

csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
// template<>
8-
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream) {
9-
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 4, 32>(params, stream);
10-
// }
7+
template<>
8+
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 4, 32>(params, stream);
10+
}
1111
// template<>
1212
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 4, 64>(Flash_fwd_params &params, cudaStream_t stream) {
1313
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 4, 64>(params, stream);

csrc/bit_decode/src/include/kernel_traits.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ struct Flash_fwd_kernel_traits : public Base {
303303
array_aligned<ElementKVPack, cosize_v<SmemLayoutKSize>> smem_Kpack;
304304
array_aligned<ElementKVPack, cosize_v<SmemLayoutVSize>> smem_Vpack;
305305
array_aligned<Element, cosize_v<SmemLayoutAcc>> smem_acc;
306+
array_aligned<__half2, cosize_v<SmemLayoutKParams>> smem_Kparams;
306307
};
307308
static constexpr int kSmemSize_res = int(sizeof(SharedStorage_residual));
308309

evaluation/test.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import torch
2+
import torch.nn as nn
3+
import math
4+
5+
import triton
6+
7+
import numpy as np
8+
import bit_decode_cuda as bit_decode_cuda
9+
from bit_decode import kvcache_pack_int, fwd_kvcache_int
10+
from bit_decode import DynamicCache
11+
12+
13+
def attention_ref(
14+
q,
15+
k,
16+
v,
17+
):
18+
"""
19+
Arguments:
20+
q: (batch_size, seqlen_q, nheads, head_dim)
21+
k: (batch_size, seqlen_k, nheads_k, head_dim)
22+
v: (batch_size, seqlen_k, nheads_k, head_dim)
23+
Output:
24+
output: (batch_size, seqlen_q, nheads, head_dim)
25+
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
26+
"""
27+
dtype_og = q.dtype
28+
29+
d = q.shape[-1]
30+
31+
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
32+
33+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
34+
35+
output = torch.einsum("bhts,bshd->bthd", attention, v)
36+
37+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
38+
39+
40+
# Quantization parameters
41+
quant_mode = "k-channel"
42+
num_bits = 4
43+
pack_nums = 16 / num_bits
44+
group_size = 32
45+
residual_block_size = 128
46+
47+
device = "cuda"
48+
dtype = torch.float16
49+
50+
layer_idx = 0
51+
batch_size = 1
52+
nheads = 32
53+
nheads_k = 32
54+
d = 128
55+
56+
seqlen_q = 1
57+
seqlen_k = 1024
58+
sm_scale = 1.0 / math.sqrt(d)
59+
60+
61+
####### Round 1 : Prefill #######
62+
torch.manual_seed(42)
63+
64+
q = torch.rand(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
65+
k_state = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
66+
v_state = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
67+
68+
residual_len = seqlen_k % residual_block_size
69+
residual = residual_len > 0
70+
seqlen_k_pack = seqlen_k - residual_len
71+
72+
print(f"residual_len: {residual_len}, residual: {residual}, seqlen_k_pack: {seqlen_k_pack}")
73+
74+
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k_pack, seqlen_k_pack,
75+
dtype=torch.int32, device=device)
76+
77+
# Initialize quantization tensors
78+
k_pack = torch.zeros((batch_size, int(seqlen_k_pack // pack_nums), nheads_k, d), dtype=torch.uint16, device=device)
79+
k_params = torch.zeros((batch_size, int(seqlen_k_pack // group_size), nheads_k, d), dtype=torch.float32, device=device)
80+
81+
v_pack = torch.zeros((batch_size, seqlen_k_pack, nheads_k, int(d // pack_nums)), dtype=torch.uint16, device=device)
82+
v_params = torch.zeros((batch_size, int(d // group_size), nheads_k, seqlen_k_pack), dtype=torch.float32, device=device)
83+
84+
# KV Cache Dynamic Cache
85+
past_key_value = DynamicCache()
86+
87+
if residual:
88+
k_state_residual = k_state[:, -residual_len:, :, :]
89+
v_state_residual = v_state[:, -residual_len:, :, :]
90+
k_state_past = k_state[:, :-residual_len, :, :]
91+
v_state_past = v_state[:, :-residual_len, :, :]
92+
past_key_value.update_residual(k_state_residual, v_state_residual, layer_idx)
93+
else:
94+
k_state_past = k_state
95+
v_state_past = v_state
96+
97+
kvcache_pack_int(
98+
k_state_past, k_pack, k_params,
99+
v_state_past, v_pack, v_params,
100+
None, # opt_block_table
101+
cu_seqlens_k,
102+
seqlen_k_pack,
103+
quant_mode,
104+
group_size,
105+
num_bits
106+
)
107+
past_key_value.update_pack(k_pack, k_params, v_pack, v_params, layer_idx)
108+
109+
# self
110+
k_pack_new = torch.empty((batch_size, int(residual_block_size // pack_nums), nheads_k, k_pack.size(-1)), dtype=torch.uint16, device=device)
111+
k_params_new = torch.empty((batch_size, int(residual_block_size // group_size), nheads_k, k_params.size(-1)), dtype=torch.float32, device=device)
112+
v_pack_new = torch.empty((batch_size, residual_block_size, nheads_k, v_pack.size(-1)), dtype=torch.uint16, device=device)
113+
v_params_new = torch.empty((batch_size, v_params.size(1), nheads_k, residual_block_size), dtype=torch.float32, device=device)
114+
115+
####### Round 2-3 : Decode #######
116+
for round_idx in range(32):
117+
k_new = torch.randn(batch_size, 1, nheads_k, d, device=device, dtype=dtype)
118+
v_new = torch.randn(batch_size, 1, nheads_k, d, device=device, dtype=dtype)
119+
120+
# Get kv cache_pack
121+
k_pack, k_params, v_pack, v_params = past_key_value.update_pack(None, None, None, None, layer_idx)
122+
123+
seqlen_pack = v_pack.shape[1]
124+
seqlens_k = torch.full((batch_size,), seqlen_pack, dtype=torch.int32, device=device)
125+
126+
# Get kv cache_residual and append new kv
127+
k_residual = torch.zeros((batch_size, residual_block_size, nheads_k, d), device=device, dtype=dtype)
128+
v_residual = torch.zeros((batch_size, residual_block_size, nheads_k, d), device=device, dtype=dtype)
129+
k_residual_cache, v_residual_cache = past_key_value.update_residual(k_new, v_new, layer_idx)
130+
131+
cur_residual_len = k_residual_cache.shape[1]
132+
print(f"cur_residual_len: {cur_residual_len}")
133+
134+
k_residual[:, :cur_residual_len, :, :] = k_residual_cache
135+
v_residual[:, :cur_residual_len, :, :] = v_residual_cache
136+
137+
out_bitdecode, k_pack_new, k_params_new, v_pack_new, v_params_new = fwd_kvcache_int(
138+
q,
139+
k_pack, k_params,
140+
v_pack, v_params,
141+
k_residual, v_residual, seqlens_k, #seqlens_k
142+
k_pack_new, k_params_new, v_pack_new, v_params_new,
143+
None, # opt_block_table
144+
sm_scale,
145+
quant_mode,
146+
group_size,
147+
residual_block_size,
148+
cur_residual_len, # new_lens
149+
num_bits
150+
)
151+
152+
if cur_residual_len == residual_block_size:
153+
past_key_value.update_pack(k_pack_new, k_params_new, v_pack_new, v_params_new, layer_idx)
154+
past_key_value.clear_residual(layer_idx)
155+
156+
k_state = torch.cat([k_state, k_new], dim=1)
157+
v_state = torch.cat([v_state, v_new], dim=1)
158+
159+
out_ref = attention_ref(q, k_state, v_state)[0]
160+
print(f"Round {round_idx+2}: bitdecode vs pytorch: {(out_bitdecode - out_ref).abs().mean().item()}")

0 commit comments

Comments
 (0)