Skip to content

Commit 4a15ed3

Browse files
committed
add page tets
1 parent bbb6d6e commit 4a15ed3

9 files changed

Lines changed: 262 additions & 57 deletions

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ python setup.py install
2727
1. See benchmark/bench_single_decode.ipynb
2828
2. (Optional) Play with libtorch c++
2929
```
30-
cd libs/
31-
wget https://download.pytorch.org/libtorch /cu124/libtorch-shared-with-deps-2.5.1%2Bcu124.zip
32-
unzip libtorch-shared-with-deps-2.5.1+cu124.zip
33-
rm libtorch-shared-with-deps-2.5.1+cu124.zip
30+
# download libtorch
3431
3532
cd BitDecoding/csrc/bit_decode
3633
mkdir build && cd build

csrc/bit_decode/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ target_link_libraries(test_single_packdecode "${TORCH_LIBRARIES}")
3131
target_include_directories(test_single_packdecode PRIVATE ${INCLUDE_DIR})
3232
target_compile_options(test_single_packdecode PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>)
3333

34+
message(STATUS "Compile testing packdecode kernel.")
35+
add_executable(test_batch_packdecode
36+
${PROJECT_SOURCE_DIR}/src/test_batch_packdecode.cu
37+
${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_sm80.cu
38+
${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu
39+
${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu
40+
${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu
41+
${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_split_hdim128_fp16_sm80_4bit.cu
42+
)
43+
target_link_libraries(test_batch_packdecode "${TORCH_LIBRARIES}")
44+
target_include_directories(test_batch_packdecode PRIVATE ${INCLUDE_DIR})
45+
target_compile_options(test_batch_packdecode PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=255 -gencode arch=compute_80,code=sm_80 -w>)
46+
3447
message(STATUS "Compile benchmarking kernel.")
3548
add_executable(bench_single_packdecode
3649
${PROJECT_SOURCE_DIR}/src/bench_single_packdecode.cu

csrc/bit_decode/src/flash_api.h

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,19 @@ void set_params_fprop(Flash_fwd_params &params,
8888
params.o_head_stride = out.stride(-2);
8989

9090
if (cu_seqlens_q_d == nullptr) {
91-
params.q_batch_stride = q.stride(0);
92-
// params.k_batch_stride = k.stride(0);
93-
params.K_pack_batch_stride = k_pack.stride(0);
94-
params.k_params_batch_stride = k_params.stride(0);
95-
// params.v_batch_stride = v.stride(0);
96-
params.v_pack_batch_stride = v_pack.stride(0);
97-
params.v_params_batch_stride = v_params.stride(0);
98-
params.o_batch_stride = out.stride(0);
99-
100-
if (seqlenq_ngroups_swapped) {
101-
params.q_batch_stride *= seqlen_q;
102-
params.o_batch_stride *= seqlen_q;
103-
}
91+
params.q_batch_stride = q.stride(0);
92+
// params.k_batch_stride = k.stride(0);
93+
params.K_pack_batch_stride = k_pack.stride(0);
94+
params.k_params_batch_stride = k_params.stride(0);
95+
// params.v_batch_stride = v.stride(0);
96+
params.v_pack_batch_stride = v_pack.stride(0);
97+
params.v_params_batch_stride = v_params.stride(0);
98+
params.o_batch_stride = out.stride(0);
99+
100+
if (seqlenq_ngroups_swapped) {
101+
params.q_batch_stride *= seqlen_q;
102+
params.o_batch_stride *= seqlen_q;
103+
}
104104
}
105105

106106
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
@@ -130,14 +130,14 @@ void set_params_fprop(Flash_fwd_params &params,
130130
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
131131
#endif
132132
if (softcap > 0.0) {
133-
params.softcap = softmax_scale / softcap;
134-
params.scale_softmax = softcap;
135-
params.scale_softmax_log2 = softcap * M_LOG2E;
133+
params.softcap = softmax_scale / softcap;
134+
params.scale_softmax = softcap;
135+
params.scale_softmax_log2 = softcap * M_LOG2E;
136136
} else{
137-
// Remove potential NaN
138-
params.softcap = 0.0;
139-
params.scale_softmax = softmax_scale;
140-
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
137+
// Remove potential NaN
138+
params.softcap = 0.0;
139+
params.scale_softmax = softmax_scale;
140+
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
141141
}
142142

143143
// Set this to probability of keeping an element to simplify things.
@@ -337,7 +337,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x
337337
const auto sizes = q.sizes();
338338

339339
const int batch_size = sizes[0];
340-
int seqlen_q = sizes[1];
340+
int seqlen_q = sizes[1];
341341
int num_heads = sizes[2];
342342
const int head_size_og = sizes[3]; // dim
343343

@@ -456,7 +456,8 @@ void set_params_fprop_qpack(Flash_fwd_params &params,
456456
const at::Tensor v, at::Tensor v_pack, at::Tensor v_params,
457457
void *cu_seqlens_k_d,
458458
const std::string quant_mode,
459-
const int group_size
459+
const int group_size,
460+
bool page_kv
460461
) {
461462

462463
// Reset the parameters
@@ -489,12 +490,12 @@ void set_params_fprop_qpack(Flash_fwd_params &params,
489490
params.v_pack_head_stride = v_pack.stride(-2);
490491
params.v_params_head_stride = v_params.stride(-2);
491492

492-
// params.k_batch_stride = k.stride(0);
493-
params.k_batch_stride = seqlen_k * k.size(-2) * k.size(-1);
493+
if (page_kv) params.k_batch_stride = k.stride(0);
494+
else params.k_batch_stride = seqlen_k * k.size(-2) * k.size(-1);
494495
params.K_pack_batch_stride = k_pack.stride(0);
495496
params.k_params_batch_stride = k_params.stride(0);
496-
// params.v_batch_stride = v.stride(0);
497-
params.v_batch_stride = seqlen_k * v.size(-2) * v.size(-1);
497+
if (page_kv) params.v_batch_stride = v.stride(0);
498+
else params.v_batch_stride = seqlen_k * v.size(-2) * v.size(-1);
498499
params.v_pack_batch_stride = v_pack.stride(0);
499500
params.v_params_batch_stride = v_params.stride(0);
500501

@@ -583,7 +584,8 @@ void kvcache_qpack(const at::Tensor &k,
583584
v, v_pack, v_params,
584585
/*cu_seqlens_k_d=*/nullptr,
585586
quant_mode,
586-
group_size
587+
group_size,
588+
paged_KV
587589
);
588590

589591
if (paged_KV) {

csrc/bit_decode/src/flash_fwd_kernel.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ inline __device__ void compute_qpack_1rowblock(const Params &params, const int b
893893
Tensor tSrV_view = smem_thr_copy_V.retile_D(tSrV);
894894
Tensor tSsV_pack_s2r = smem_thr_copy_V_pack.partition_S(sVt_pack);
895895
Tensor tSrV_pack_s2r_view = smem_thr_copy_V_pack.retile_D(tSrV_pack);
896-
896+
897897
// Advance gK
898898
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
899899

@@ -914,7 +914,7 @@ inline __device__ void compute_qpack_1rowblock(const Params &params, const int b
914914

915915
TensorParamsKC tScales_k_c, tZeros_k_c;
916916
TensorParamsVG tScales_v_c, tZeros_v_c;
917-
TensorParamsG tScales_k_g, tZeros_k_g;
917+
TensorParamsG tScales_k_g, tZeros_k_g;
918918

919919
if (Kernel_traits::quant_mode == 1) {
920920
quant::qpack_Kchannel_Vtensor<num_bits>(tSrK, tSrK_pack, tScales_k_c, tZeros_k_c, sReduce_tmp, num_params);
@@ -979,7 +979,6 @@ inline __device__ void compute_qpack_1rowblock(const Params &params, const int b
979979
cute::copy(gmem_tiled_copy_k_pack, tKsK_pack_s2g, tKgK_pack_s2g);
980980
__syncthreads();
981981
cute::copy(gmem_tiled_copy_v_pack, tVsV_pack_s2g, tVgV_pack_s2g);
982-
983982
__syncthreads();
984983
// //////////////////////////////////////////////////////////////////////////////
985984
// // verify the quantize
@@ -1019,7 +1018,7 @@ inline __device__ void compute_qpack_1rowblock(const Params &params, const int b
10191018
// quant::dequant_Kchannel_Vtensor<num_bits>(tSrV_pack(_,_,i), tSrV_dequant(_,_,i), tScales_v_c(_,i), tZeros_v_c(_,i), num_params);
10201019
// }
10211020

1022-
if (Kernel_traits::quant_mode == 1) {
1021+
// if (Kernel_traits::quant_mode == 1) {
10231022
// CUTE_UNROLL
10241023
// for (int i = 0; i < size<1>(tScales_k_h2_c); ++i) {
10251024
// CUTE_UNROLL
@@ -1033,7 +1032,7 @@ inline __device__ void compute_qpack_1rowblock(const Params &params, const int b
10331032
// for (int i = 0; i < size<2>(tSrK_pack); ++i) {
10341033
// quant::dequant_Kchannel_Vtensor<num_bits>(tSrK_pack(_,_,i), tSrK_dequant(_,_,i), tScales_k_c(_,i), tZeros_k_c(_,i), num_params);
10351034
// }
1036-
} else {
1035+
// } else {
10371036
// CUTE_UNROLL
10381037
// for (int j = 0; j < size<0>(tScales_k_h2_g); ++j) {
10391038
// tScales_k_h2_g(j) = gK_params(0 + 32*j + tidx/4, 0);
@@ -1043,11 +1042,11 @@ inline __device__ void compute_qpack_1rowblock(const Params &params, const int b
10431042
// auto tScales_k_h1_g = cute::recast<__half>(tScales_k_h2_g);
10441043
// auto tZeros_k_h1_g = cute::recast<__half>(tZeros_k_h2_g);
10451044

1046-
CUTE_UNROLL
1047-
for (int i = 0; i < size<2>(tSrK_pack); ++i) {
1048-
quant::dequantize_Ktensor(tSrK_pack, tSrK_dequant, tScales_k_h2_g, tZeros_k_h2_g, 4, group_size, i);
1049-
}
1050-
}
1045+
// CUTE_UNROLL
1046+
// for (int i = 0; i < size<2>(tSrK_pack); ++i) {
1047+
// quant::dequantize_Ktensor(tSrK_pack, tSrK_dequant, tScales_k_h2_g, tZeros_k_h2_g, 4, group_size, i);
1048+
// }
1049+
// }
10511050

10521051
// // //////////////////////////////////////////////////////////////////////////////
10531052
#if DEBUG2
@@ -1132,10 +1131,10 @@ template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
11321131
inline __device__ void compute_attn_splitkv(const Params &params) {
11331132
const int m_block = blockIdx.x;
11341133
// The block index for the batch.
1135-
const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
1134+
const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
11361135
// The block index for the head.
1137-
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
1138-
const int n_split_idx = Split ? blockIdx.y : 0;
1136+
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
1137+
const int n_split_idx = Split ? blockIdx.y : 0;
11391138
const int num_n_splits = Split ? gridDim.y : 1;
11401139
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Paged_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
11411140
}

csrc/bit_decode/src/flash_fwd_launch_template.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L
6060
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
6161
}
6262

63-
64-
65-
6663
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
6764
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
6865
constexpr size_t smem_size = Kernel_traits::kSmemSize;
@@ -118,7 +115,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
118115
// LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
119116
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
120117
// BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
121-
// BOOL_SWITCH(params.block_table != nullptr, Paged_KV, [&] {
118+
BOOL_SWITCH(params.block_table != nullptr, Paged_KV, [&] {
122119
// ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
123120
// SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
124121
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
@@ -131,7 +128,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
131128
// Append_KV:
132129
// Has_alibi: 0
133130
// Is_softcap: 0
134-
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, false, false, true, false, Split, false, false>;
131+
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, false, false, true, false, Split, false, Paged_KV>;
135132
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
136133
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
137134
if (smem_size >= 48 * 1024) {
@@ -141,7 +138,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
141138
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
142139
C10_CUDA_KERNEL_LAUNCH_CHECK();
143140
// });
144-
// });
141+
});
145142
// });
146143
});
147144
// });

csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
7+
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
88
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream);
99
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);

csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
template<>
8-
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
9-
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
10-
}
7+
// template<>
8+
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
9+
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
10+
// }
1111
// template<>
1212
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream) {
1313
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 64>(params, stream);

0 commit comments

Comments
 (0)