1+ import torch
2+ import torch .nn as nn
3+ import math
4+ import triton
5+ from einops import rearrange , repeat
6+ import numpy as np
7+
8+ from flash_attn import flash_attn_with_kvcache
9+ from bit_decode import kvcache_pack_int , fwd_kvcache_int
10+
11+
12+ def attention_ref (
13+ q ,
14+ k ,
15+ v ,
16+ ):
17+ """
18+ Arguments:
19+ q: (batch_size, seqlen_q, nheads, head_dim)
20+ k: (batch_size, seqlen_k, nheads_k, head_dim)
21+ v: (batch_size, seqlen_k, nheads_k, head_dim)
22+ Output:
23+ output: (batch_size, seqlen_q, nheads, head_dim)
24+ attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
25+ """
26+ dtype_og = q .dtype
27+
28+ d = q .shape [- 1 ]
29+
30+ scores = torch .einsum ("bthd,bshd->bhts" , q / math .sqrt (d ), k )
31+
32+ attention = torch .softmax (scores , dim = - 1 ).to (v .dtype )
33+
34+ output = torch .einsum ("bhts,bshd->bthd" , attention , v )
35+
36+ return output .to (dtype = dtype_og ), attention .to (dtype = dtype_og )
37+
38+
39+ # Define constants
40+ batch_size = 1
41+ nheads = 32
42+ nheads_k = 32
43+ d = 128
44+
45+ # Sequence length
46+ seqlen_q = 1
47+ seqlen_kv = 4096
48+
49+ # Quantization parameters
50+ quant_mode = "k-channel"
51+ num_bits = 4
52+ pack_nums = 16 / num_bits
53+ group_size = 128
54+
55+
56+ # Set seed and parameters
57+ device = "cuda"
58+ dtype = torch .float16
59+ torch .random .manual_seed (0 )
60+
61+ # Initialize tensors
62+ q = torch .randn (batch_size , seqlen_q , nheads , d , device = device , dtype = dtype )
63+ k_cache = torch .randn (batch_size , seqlen_kv , nheads_k , d , device = device , dtype = dtype )
64+ v_cache = torch .randn (batch_size , seqlen_kv , nheads_k , d , device = device , dtype = dtype )
65+
66+ k_cache_rep = repeat (k_cache , "b s h d -> b s (h g) d" , g = nheads // nheads_k )
67+ v_cache_rep = repeat (v_cache , "b s h d -> b s (h g) d" , g = nheads // nheads_k )
68+
69+ # Reference attention computation
70+ out_ref , _ = attention_ref (q , k_cache_rep , v_cache_rep )
71+
72+ ##################### BitDecoding Packing Kernel #####################
73+
74+ # Initialize quantization tensors
75+ if quant_mode == "k-channel" :
76+ k_pack = torch .zeros ((batch_size , int (seqlen_kv // pack_nums ), nheads_k , d ), dtype = torch .uint16 , device = device )
77+ k_params = torch .zeros ((batch_size , int (seqlen_kv // group_size ), nheads_k , d ), dtype = torch .float32 , device = device )
78+ else :
79+ k_pack = torch .zeros ((batch_size , seqlen_kv , nheads_k , int (d // pack_nums )), dtype = torch .uint16 , device = device )
80+ k_params = torch .zeros ((batch_size , int (d // group_size ), nheads_k , seqlen_kv ), dtype = torch .float32 , device = device )
81+
82+ v_pack = torch .zeros ((batch_size , seqlen_kv , nheads_k , int (d // pack_nums )), dtype = torch .uint16 , device = device )
83+ v_params = torch .zeros ((batch_size , int (d // group_size ), nheads_k , seqlen_kv ), dtype = torch .float32 , device = device )
84+
85+ cu_seqlens_k = torch .arange (0 , (batch_size + 1 ) * seqlen_kv , seqlen_kv , dtype = torch .int32 , device = device )
86+
87+ kvcache_pack_int (
88+ k_cache , k_pack , k_params ,
89+ v_cache , v_pack , v_params ,
90+ None , # opt_block_table
91+ cu_seqlens_k ,
92+ seqlen_kv ,
93+ quant_mode ,
94+ group_size ,
95+ num_bits
96+ )
97+
98+ sm_scale = 1.0 / math .sqrt (d )
99+ out_bitdecode = fwd_kvcache_int (
100+ q ,
101+ k_pack , k_params ,
102+ v_pack , v_params ,
103+ None , # opt_block_table
104+ sm_scale ,
105+ quant_mode ,
106+ group_size ,
107+ num_bits
108+ )
109+
110+ print (f"seqlen_kv:{ seqlen_kv } BitDecode vs Pytorch: { (out_bitdecode - out_ref ).abs ().mean ().item ()} " )
111+
112+ print (f"out_ref: \n { out_ref [0 ,0 ,0 ,:8 ]} " )
113+ print (f"out_bitdecode: \n { out_bitdecode [0 ,0 ,0 ,:8 ]} " )
0 commit comments