@@ -893,7 +893,7 @@ inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int b
893893 Tensor tSrV_view = smem_thr_copy_V.retile_D (tSrV);
894894 Tensor tSsV_pack_s2r = smem_thr_copy_V_pack.partition_S (sVt_pack );
895895 Tensor tSrV_pack_s2r_view = smem_thr_copy_V_pack.retile_D (tSrV_pack);
896-
896+
897897 // Advance gK
898898 cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
899899
@@ -914,7 +914,7 @@ inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int b
914914
915915 TensorParamsKC tScales_k_c, tZeros_k_c;
916916 TensorParamsVG tScales_v_c, tZeros_v_c;
917- TensorParamsG tScales_k_g, tZeros_k_g;
917+ TensorParamsG tScales_k_g, tZeros_k_g;
918918
919919 if (Kernel_traits::quant_mode == 1 ) {
920920 quant::qpack_Kchannel_Vtensor<num_bits>(tSrK, tSrK_pack, tScales_k_c, tZeros_k_c, sReduce_tmp , num_params);
@@ -979,7 +979,6 @@ inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int b
979979 cute::copy (gmem_tiled_copy_k_pack, tKsK_pack_s2g, tKgK_pack_s2g);
980980 __syncthreads ();
981981 cute::copy (gmem_tiled_copy_v_pack, tVsV_pack_s2g, tVgV_pack_s2g);
982-
983982 __syncthreads ();
984983 // //////////////////////////////////////////////////////////////////////////////
985984 // // verify the quantize
@@ -1019,7 +1018,7 @@ inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int b
10191018 // quant::dequant_Kchannel_Vtensor<num_bits>(tSrV_pack(_,_,i), tSrV_dequant(_,_,i), tScales_v_c(_,i), tZeros_v_c(_,i), num_params);
10201019 // }
10211020
1022- if (Kernel_traits::quant_mode == 1 ) {
1021+ // if (Kernel_traits::quant_mode == 1) {
10231022 // CUTE_UNROLL
10241023 // for (int i = 0; i < size<1>(tScales_k_h2_c); ++i) {
10251024 // CUTE_UNROLL
@@ -1033,7 +1032,7 @@ inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int b
10331032 // for (int i = 0; i < size<2>(tSrK_pack); ++i) {
10341033 // quant::dequant_Kchannel_Vtensor<num_bits>(tSrK_pack(_,_,i), tSrK_dequant(_,_,i), tScales_k_c(_,i), tZeros_k_c(_,i), num_params);
10351034 // }
1036- } else {
1035+ // } else {
10371036 // CUTE_UNROLL
10381037 // for (int j = 0; j < size<0>(tScales_k_h2_g); ++j) {
10391038 // tScales_k_h2_g(j) = gK_params(0 + 32*j + tidx/4, 0);
@@ -1043,11 +1042,11 @@ inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int b
10431042 // auto tScales_k_h1_g = cute::recast<__half>(tScales_k_h2_g);
10441043 // auto tZeros_k_h1_g = cute::recast<__half>(tZeros_k_h2_g);
10451044
1046- CUTE_UNROLL
1047- for (int i = 0 ; i < size<2 >(tSrK_pack); ++i) {
1048- quant::dequantize_Ktensor (tSrK_pack, tSrK_dequant, tScales_k_h2_g, tZeros_k_h2_g, 4 , group_size, i);
1049- }
1050- }
1045+ // CUTE_UNROLL
1046+ // for (int i = 0; i < size<2>(tSrK_pack); ++i) {
1047+ // quant::dequantize_Ktensor(tSrK_pack, tSrK_dequant, tScales_k_h2_g, tZeros_k_h2_g, 4, group_size, i);
1048+ // }
1049+ // }
10511050
10521051 // // //////////////////////////////////////////////////////////////////////////////
10531052 #if DEBUG2
@@ -1132,10 +1131,10 @@ template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
11321131inline __device__ void compute_attn_splitkv (const Params ¶ms) {
11331132 const int m_block = blockIdx.x ;
11341133 // The block index for the batch.
1135- const int bidb = Split ? blockIdx.z / params.h : blockIdx.y ;
1134+ const int bidb = Split ? blockIdx.z / params.h : blockIdx.y ;
11361135 // The block index for the head.
1137- const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z ;
1138- const int n_split_idx = Split ? blockIdx.y : 0 ;
1136+ const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z ;
1137+ const int n_split_idx = Split ? blockIdx.y : 0 ;
11391138 const int num_n_splits = Split ? gridDim.y : 1 ;
11401139 flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Paged_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
11411140}
0 commit comments