@@ -32,15 +32,17 @@ def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torc
3232 quant_mode ,
3333 group_size
3434 )
35- # else:
36- # bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params,
37- # V_unpad, v_pack, v_params,
38- # opt_block_table,
39- # cu_seqlens_k,
40- # seqlen_k,
41- # quant_mode,
42- # group_size
43- # )
35+ elif num_bits == 2 :
36+ bit_decode_cuda .kvcache_pack_int2 (K_unpad , k_pack , k_params ,
37+ V_unpad , v_pack , v_params ,
38+ opt_block_table ,
39+ cu_seqlens_k ,
40+ seqlen_k ,
41+ quant_mode ,
42+ group_size
43+ )
44+ else :
45+ raise ValueError (f"Unsupported num_bits={ num_bits } ; expected 2 or 4" )
4446
4547def fwd_kvcache_int (q : torch .Tensor ,
4648 k_pack : torch .Tensor , k_params : torch .Tensor ,
@@ -78,26 +80,28 @@ def fwd_kvcache_int(q: torch.Tensor,
7880 True , # Added
7981 0 # Added
8082 )
81- # else:
82- # out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
83- # q,
84- # k_pack, k_params,
85- # v_pack, v_params,
86- # opt_k_new, opt_v_new, opt_seqlens_k,
87- # k_pack_new, k_params_new, v_pack_new, v_params_new,
88- # opt_block_table,
89- # softmax_scale,
90- # quant_mode,
91- # group_size,
92- # residual_block_size,
93- # new_lens,
94- # False, # Added
95- # -1, # Added
96- # -1, # Added
97- # 0.0, # Added
98- # True, # Added
99- # 0 # Added
100- # )
83+ elif num_bits == 2 :
84+ out_bit , k_pack_new , k_params_new , v_pack_new , v_params_new = bit_decode_cuda .fwd_kvcache_int2 (
85+ q ,
86+ k_pack , k_params ,
87+ v_pack , v_params ,
88+ opt_k_new , opt_v_new , opt_seqlens_k ,
89+ k_pack_new , k_params_new , v_pack_new , v_params_new ,
90+ opt_block_table ,
91+ softmax_scale ,
92+ quant_mode ,
93+ group_size ,
94+ residual_block_size ,
95+ new_lens ,
96+ False , # Added
97+ - 1 , # Added
98+ - 1 , # Added
99+ 0.0 , # Added
100+ True , # Added
101+ 0 # Added
102+ )
103+ else :
104+ raise ValueError (f"Unsupported num_bits={ num_bits } ; expected 2 or 4" )
101105
102106
103107 return out_bit , k_pack_new , k_params_new , v_pack_new , v_params_new
0 commit comments