Skip to content

Commit 3b93787

Browse files
committed
update e2e
1 parent 30f0f81 commit 3b93787

18 files changed

Lines changed: 6374 additions & 1329 deletions

benchmark/bench_single_decode.ipynb

Lines changed: 0 additions & 414 deletions
This file was deleted.

bit_decode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
kvcache_pack_int,
55
fwd_kvcache_int
66
)
7+
8+
from bit_decode.models.cache_utils import Cache, DynamicCache, StaticCache

bit_decode/bit_decode_interface.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
import torch
66
import torch.nn as nn
77

8+
# isort: off
9+
# We need to import the CUDA kernels after importing torch
810
import bit_decode_cuda as bit_decode_cuda
911

1012
def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torch.Tensor,
1113
v_cache: torch.Tensor, v_pack: torch.Tensor, v_params: torch.Tensor,
1214
opt_block_table: Optional[torch.Tensor] = None,
1315
cu_seqlens_k: torch.Tensor = None,
1416
seqlen_k: int = 0,
15-
quant_mode: str = "k-channel",
17+
quant_mode: str = "k-tensor",
1618
group_size: int = 128,
1719
num_bits: int = 4):
1820

@@ -22,65 +24,80 @@ def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torc
2224
V_unpad = v_cache.reshape(batch_size * seqlen_k, nheads_k, d)
2325

2426
if num_bits == 4:
25-
bit_decode_cuda.kvcache_pack_i4(K_unpad, k_pack, k_params,
26-
V_unpad, v_pack, v_params,
27-
opt_block_table,
28-
cu_seqlens_k,
29-
seqlen_k,
30-
quant_mode,
31-
group_size
32-
)
33-
else:
34-
bit_decode_cuda.kvcache_pack_i2(K_unpad, k_pack, k_params,
35-
V_unpad, v_pack, v_params,
36-
opt_block_table,
37-
cu_seqlens_k,
38-
seqlen_k,
39-
quant_mode,
40-
group_size
41-
)
27+
bit_decode_cuda.kvcache_pack_int4(K_unpad, k_pack, k_params,
28+
V_unpad, v_pack, v_params,
29+
opt_block_table,
30+
cu_seqlens_k,
31+
seqlen_k,
32+
quant_mode,
33+
group_size
34+
)
35+
# else:
36+
# bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params,
37+
# V_unpad, v_pack, v_params,
38+
# opt_block_table,
39+
# cu_seqlens_k,
40+
# seqlen_k,
41+
# quant_mode,
42+
# group_size
43+
# )
4244

4345
def fwd_kvcache_int(q: torch.Tensor,
4446
k_pack: torch.Tensor, k_params: torch.Tensor,
4547
v_pack: torch.Tensor, v_params: torch.Tensor,
48+
opt_k_new: Optional[torch.Tensor] = None,
49+
opt_v_new: Optional[torch.Tensor] = None,
50+
opt_seqlens_k: Optional[torch.Tensor] = None,
51+
k_pack_new: torch.Tensor = None, k_params_new: torch.Tensor = None,
52+
v_pack_new: torch.Tensor = None, v_params_new: torch.Tensor = None,
4653
opt_block_table: Optional[torch.Tensor] = None,
4754
softmax_scale: float = 1.0,
48-
quant_mode: str = "k-channel",
55+
quant_mode: str = "k-tensor",
4956
group_size: int = 128,
57+
residual_block_size: int = 128,
58+
new_lens: int = 0,
5059
num_bits: int = 4):
5160

5261
if num_bits == 4:
53-
out_bit = bit_decode_cuda.fwd_kvcache_i4(
54-
q,
55-
k_pack, k_params,
56-
v_pack, v_params,
57-
opt_block_table,
58-
softmax_scale,
59-
quant_mode,
60-
group_size,
61-
False, # is_causal
62-
-1, # window_size_left
63-
-1, # window_size_right
64-
0.0, # softcap
65-
True, # is_rotary_interleaved
66-
0 # num_splits
67-
)
68-
else:
69-
out_bit = bit_decode_cuda.fwd_kvcache_i2(
62+
out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int4(
7063
q,
7164
k_pack, k_params,
7265
v_pack, v_params,
66+
opt_k_new, opt_v_new, opt_seqlens_k,
67+
k_pack_new, k_params_new, v_pack_new, v_params_new,
7368
opt_block_table,
7469
softmax_scale,
7570
quant_mode,
7671
group_size,
72+
residual_block_size,
73+
new_lens,
7774
False, # Added
7875
-1, # Added
7976
-1, # Added
8077
0.0, # Added
8178
True, # Added
8279
0 # Added
8380
)
81+
# else:
82+
# out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
83+
# q,
84+
# k_pack, k_params,
85+
# v_pack, v_params,
86+
# opt_k_new, opt_v_new, opt_seqlens_k,
87+
# k_pack_new, k_params_new, v_pack_new, v_params_new,
88+
# opt_block_table,
89+
# softmax_scale,
90+
# quant_mode,
91+
# group_size,
92+
# residual_block_size,
93+
# new_lens,
94+
# False, # Added
95+
# -1, # Added
96+
# -1, # Added
97+
# 0.0, # Added
98+
# True, # Added
99+
# 0 # Added
100+
# )
84101

85102

86-
return out_bit
103+
return out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new

bit_decode/models/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)