Skip to content

Commit 33cffa2

Browse files
Fix 2bit bit decoding residual path
1 parent 9981c1d commit 33cffa2

6 files changed

Lines changed: 51 additions & 44 deletions

File tree

bit_decode/bit_decode_interface.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torc
3232
quant_mode,
3333
group_size
3434
)
35-
# else:
36-
# bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params,
37-
# V_unpad, v_pack, v_params,
38-
# opt_block_table,
39-
# cu_seqlens_k,
40-
# seqlen_k,
41-
# quant_mode,
42-
# group_size
43-
# )
35+
elif num_bits == 2:
36+
bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params,
37+
V_unpad, v_pack, v_params,
38+
opt_block_table,
39+
cu_seqlens_k,
40+
seqlen_k,
41+
quant_mode,
42+
group_size
43+
)
44+
else:
45+
raise ValueError(f"Unsupported num_bits={num_bits}; expected 2 or 4")
4446

4547
def fwd_kvcache_int(q: torch.Tensor,
4648
k_pack: torch.Tensor, k_params: torch.Tensor,
@@ -78,26 +80,28 @@ def fwd_kvcache_int(q: torch.Tensor,
7880
True, # Added
7981
0 # Added
8082
)
81-
# else:
82-
# out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
83-
# q,
84-
# k_pack, k_params,
85-
# v_pack, v_params,
86-
# opt_k_new, opt_v_new, opt_seqlens_k,
87-
# k_pack_new, k_params_new, v_pack_new, v_params_new,
88-
# opt_block_table,
89-
# softmax_scale,
90-
# quant_mode,
91-
# group_size,
92-
# residual_block_size,
93-
# new_lens,
94-
# False, # Added
95-
# -1, # Added
96-
# -1, # Added
97-
# 0.0, # Added
98-
# True, # Added
99-
# 0 # Added
100-
# )
83+
elif num_bits == 2:
84+
out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
85+
q,
86+
k_pack, k_params,
87+
v_pack, v_params,
88+
opt_k_new, opt_v_new, opt_seqlens_k,
89+
k_pack_new, k_params_new, v_pack_new, v_params_new,
90+
opt_block_table,
91+
softmax_scale,
92+
quant_mode,
93+
group_size,
94+
residual_block_size,
95+
new_lens,
96+
False, # Added
97+
-1, # Added
98+
-1, # Added
99+
0.0, # Added
100+
True, # Added
101+
0 # Added
102+
)
103+
else:
104+
raise ValueError(f"Unsupported num_bits={num_bits}; expected 2 or 4")
101105

102106

103107
return out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new

csrc/bit_decode/decode_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ void kvcache_qpack(const at::Tensor &k, at::Tensor &k_pack, at::Tensor &k_params
687687

688688
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
689689
m.doc() = "BitDecoding";
690-
// m.def("kvcache_pack_i2", &kvcache_qpack<2>, "Forward pass, kvcache quantization and packing (2-bit)");
690+
m.def("kvcache_pack_int2", &kvcache_qpack<2>, "Forward pass, kvcache quantization and packing (2-bit)");
691691
m.def("kvcache_pack_int4", &kvcache_qpack<4>, "Forward pass, kvcache quantization and packing (4-bit)");
692-
// m.def("fwd_kvcache_i2", &mha_fwd_kvcache<2>, "Forward pass, with 2-bit KV-cache");
692+
m.def("fwd_kvcache_int2", &mha_fwd_kvcache<2>, "Forward pass, with 2-bit KV-cache");
693693
m.def("fwd_kvcache_int4", &mha_fwd_kvcache<4>, "Forward pass, with 4-bit KV-cache");
694694
}

csrc/bit_decode/src/flash_fwd_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ inline __device__ void compute_attn_1rowblock_residualkv(const Params &params, c
7070
using ElementKVPack = typename Kernel_traits::ElementKVPack;
7171
using ElementAccum = typename Kernel_traits::ElementAccum;
7272
using index_t = typename Kernel_traits::index_t;
73-
using SharedStorage = typename Kernel_traits::SharedStorage;
73+
using SharedStorage = typename Kernel_traits::SharedStorage_residual;
7474

7575
// Shared memory.
7676
extern __shared__ char smem_[];
@@ -1881,4 +1881,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
18811881
}
18821882
}
18831883

1884-
} // namespace flash
1884+
} // namespace flash

csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
7+
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
88
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream);
9-
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);
9+
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);

csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
// template<>
8-
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
9-
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
10-
// }
7+
template<>
8+
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
10+
}
1111
// template<>
1212
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream) {
1313
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 64>(params, stream);
1414
// }
15-
// template<>
16-
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream) {
17-
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 32>(params, stream);
18-
// }
15+
template<>
16+
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream) {
17+
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 32>(params, stream);
18+
}

evaluation/example.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@ def main():
2323
parser.add_argument('--max_length', type=int, default=131072, help='Maximum length of the input sequence')
2424
parser.add_argument('--num_bits', type=int, default=4, help='Number of bits for quantization')
2525
parser.add_argument('--quant_mode', type=str, default='k-channel', help='Quantization mode')
26-
parser.add_argument('--group_size', type=int, default=128, help='Group size for quantization')
26+
parser.add_argument('--group_size', type=int, default=None, help='Group size for quantization')
2727
parser.add_argument('--attn_backend', type=str, default='flash_attention_2', help='Attention implementation')
2828
args = parser.parse_args()
2929

3030
# For reproducibility
3131
random.seed(0)
3232
torch.manual_seed(0)
3333

34+
if args.group_size is None:
35+
args.group_size = 32 if args.num_bits == 2 else 128
36+
3437
if "Llama" in args.model_path:
3538
config = LlamaConfig.from_pretrained(args.model_path)
3639
elif "Qwen" in args.model_path:

0 commit comments

Comments
 (0)