1+ import torch
2+ import torch .nn as nn
3+ import math
4+
5+ import triton
6+
7+ import numpy as np
8+ import bit_decode_cuda as bit_decode_cuda
9+ from bit_decode import kvcache_pack_int , fwd_kvcache_int
10+ from bit_decode import DynamicCache
11+
12+
13+ def attention_ref (
14+ q ,
15+ k ,
16+ v ,
17+ ):
18+ """
19+ Arguments:
20+ q: (batch_size, seqlen_q, nheads, head_dim)
21+ k: (batch_size, seqlen_k, nheads_k, head_dim)
22+ v: (batch_size, seqlen_k, nheads_k, head_dim)
23+ Output:
24+ output: (batch_size, seqlen_q, nheads, head_dim)
25+ attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
26+ """
27+ dtype_og = q .dtype
28+
29+ d = q .shape [- 1 ]
30+
31+ scores = torch .einsum ("bthd,bshd->bhts" , q / math .sqrt (d ), k )
32+
33+ attention = torch .softmax (scores , dim = - 1 ).to (v .dtype )
34+
35+ output = torch .einsum ("bhts,bshd->bthd" , attention , v )
36+
37+ return output .to (dtype = dtype_og ), attention .to (dtype = dtype_og )
38+
39+
40+ # Quantization parameters
41+ quant_mode = "k-channel"
42+ num_bits = 4
43+ pack_nums = 16 / num_bits
44+ group_size = 32
45+ residual_block_size = 128
46+
47+ device = "cuda"
48+ dtype = torch .float16
49+
50+ layer_idx = 0
51+ batch_size = 1
52+ nheads = 32
53+ nheads_k = 32
54+ d = 128
55+
56+ seqlen_q = 1
57+ seqlen_k = 1024
58+ sm_scale = 1.0 / math .sqrt (d )
59+
60+
61+ ####### Round 1 : Prefill #######
62+ torch .manual_seed (42 )
63+
64+ q = torch .rand (batch_size , seqlen_q , nheads , d , device = device , dtype = dtype )
65+ k_state = torch .randn (batch_size , seqlen_k , nheads_k , d , device = device , dtype = dtype )
66+ v_state = torch .randn (batch_size , seqlen_k , nheads_k , d , device = device , dtype = dtype )
67+
68+ residual_len = seqlen_k % residual_block_size
69+ residual = residual_len > 0
70+ seqlen_k_pack = seqlen_k - residual_len
71+
72+ print (f"residual_len: { residual_len } , residual: { residual } , seqlen_k_pack: { seqlen_k_pack } " )
73+
74+ cu_seqlens_k = torch .arange (0 , (batch_size + 1 ) * seqlen_k_pack , seqlen_k_pack ,
75+ dtype = torch .int32 , device = device )
76+
77+ # Initialize quantization tensors
78+ k_pack = torch .zeros ((batch_size , int (seqlen_k_pack // pack_nums ), nheads_k , d ), dtype = torch .uint16 , device = device )
79+ k_params = torch .zeros ((batch_size , int (seqlen_k_pack // group_size ), nheads_k , d ), dtype = torch .float32 , device = device )
80+
81+ v_pack = torch .zeros ((batch_size , seqlen_k_pack , nheads_k , int (d // pack_nums )), dtype = torch .uint16 , device = device )
82+ v_params = torch .zeros ((batch_size , int (d // group_size ), nheads_k , seqlen_k_pack ), dtype = torch .float32 , device = device )
83+
84+ # KV Cache Dynamic Cache
85+ past_key_value = DynamicCache ()
86+
87+ if residual :
88+ k_state_residual = k_state [:, - residual_len :, :, :]
89+ v_state_residual = v_state [:, - residual_len :, :, :]
90+ k_state_past = k_state [:, :- residual_len , :, :]
91+ v_state_past = v_state [:, :- residual_len , :, :]
92+ past_key_value .update_residual (k_state_residual , v_state_residual , layer_idx )
93+ else :
94+ k_state_past = k_state
95+ v_state_past = v_state
96+
97+ kvcache_pack_int (
98+ k_state_past , k_pack , k_params ,
99+ v_state_past , v_pack , v_params ,
100+ None , # opt_block_table
101+ cu_seqlens_k ,
102+ seqlen_k_pack ,
103+ quant_mode ,
104+ group_size ,
105+ num_bits
106+ )
107+ past_key_value .update_pack (k_pack , k_params , v_pack , v_params , layer_idx )
108+
109+ # self
110+ k_pack_new = torch .empty ((batch_size , int (residual_block_size // pack_nums ), nheads_k , k_pack .size (- 1 )), dtype = torch .uint16 , device = device )
111+ k_params_new = torch .empty ((batch_size , int (residual_block_size // group_size ), nheads_k , k_params .size (- 1 )), dtype = torch .float32 , device = device )
112+ v_pack_new = torch .empty ((batch_size , residual_block_size , nheads_k , v_pack .size (- 1 )), dtype = torch .uint16 , device = device )
113+ v_params_new = torch .empty ((batch_size , v_params .size (1 ), nheads_k , residual_block_size ), dtype = torch .float32 , device = device )
114+
115+ ####### Round 2-3 : Decode #######
116+ for round_idx in range (32 ):
117+ k_new = torch .randn (batch_size , 1 , nheads_k , d , device = device , dtype = dtype )
118+ v_new = torch .randn (batch_size , 1 , nheads_k , d , device = device , dtype = dtype )
119+
120+ # Get kv cache_pack
121+ k_pack , k_params , v_pack , v_params = past_key_value .update_pack (None , None , None , None , layer_idx )
122+
123+ seqlen_pack = v_pack .shape [1 ]
124+ seqlens_k = torch .full ((batch_size ,), seqlen_pack , dtype = torch .int32 , device = device )
125+
126+ # Get kv cache_residual and append new kv
127+ k_residual = torch .zeros ((batch_size , residual_block_size , nheads_k , d ), device = device , dtype = dtype )
128+ v_residual = torch .zeros ((batch_size , residual_block_size , nheads_k , d ), device = device , dtype = dtype )
129+ k_residual_cache , v_residual_cache = past_key_value .update_residual (k_new , v_new , layer_idx )
130+
131+ cur_residual_len = k_residual_cache .shape [1 ]
132+ print (f"cur_residual_len: { cur_residual_len } " )
133+
134+ k_residual [:, :cur_residual_len , :, :] = k_residual_cache
135+ v_residual [:, :cur_residual_len , :, :] = v_residual_cache
136+
137+ out_bitdecode , k_pack_new , k_params_new , v_pack_new , v_params_new = fwd_kvcache_int (
138+ q ,
139+ k_pack , k_params ,
140+ v_pack , v_params ,
141+ k_residual , v_residual , seqlens_k , #seqlens_k
142+ k_pack_new , k_params_new , v_pack_new , v_params_new ,
143+ None , # opt_block_table
144+ sm_scale ,
145+ quant_mode ,
146+ group_size ,
147+ residual_block_size ,
148+ cur_residual_len , # new_lens
149+ num_bits
150+ )
151+
152+ if cur_residual_len == residual_block_size :
153+ past_key_value .update_pack (k_pack_new , k_params_new , v_pack_new , v_params_new , layer_idx )
154+ past_key_value .clear_residual (layer_idx )
155+
156+ k_state = torch .cat ([k_state , k_new ], dim = 1 )
157+ v_state = torch .cat ([v_state , v_new ], dim = 1 )
158+
159+ out_ref = attention_ref (q , k_state , v_state )[0 ]
160+ print (f"Round { round_idx + 2 } : bitdecode vs pytorch: { (out_bitdecode - out_ref ).abs ().mean ().item ()} " )
0 commit comments