55import torch
66import torch .nn as nn
77
8+ # isort: off
9+ # We need to import the CUDA kernels after importing torch
810import bit_decode_cuda as bit_decode_cuda
911
1012def kvcache_pack_int (k_cache : torch .Tensor , k_pack : torch .Tensor , k_params : torch .Tensor ,
1113 v_cache : torch .Tensor , v_pack : torch .Tensor , v_params : torch .Tensor ,
1214 opt_block_table : Optional [torch .Tensor ] = None ,
1315 cu_seqlens_k : torch .Tensor = None ,
1416 seqlen_k : int = 0 ,
15- quant_mode : str = "k-channel " ,
17+ quant_mode : str = "k-tensor " ,
1618 group_size : int = 128 ,
1719 num_bits : int = 4 ):
1820
@@ -22,65 +24,80 @@ def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torc
2224 V_unpad = v_cache .reshape (batch_size * seqlen_k , nheads_k , d )
2325
2426 if num_bits == 4 :
25- bit_decode_cuda .kvcache_pack_i4 (K_unpad , k_pack , k_params ,
26- V_unpad , v_pack , v_params ,
27- opt_block_table ,
28- cu_seqlens_k ,
29- seqlen_k ,
30- quant_mode ,
31- group_size
32- )
33- else :
34- bit_decode_cuda .kvcache_pack_i2 (K_unpad , k_pack , k_params ,
35- V_unpad , v_pack , v_params ,
36- opt_block_table ,
37- cu_seqlens_k ,
38- seqlen_k ,
39- quant_mode ,
40- group_size
41- )
27+ bit_decode_cuda .kvcache_pack_int4 (K_unpad , k_pack , k_params ,
28+ V_unpad , v_pack , v_params ,
29+ opt_block_table ,
30+ cu_seqlens_k ,
31+ seqlen_k ,
32+ quant_mode ,
33+ group_size
34+ )
35+ # else:
36+ # bit_decode_cuda.kvcache_pack_int2 (K_unpad, k_pack, k_params,
37+ # V_unpad, v_pack, v_params,
38+ # opt_block_table,
39+ # cu_seqlens_k,
40+ # seqlen_k,
41+ # quant_mode,
42+ # group_size
43+ # )
4244
4345def fwd_kvcache_int (q : torch .Tensor ,
4446 k_pack : torch .Tensor , k_params : torch .Tensor ,
4547 v_pack : torch .Tensor , v_params : torch .Tensor ,
48+ opt_k_new : Optional [torch .Tensor ] = None ,
49+ opt_v_new : Optional [torch .Tensor ] = None ,
50+ opt_seqlens_k : Optional [torch .Tensor ] = None ,
51+ k_pack_new : torch .Tensor = None , k_params_new : torch .Tensor = None ,
52+ v_pack_new : torch .Tensor = None , v_params_new : torch .Tensor = None ,
4653 opt_block_table : Optional [torch .Tensor ] = None ,
4754 softmax_scale : float = 1.0 ,
48- quant_mode : str = "k-channel " ,
55+ quant_mode : str = "k-tensor " ,
4956 group_size : int = 128 ,
57+ residual_block_size : int = 128 ,
58+ new_lens : int = 0 ,
5059 num_bits : int = 4 ):
5160
5261 if num_bits == 4 :
53- out_bit = bit_decode_cuda .fwd_kvcache_i4 (
54- q ,
55- k_pack , k_params ,
56- v_pack , v_params ,
57- opt_block_table ,
58- softmax_scale ,
59- quant_mode ,
60- group_size ,
61- False , # is_causal
62- - 1 , # window_size_left
63- - 1 , # window_size_right
64- 0.0 , # softcap
65- True , # is_rotary_interleaved
66- 0 # num_splits
67- )
68- else :
69- out_bit = bit_decode_cuda .fwd_kvcache_i2 (
62+ out_bit , k_pack_new , k_params_new , v_pack_new , v_params_new = bit_decode_cuda .fwd_kvcache_int4 (
7063 q ,
7164 k_pack , k_params ,
7265 v_pack , v_params ,
66+ opt_k_new , opt_v_new , opt_seqlens_k ,
67+ k_pack_new , k_params_new , v_pack_new , v_params_new ,
7368 opt_block_table ,
7469 softmax_scale ,
7570 quant_mode ,
7671 group_size ,
72+ residual_block_size ,
73+ new_lens ,
7774 False , # Added
7875 - 1 , # Added
7976 - 1 , # Added
8077 0.0 , # Added
8178 True , # Added
8279 0 # Added
8380 )
81+ # else:
82+ # out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
83+ # q,
84+ # k_pack, k_params,
85+ # v_pack, v_params,
86+ # opt_k_new, opt_v_new, opt_seqlens_k,
87+ # k_pack_new, k_params_new, v_pack_new, v_params_new,
88+ # opt_block_table,
89+ # softmax_scale,
90+ # quant_mode,
91+ # group_size,
92+ # residual_block_size,
93+ # new_lens,
94+ # False, # Added
95+ # -1, # Added
96+ # -1, # Added
97+ # 0.0, # Added
98+ # True, # Added
99+ # 0 # Added
100+ # )
84101
85102
86- return out_bit
103+ return out_bit , k_pack_new , k_params_new , v_pack_new , v_params_new
0 commit comments