Skip to content

Commit 409f118

Browse files
committed
update
1 parent c531ba4 commit 409f118

7 files changed

Lines changed: 145 additions & 57 deletions

benchmark/bench_single_decode.ipynb

Lines changed: 20 additions & 45 deletions
Large diffs are not rendered by default.

csrc/bit_decode/decode_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
190190
} else if (params.group_size == 64) {
191191
// run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, num_bits, 64>(params, stream);
192192
} else if (params.group_size == 32) {
193-
// run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, num_bits, 32>(params, stream);
193+
run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, num_bits, 32>(params, stream);
194194
}
195195
} else {
196196
if (params.group_size == 128) {
@@ -212,7 +212,7 @@ void run_kvcache_qpack(Flash_fwd_params &params, cudaStream_t stream) {
212212
} else if (params.group_size == 64) {
213213
// run_kvcache_qpack_<cutlass::half_t, 128, 1, num_bits, 64>(params, stream);
214214
} else if (params.group_size == 32) {
215-
// run_kvcache_qpack_<cutlass::half_t, 128, 1, num_bits, 32>(params, stream);
215+
run_kvcache_qpack_<cutlass::half_t, 128, 1, num_bits, 32>(params, stream);
216216
}
217217
} else {
218218
if (params.group_size == 128) {

csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66

77
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
88
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream);
9-
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);
9+
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);

csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111

1212
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 128>(Flash_fwd_params &params, cudaStream_t stream);
1313
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 64>(Flash_fwd_params &params, cudaStream_t stream);
14-
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream);
14+
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream);

csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &param
1212
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream) {
1313
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 64>(params, stream);
1414
// }
15-
// template<>
16-
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream) {
17-
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 32>(params, stream);
18-
// }
15+
template<>
16+
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream) {
17+
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 32>(params, stream);
18+
}

csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
// template<>
8-
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream) {
9-
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 4, 32>(params, stream);
10-
// }
7+
template<>
8+
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 4, 32>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 4, 32>(params, stream);
10+
}
1111
// template<>
1212
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 4, 64>(Flash_fwd_params &params, cudaStream_t stream) {
1313
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 4, 64>(params, stream);

evaluation/test.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.nn as nn
3+
import math
4+
import triton
5+
from einops import rearrange, repeat
6+
import numpy as np
7+
8+
from flash_attn import flash_attn_with_kvcache
9+
from bit_decode import kvcache_pack_int, fwd_kvcache_int
10+
11+
12+
def attention_ref(
13+
q,
14+
k,
15+
v,
16+
):
17+
"""
18+
Arguments:
19+
q: (batch_size, seqlen_q, nheads, head_dim)
20+
k: (batch_size, seqlen_k, nheads_k, head_dim)
21+
v: (batch_size, seqlen_k, nheads_k, head_dim)
22+
Output:
23+
output: (batch_size, seqlen_q, nheads, head_dim)
24+
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
25+
"""
26+
dtype_og = q.dtype
27+
28+
d = q.shape[-1]
29+
30+
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
31+
32+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
33+
34+
output = torch.einsum("bhts,bshd->bthd", attention, v)
35+
36+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
37+
38+
39+
# Define constants
40+
batch_size = 1
41+
nheads = 32
42+
nheads_k = 32
43+
d = 128
44+
45+
# Sequence length
46+
seqlen_q = 1
47+
seqlen_kv = 4096
48+
49+
# Quantization parameters
50+
quant_mode = "k-channel"
51+
num_bits = 4
52+
pack_nums = 16 / num_bits
53+
group_size = 128
54+
55+
56+
# Set seed and parameters
57+
device = "cuda"
58+
dtype = torch.float16
59+
torch.random.manual_seed(0)
60+
61+
# Initialize tensors
62+
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
63+
k_cache = torch.randn(batch_size, seqlen_kv, nheads_k, d, device=device, dtype=dtype)
64+
v_cache = torch.randn(batch_size, seqlen_kv, nheads_k, d, device=device, dtype=dtype)
65+
66+
k_cache_rep = repeat(k_cache, "b s h d -> b s (h g) d", g=nheads // nheads_k)
67+
v_cache_rep = repeat(v_cache, "b s h d -> b s (h g) d", g=nheads // nheads_k)
68+
69+
# Reference attention computation
70+
out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep)
71+
72+
##################### BitDecoding Packing Kernel #####################
73+
74+
# Initialize quantization tensors
75+
if quant_mode == "k-channel":
76+
k_pack = torch.zeros((batch_size, int(seqlen_kv // pack_nums), nheads_k, d), dtype=torch.uint16, device=device)
77+
k_params = torch.zeros((batch_size, int(seqlen_kv // group_size), nheads_k, d), dtype=torch.float32, device=device)
78+
else:
79+
k_pack = torch.zeros((batch_size, seqlen_kv, nheads_k, int(d // pack_nums)), dtype=torch.uint16, device=device)
80+
k_params = torch.zeros((batch_size, int(d // group_size), nheads_k, seqlen_kv), dtype=torch.float32, device=device)
81+
82+
v_pack = torch.zeros((batch_size, seqlen_kv, nheads_k, int(d // pack_nums)), dtype=torch.uint16, device=device)
83+
v_params = torch.zeros((batch_size, int(d // group_size), nheads_k, seqlen_kv), dtype=torch.float32, device=device)
84+
85+
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_kv, seqlen_kv, dtype=torch.int32, device=device)
86+
87+
kvcache_pack_int(
88+
k_cache, k_pack, k_params,
89+
v_cache, v_pack, v_params,
90+
None, # opt_block_table
91+
cu_seqlens_k,
92+
seqlen_kv,
93+
quant_mode,
94+
group_size,
95+
num_bits
96+
)
97+
98+
sm_scale = 1.0 / math.sqrt(d)
99+
out_bitdecode = fwd_kvcache_int(
100+
q,
101+
k_pack, k_params,
102+
v_pack, v_params,
103+
None, # opt_block_table
104+
sm_scale,
105+
quant_mode,
106+
group_size,
107+
num_bits
108+
)
109+
110+
print(f"seqlen_kv:{seqlen_kv} BitDecode vs Pytorch: {(out_bitdecode - out_ref).abs().mean().item()}")
111+
112+
print(f"out_ref: \n{out_ref[0,0,0,:8]}")
113+
print(f"out_bitdecode: \n{out_bitdecode[0,0,0,:8]}")

0 commit comments

Comments
 (0)