From 6862777fb692d3f6ff91b7f75844b95cf423dd27 Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 28 Nov 2025 06:50:17 +0000 Subject: [PATCH 01/25] add block scale parameters to kernel --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 26 ++ example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 224 ++++++++++++++---- example/ck_tile/01_fmha/quant.hpp | 7 + include/ck_tile/core/utility/functional.hpp | 9 + .../host/reference/reference_batched_gemm.hpp | 39 +++ .../block_attention_quant_scale_enum.hpp | 6 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 149 +++++++++++- 10 files changed, 409 insertions(+), 59 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9edf50e89c..5f5587226c 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -136,7 +136,7 @@ list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) # Allow comparing floating points directly in order to check sentinel values list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) - +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -g1 --save-temps -Wno-gnu-line-marker) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 333579ec8d..1009b29d0b 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -65,11 +65,13 @@ def get_mask_check_map(mask: str): QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", + "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 360d6a7c78..9599a57228 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -738,7 +738,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], - ["no", "pertensor"], + ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], ): @@ -829,7 +829,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] + ["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f279ebfcea..47b80ebb32 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,8 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* bseqstart_q_ptr; + const void* bseqstart_k_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -256,6 +258,9 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -263,6 +268,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -274,6 +282,9 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; + + ck_tile::index_t block_scale_m; + ck_tile::index_t block_scale_n; }; struct fmha_fwd_pagedkv_args @@ -592,6 +603,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, + args.bseqstart_q_ptr, + args.bseqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -611,6 +624,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.window_size_left, args.window_size_right, args.mask_type, @@ -618,6 +634,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_m, + args.block_scale_n, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr); } @@ -654,6 +672,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -661,12 +682,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, + args.batch_stride_q_descale, + args.batch_stride_k_descale, + args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_m, + args.block_scale_n, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr); } diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index bca4c60bc6..2db1caf568 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -187,6 +187,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + constexpr ck_tile::index_t block_scale_m_ = 128; + constexpr ck_tile::index_t block_scale_n_ = 128; + const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -448,7 +451,11 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - auto max_seqlen_k = std::numeric_limits::min(); + size_t num_block_scale_q = 0; + size_t num_block_scale_k = 0; + std::vector bseqstart_q_host = {0}; + std::vector bseqstart_k_host = {0}; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -464,6 +471,10 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } + num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_); + num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_); + bseqstart_q_host.push_back(num_block_scale_q); + bseqstart_k_host.push_back(num_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -474,6 +485,8 @@ fwd_result fmha_fwd_run(mode_enum mode, sizeof(VDataType) * hdim_v * real_seqlen_k); } } + // std::cout << "bseqstart_q_host: " << bseqstart_q_host + // << "bseqstart_k_host: " << bseqstart_k_host << std::endl; const ck_tile::index_t max_num_page_blocks = (0 < page_block_size @@ -525,6 +538,13 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); + const ck_tile::index_t num_block_scale_m = + (mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_) + : num_block_scale_q; + const ck_tile::index_t num_block_scale_n = + (mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_) + : num_block_scale_k; + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( @@ -575,9 +595,18 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor q_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_m} + : std::array{1, 1, 1}); + ck_tile::HostTensor k_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_n} + : std::array{1, 1, 1}); + ck_tile::HostTensor v_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_n} + : std::array{1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -692,6 +721,13 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } + else if(qscale.type == quant_scale_enum::blockscale) + { + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + // return fwd_result::no_instance; + } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -705,6 +741,8 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bseqstart_q_buf(bseqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem bseqstart_k_buf(bseqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -749,6 +787,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); + bseqstart_q_buf.ToDevice(bseqstart_q_host.data()); + bseqstart_k_buf.ToDevice(bseqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -941,11 +981,14 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_q_descale = num_block_scale_m; + const ck_tile::index_t nhead_stride_k_descale = num_block_scale_n; + const ck_tile::index_t nhead_stride_v_descale = num_block_scale_n; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -963,6 +1006,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + const ck_tile::index_t batch_stride_q_descale = num_block_scale_m * nhead; + const ck_tile::index_t batch_stride_k_descale = num_block_scale_n * nhead_k; + const ck_tile::index_t batch_stride_v_descale = num_block_scale_n * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1046,9 +1092,37 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + if(qscale.type == quant_scale_enum::blockscale) + { + args.q_descale_ptr = + reinterpret_cast(q_descale_buf.GetDeviceBuffer()); + args.k_descale_ptr = + reinterpret_cast(k_descale_buf.GetDeviceBuffer()); + args.v_descale_ptr = + reinterpret_cast(v_descale_buf.GetDeviceBuffer()); + + args.bseqstart_q_ptr = + (mode == mode_enum::group ? bseqstart_q_buf.GetDeviceBuffer() : nullptr); + args.bseqstart_k_ptr = + (mode == mode_enum::group ? bseqstart_k_buf.GetDeviceBuffer() : nullptr); + + args.nhead_stride_q_descale = nhead_stride_q_descale; + args.nhead_stride_k_descale = nhead_stride_k_descale; + args.nhead_stride_v_descale = nhead_stride_v_descale; + + args.batch_stride_q_descale = batch_stride_q_descale; + args.batch_stride_k_descale = batch_stride_k_descale; + args.batch_stride_v_descale = batch_stride_v_descale; + + args.block_scale_m = block_scale_m_; + args.block_scale_n = block_scale_n_; + } + else + { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + } args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1551,14 +1625,42 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - ck_tile:: - reference_batched_gemm( + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t q_offset = + (mode == mode_enum::batch) ? 0 : bseqstart_q_host[wb]; + const ck_tile::index_t k_offset = + (mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb]; + ck_tile::reference_batched_quant_gemm( q_host_ref, k_host_ref, s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s_host)); + ck_tile::idx_identity{}, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return value * scale_s * + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + std::get<1>(idx) / block_scale_m_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + std::get<2>(idx) / block_scale_n_); + }); + } + else + { + ck_tile:: + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } if(0.f < logits_soft_cap) { @@ -1716,13 +1818,34 @@ fwd_result fmha_fwd_run(mode_enum mode, } } - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t v_offset = + (mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb]; + ck_tile:: + reference_batched_quant_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return ck_tile::type_convert(value) * + v_descale_host(b_idx, + std::get<0>(idx) / nr, + v_offset + std::get<2>(idx) / block_scale_n_); + }, + ck_tile::idx_identity{}); + } + else + { + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + } ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off @@ -1730,7 +1853,6 @@ fwd_result fmha_fwd_run(mode_enum mode, if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on - auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err(o_host_result, o_host_ref, @@ -1788,31 +1910,33 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results(*json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale ? "no_scale" - : "pertensor", - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results( + *json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale + ? "no_scale" + : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 59d4ac1707..feb28cba24 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,6 +13,7 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, + blockscale, }; struct quant_scale_info @@ -25,6 +26,8 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; + else if(type == quant_scale_enum::blockscale) + os << "bs"; } static quant_scale_info decode(std::string str) @@ -38,6 +41,10 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } + else if(str == "bs" || str == "2") + { + info.type = quant_scale_enum::blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 90740dcbe3..420eba6609 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -91,6 +91,15 @@ struct identity } }; +struct idx_identity +{ + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(auto, T&& arg) const noexcept + { + return std::forward(arg); + } +}; + namespace detail { // RemainLengths: sequence<...> diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 63f13b1b16..96b54a4093 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -47,4 +47,43 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + AccDataType v_a = ck_tile::type_convert( + a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k))); + AccDataType v_b = ck_tile::type_convert( + b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k))); + + v_acc += v_a * v_b; + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert(acc_element_op(std::make_tuple(batch, m, n), v_acc)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 3755a2bc71..7e0f704bef 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, + BLOCKSCALE, }; template @@ -27,5 +28,10 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "blockscale"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 38830ee6fe..c3628b44b6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -214,6 +214,29 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; + struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs + { + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + ck_tile::index_t block_scale_m; + ck_tile::index_t block_scale_n; + }; + + struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + + struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + const int32_t* bseqstart_q_ptr; + const int32_t* bseqstart_k_ptr; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -289,9 +312,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t> { @@ -315,9 +341,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -374,6 +403,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -381,6 +413,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -388,6 +423,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -455,6 +492,23 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_m = block_scale_m; + kargs.block_scale_n = block_scale_n; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -520,6 +574,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -527,12 +584,17 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -568,6 +630,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -575,12 +640,17 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, mask_type, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -619,6 +689,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -626,12 +699,17 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -667,6 +745,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -674,12 +755,17 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, mask_type, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -700,6 +786,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* bseqstart_q_ptr, + const void* bseqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -719,6 +807,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -727,6 +818,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -793,6 +886,22 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.block_scale_m = block_scale_m; + kargs.block_scale_n = block_scale_n; + + kargs.bseqstart_q_ptr = reinterpret_cast(bseqstart_q_ptr); + kargs.bseqstart_k_ptr = reinterpret_cast(bseqstart_k_ptr); + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -844,6 +953,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* bseqstart_q_ptr, + const void* bseqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -863,6 +974,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -870,6 +984,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -888,6 +1004,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + bseqstart_q_ptr, + bseqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -907,6 +1025,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, mask_type, @@ -914,6 +1035,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -935,6 +1058,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* bseqstart_q_ptr, + const void* bseqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -954,6 +1079,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -961,6 +1089,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -979,6 +1109,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + bseqstart_q_ptr, + bseqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -998,6 +1130,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, mask_type, @@ -1005,6 +1140,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } From 7260af6c15fa7b87503455dfd7d77a5fe938becb Mon Sep 17 00:00:00 2001 From: ltqin Date: Sun, 30 Nov 2025 09:26:04 +0000 Subject: [PATCH 02/25] add block scale to kernel --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 80 ++++++++++++++++++- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 60 ++++++++++++-- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 10 ++- 3 files changed, 142 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index c3628b44b6..7d80869d3d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1279,6 +1279,9 @@ struct FmhaFwdKernel long_index_t batch_offset_randval = 0; long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; + long_index_t batch_offset_q_descale = 0; + long_index_t batch_offset_k_descale = 0; + long_index_t batch_offset_v_descale = 0; if constexpr(kIsGroupMode) { @@ -1310,6 +1313,14 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const long_index_t bquery_start = kargs.bseqstart_q_ptr[i_batch]; + const long_index_t bkey_start = kargs.bseqstart_k_ptr[i_batch]; + batch_offset_q_descale = bquery_start; + batch_offset_k_descale = bkey_start; + batch_offset_v_descale = bkey_start; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1377,6 +1388,15 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + batch_offset_q_descale = + static_cast(i_batch) * kargs.batch_stride_q_descale; + batch_offset_k_descale = + static_cast(i_batch) * kargs.batch_stride_k_descale; + batch_offset_v_descale = + static_cast(i_batch) * kargs.batch_stride_v_descale; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // If cumulative seqlen pointers are provided, override per-batch effective lengths @@ -1724,7 +1744,65 @@ struct FmhaFwdKernel variant_params, block_indices, smem_ptr, - dropout); + dropout, + nullptr, + nullptr, + 1); + } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + size_t idx = i_m0 / kargs.block_scale_m; + float q_descale = q_descale_ptr[idx]; + + float scale_o = 1.0; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::scales{scale_o}; + }(); + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{1.0f}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s * q_descale, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + kargs.block_scale_n); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9e1eb3bdec..301ffaebae 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -165,8 +166,13 @@ struct BlockFmhaPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + index_t block_scale_n) const { + ignore = block_scale_n; + ignore = v_descale_ptr; static_assert( std::is_same_v> && std::is_same_v> && @@ -318,6 +324,14 @@ struct BlockFmhaPipelineQRKSVS static_assert(1 <= k1_loops); do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + const auto row = k_origin.at(number<0>{}); + const index_t idx = row / block_scale_n; + k_descale = k_descale_ptr[idx]; + } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -387,7 +401,12 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); schedule_gemm0(); } - + // dequant + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [k_descale](auto& x) { x = x * k_descale; }, s_acc); + } // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -602,18 +621,41 @@ struct BlockFmhaPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } + + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const auto v_origin = v_dram_window.get_window_origin(); + const auto col = v_origin.at(number<1>{}); + const index_t idx = col / block_scale_n; + v_descale = v_descale_ptr[idx]; + } move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + // C= C0*a + C =a*(C0 + C/a) + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -640,11 +682,16 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(++i_total_loops < num_total_loop); // store lse @@ -750,7 +797,10 @@ struct BlockFmhaPipelineQRKSVS variant_params, block_indices, smem_ptr, - dropout); + dropout, + nullptr, + nullptr, + 1); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e07516cc27..19ecc7c48c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -187,7 +187,10 @@ struct BlockFmhaPipelineQRKSVSAsync const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float*, + const float*, + index_t) const { static_assert( std::is_same_v> && @@ -847,7 +850,10 @@ struct BlockFmhaPipelineQRKSVSAsync variant_params, block_indices, smem_ptr, - dropout); + dropout, + nullptr, + nullptr, + 1); } }; From 39104a061dff62a90e1d452f926ab444a4e3a1f1 Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 1 Dec 2025 03:40:35 +0000 Subject: [PATCH 03/25] add smoke test --- example/ck_tile/01_fmha/script/smoke_test_fwd.sh | 5 +++-- .../ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 14 +++++++------- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 8 +++----- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 596542eb9d..227f26c8f3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -95,10 +95,11 @@ run_fp8bf16_tests() { for perm in 0 1 ; do for b in 1 2 ; do for hdim in 64 128 256 ; do + for scale in 1 2; do - $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS - done ; done ; done + done ; done ; done ; done } run_fp8fp32_tests() { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 7d80869d3d..6b239d2363 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1272,13 +1272,13 @@ struct FmhaFwdKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; long_index_t batch_offset_q_descale = 0; long_index_t batch_offset_k_descale = 0; long_index_t batch_offset_v_descale = 0; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 301ffaebae..1de3d4dd5c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -404,8 +404,7 @@ struct BlockFmhaPipelineQRKSVS // dequant if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { - tile_elementwise_inout( - [k_descale](auto& x) { x = x * k_descale; }, s_acc); + tile_elementwise_inout([k_descale](auto& x) { x = x * k_descale; }, s_acc); } // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -638,8 +637,7 @@ struct BlockFmhaPipelineQRKSVS // STAGE 3, KV gemm auto o_acc0 = decltype(o_acc){}; clear_tile(o_acc0); - - // C= C0*a + C =a*(C0 + C/a) + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { @@ -690,7 +688,7 @@ struct BlockFmhaPipelineQRKSVS if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { tile_elementwise_inout( - [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); } } while(++i_total_loops < num_total_loop); From 356c3c970664af68a04e2694da7e270b8c8338bf Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 1 Dec 2025 06:25:03 +0000 Subject: [PATCH 04/25] format --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- .../host/reference/reference_batched_gemm.hpp | 13 +++++++------ .../elementwise/unary_element_wise_operation.hpp | 12 ++++++------ 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 5f5587226c..9edf50e89c 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -136,7 +136,7 @@ list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) # Allow comparing floating points directly in order to check sentinel values list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) -list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -g1 --save-temps -Wno-gnu-line-marker) + # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 96b54a4093..8d266ffca4 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -55,11 +55,11 @@ template CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, - const HostTensor& b_b_n_k, - HostTensor& c_b_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_b_n_k.mDesc.get_lengths()[1]; const int K = b_b_n_k.mDesc.get_lengths()[2]; @@ -79,7 +79,8 @@ CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_ v_acc += v_a * v_b; } - c_b_m_n(batch, m, n) = ck_tile::type_convert(acc_element_op(std::make_tuple(batch, m, n), v_acc)); + c_b_m_n(batch, m, n) = ck_tile::type_convert( + acc_element_op(std::make_tuple(batch, m, n), v_acc)); } }; diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 2f8d3c6053..a962b5d7b1 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1429,7 +1429,7 @@ struct SoftRelu { static constexpr const char* name = "SoftRelu"; - SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1450,7 +1450,7 @@ struct Power static constexpr const char* name = "Power"; Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma){}; + : alpha_(alpha), beta_(beta), gamma_(gamma) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1474,7 +1474,7 @@ struct ClippedRelu { static constexpr const char* name = "ClippedRelu"; - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1495,7 +1495,7 @@ struct LeakyRelu { static constexpr const char* name = "LeakyRelu"; - LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1514,7 +1514,7 @@ struct Elu { static constexpr const char* name = "Elu"; - Elu(float alpha = 1.f) : alpha_(alpha){}; + Elu(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1533,7 +1533,7 @@ struct Logistic { static constexpr const char* name = "Logistic"; - Logistic(float alpha = 1.f) : alpha_(alpha){}; + Logistic(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const From 42b5aa4484fde7a04de9c206e9b4cd9c33786d30 Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 1 Dec 2025 06:30:28 +0000 Subject: [PATCH 05/25] Revert "format" This reverts commit 356c3c970664af68a04e2694da7e270b8c8338bf. --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- .../host/reference/reference_batched_gemm.hpp | 13 ++++++------- .../elementwise/unary_element_wise_operation.hpp | 12 ++++++------ 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9edf50e89c..5f5587226c 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -136,7 +136,7 @@ list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) # Allow comparing floating points directly in order to check sentinel values list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) - +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -g1 --save-temps -Wno-gnu-line-marker) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 8d266ffca4..96b54a4093 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -55,11 +55,11 @@ template CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, - const HostTensor& b_b_n_k, - HostTensor& c_b_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_b_n_k.mDesc.get_lengths()[1]; const int K = b_b_n_k.mDesc.get_lengths()[2]; @@ -79,8 +79,7 @@ CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_ v_acc += v_a * v_b; } - c_b_m_n(batch, m, n) = ck_tile::type_convert( - acc_element_op(std::make_tuple(batch, m, n), v_acc)); + c_b_m_n(batch, m, n) = ck_tile::type_convert(acc_element_op(std::make_tuple(batch, m, n), v_acc)); } }; diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index a962b5d7b1..2f8d3c6053 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1429,7 +1429,7 @@ struct SoftRelu { static constexpr const char* name = "SoftRelu"; - SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1450,7 +1450,7 @@ struct Power static constexpr const char* name = "Power"; Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma) {}; + : alpha_(alpha), beta_(beta), gamma_(gamma){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1474,7 +1474,7 @@ struct ClippedRelu { static constexpr const char* name = "ClippedRelu"; - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1495,7 +1495,7 @@ struct LeakyRelu { static constexpr const char* name = "LeakyRelu"; - LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1514,7 +1514,7 @@ struct Elu { static constexpr const char* name = "Elu"; - Elu(float alpha = 1.f) : alpha_(alpha) {}; + Elu(float alpha = 1.f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1533,7 +1533,7 @@ struct Logistic { static constexpr const char* name = "Logistic"; - Logistic(float alpha = 1.f) : alpha_(alpha) {}; + Logistic(float alpha = 1.f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const From 295484747cac70b9dab7f8ac0e787b9d15898028 Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 1 Dec 2025 06:33:57 +0000 Subject: [PATCH 06/25] only format my code --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- .../host/reference/reference_batched_gemm.hpp | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 5f5587226c..9edf50e89c 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -136,7 +136,7 @@ list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) # Allow comparing floating points directly in order to check sentinel values list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) -list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -g1 --save-temps -Wno-gnu-line-marker) + # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 96b54a4093..8d266ffca4 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -55,11 +55,11 @@ template CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, - const HostTensor& b_b_n_k, - HostTensor& c_b_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_b_n_k.mDesc.get_lengths()[1]; const int K = b_b_n_k.mDesc.get_lengths()[2]; @@ -79,7 +79,8 @@ CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_ v_acc += v_a * v_b; } - c_b_m_n(batch, m, n) = ck_tile::type_convert(acc_element_op(std::make_tuple(batch, m, n), v_acc)); + c_b_m_n(batch, m, n) = ck_tile::type_convert( + acc_element_op(std::make_tuple(batch, m, n), v_acc)); } }; From 49a280b453d92e1279077b2f56859f0e38fdb4f4 Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 1 Dec 2025 06:53:12 +0000 Subject: [PATCH 07/25] format py --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 9599a57228..8f83404506 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -829,7 +829,10 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"] + ["f"], + ["no", "pertensor", "blockscale"], + get_mask_map(mask_impl).keys(), + ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip From 326d8c31938815ff2dde2541fc3b32426924b4c7 Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 1 Dec 2025 08:42:43 +0000 Subject: [PATCH 08/25] fix auto not allowd in function prototype --- include/ck_tile/core/utility/functional.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 420eba6609..ee2977afd2 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -93,8 +93,8 @@ struct identity struct idx_identity { - template - CK_TILE_HOST_DEVICE constexpr T&& operator()(auto, T&& arg) const noexcept + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(I&&, T&& arg) const noexcept { return std::forward(arg); } From 4484de91041ca4e128c3911eaa29f044cf204554 Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 2 Dec 2025 11:04:31 +0000 Subject: [PATCH 09/25] change instance tttt to ttff --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- .../ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 17 ++++------------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 8f83404506..577618d0e9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -743,7 +743,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO None diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 6b239d2363..8efd8b35a3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1769,15 +1769,6 @@ struct FmhaFwdKernel size_t idx = i_m0 / kargs.block_scale_m; float q_descale = q_descale_ptr[idx]; - float scale_o = 1.0; - - auto o_acc_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); - else - return ck_tile::scales{scale_o}; - }(); return FmhaPipeline{}(q_dram_window, identity{}, // q_element_func k_dram_window, @@ -1788,10 +1779,10 @@ struct FmhaFwdKernel identity{}, // bias_element_func randval_dram_window, lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{1.0f}, // p_compute_element_func - o_acc_element_func, // o_acc_element_func + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{1.0f}, // p_compute_element_func + scales{1.0f}, // o_acc_element_func mask, position_encoding, kargs.scale_s * q_descale, From 2709f1644e79f521fbaa9f23d338f5f6bdef1dd7 Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 2 Dec 2025 12:14:25 +0000 Subject: [PATCH 10/25] fix structured binding issue --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 8efd8b35a3..ba11b3e437 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1703,7 +1703,7 @@ struct FmhaFwdKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { + auto o_acc_tile = [&, i_nhead_ = i_nhead]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1753,16 +1753,16 @@ struct FmhaFwdKernel { const float* q_descale_ptr = reinterpret_cast(kargs.q_descale_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q_descale + + static_cast(i_nhead_) * kargs.nhead_stride_q_descale + batch_offset_q_descale; const float* k_descale_ptr = reinterpret_cast(kargs.k_descale_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * kargs.nhead_stride_k_descale + batch_offset_k_descale; const float* v_descale_ptr = reinterpret_cast(kargs.v_descale_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * kargs.nhead_stride_v_descale + batch_offset_v_descale; From 259fd9280e67a751b043449d01012d2644f6c9a8 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 3 Dec 2025 06:08:49 +0000 Subject: [PATCH 11/25] change s_acc elementwise op --- include/ck_tile/core/numeric/math.hpp | 7 +++++++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 10 +++++----- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 19 ++++++++++++------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 57f3953514..b50399b3c9 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -37,6 +37,13 @@ struct scales return lhs_ * rhs; } + template + CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const + { + auto new_scale = lhs_ * other; + return scales(new_scale); + } + private: Scale lhs_; }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index ba11b3e437..ea14612a1f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1779,13 +1779,13 @@ struct FmhaFwdKernel identity{}, // bias_element_func randval_dram_window, lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{1.0f}, // p_compute_element_func - scales{1.0f}, // o_acc_element_func + identity{}, // lse_element_func + scales{q_descale}, // s_acc_element_func + scales{1.0f}, // p_compute_element_func + scales{1.0f}, // o_acc_element_func mask, position_encoding, - kargs.scale_s * q_descale, + kargs.scale_s, variant, variant_params, block_indices, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 1de3d4dd5c..8bd24fe137 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -402,14 +402,19 @@ struct BlockFmhaPipelineQRKSVS schedule_gemm0(); } // dequant - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - tile_elementwise_inout([k_descale](auto& x) { x = x * k_descale; }, s_acc); - } + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); + // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -427,7 +432,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -444,7 +449,7 @@ struct BlockFmhaPipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = From 0d462a204ec4d9dd503dfabd29933a9fc1878fd0 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 3 Dec 2025 09:25:41 +0000 Subject: [PATCH 12/25] async pipeline add block scale --- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 61 ++++++++++++++++--- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 19ecc7c48c..8fa098554a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,6 +46,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -188,9 +189,9 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, - const float*, - const float*, - index_t) const + const float* k_descale_ptr, + const float* v_descale_ptr, + index_t block_scale_n) const { static_assert( std::is_same_v> && @@ -366,6 +367,14 @@ struct BlockFmhaPipelineQRKSVSAsync // main loop do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + const auto row = k_origin.at(number<0>{}); + const index_t idx = row / block_scale_n; + k_descale = k_descale_ptr[idx]; + } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -412,11 +421,20 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -434,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -451,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -673,7 +691,28 @@ struct BlockFmhaPipelineQRKSVSAsync #endif }(); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const auto v_origin = v_dram_window.get_window_origin(); + const auto col = v_origin.at(number<1>{}); + const index_t idx = col / block_scale_n; + v_descale = v_descale_ptr[idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -683,7 +722,7 @@ struct BlockFmhaPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -738,13 +777,19 @@ struct BlockFmhaPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc, + o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } + + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(i_total_loops < num_total_loop); // store lse From 163fd5eccd8d6b4de4cbaf486d860f6961203045 Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 6 Jan 2026 08:05:22 +0000 Subject: [PATCH 13/25] add quantation P using shift exp2 --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 56 ++++++++++--------- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 51 +++++++++++++++-- 2 files changed, 77 insertions(+), 30 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0c82cc3613..8cffb855f4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1734,32 +1734,36 @@ struct FmhaFwdKernel size_t idx = i_m0 / kargs.block_scale_m; float q_descale = q_descale_ptr[idx]; - - return FmhaPipeline{}(q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - scales{q_descale}, // s_acc_element_func - scales{1.0f}, // p_compute_element_func - scales{1.0f}, // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout, - k_descale_ptr, - v_descale_ptr, - kargs.block_scale_n); + // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 + // Both P and rowsum are scaled by 2^shift, canceling in normalization + // No additional scaling needed in p_compute_element_func or o_acc_element_func + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + scales{q_descale}, // s_acc_element_func + identity{}, // p_compute_element_func - No scaling (done in exp2) + identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + kargs.block_scale_n); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 6227742497..492a706e9d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -642,21 +642,64 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + p_compute(i_j_idx) = + exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 8.0f); +#else + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 7.0f); +#endif + } + else + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } } - else + else if constexpr(kHasLogitsSoftCap) { - if constexpr(kHasLogitsSoftCap) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + p_compute(i_j_idx) = + exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 8.0f); +#else + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 7.0f); +#endif + } + else { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } + } + else + { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max + 8.0f); +#else + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max + 7.0f); +#endif + } else { p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); } } #else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])) * 256.0f; +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])) * 128.0f; +#endif + } + else + { + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); + } #endif }); }); From 01ccfc5e092432dc251cc0e40eaf8ae398cd900a Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 6 Jan 2026 13:28:30 +0000 Subject: [PATCH 14/25] precompute (m - shift) once per row --- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 68 ++++++------------- 1 file changed, 20 insertions(+), 48 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 492a706e9d..9c8d168b2f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -634,7 +634,22 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m_raw = get_validated_m(m[i_idx]); + auto validated_m = validated_m_raw; + auto row_max = scale_s * validated_m_raw; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= 8.0f; // for Bias/Alibi/SoftCap + row_max -= 8.0f; // for else branch +#else + validated_m -= 7.0f; + row_max -= 7.0f; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -642,45 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - p_compute(i_j_idx) = - exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 8.0f); -#else - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 7.0f); -#endif - } - else - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - } - else if constexpr(kHasLogitsSoftCap) - { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - p_compute(i_j_idx) = - exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 8.0f); -#else - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]) + 7.0f); -#endif - } - else - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + if constexpr(kHasLogitsSoftCap) { -#if CK_TILE_USE_OCP_FP8 - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max + 8.0f); -#else - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max + 7.0f); -#endif + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -688,18 +671,7 @@ struct BlockFmhaPipelineQRKSVSAsync } } #else - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])) * 256.0f; -#else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])) * 128.0f; -#endif - } - else - { - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); - } + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif }); }); From b08c25bfed99483cfddd8b32b30c165045929266 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 7 Jan 2026 03:50:02 +0000 Subject: [PATCH 15/25] change blk scale seqstrt ptr name --- example/ck_tile/01_fmha/fmha_fwd.hpp | 8 ++++---- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 6 ++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 364811cd82..5b64cfddb7 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,8 +230,8 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) - const void* bseqstart_q_ptr; - const void* bseqstart_k_ptr; + const void* blk_scale_seqstart_q_ptr; + const void* blk_scale_seqstart_k_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -621,8 +621,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, - args.bseqstart_q_ptr, - args.bseqstart_k_ptr, + args.blk_scale_seqstart_q_ptr, + args.blk_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 5441502352..dee8079701 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -485,8 +485,6 @@ fwd_result fmha_fwd_run(mode_enum mode, sizeof(VDataType) * hdim_v * real_seqlen_k); } } - // std::cout << "bseqstart_q_host: " << bseqstart_q_host - // << "bseqstart_k_host: " << bseqstart_k_host << std::endl; const ck_tile::index_t max_num_page_blocks = (0 < page_block_size @@ -1103,9 +1101,9 @@ fwd_result fmha_fwd_run(mode_enum mode, args.v_descale_ptr = reinterpret_cast(v_descale_buf.GetDeviceBuffer()); - args.bseqstart_q_ptr = + args.blk_scale_seqstart_q_ptr = (mode == mode_enum::group ? bseqstart_q_buf.GetDeviceBuffer() : nullptr); - args.bseqstart_k_ptr = + args.blk_scale_seqstart_k_ptr = (mode == mode_enum::group ? bseqstart_k_buf.GetDeviceBuffer() : nullptr); args.nhead_stride_q_descale = nhead_stride_q_descale; From 86ad0d1c5f25ae4ae528f66be6a5f212e3c8893d Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 7 Jan 2026 07:38:22 +0000 Subject: [PATCH 16/25] fix some name --- example/ck_tile/01_fmha/fmha_fwd.hpp | 12 ++-- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 59 ++++++++++--------- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 56 +++++++++--------- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 8 +-- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 6 +- 5 files changed, 72 insertions(+), 69 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5b64cfddb7..0bebb2f6cb 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -284,8 +284,8 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; - ck_tile::index_t block_scale_m; - ck_tile::index_t block_scale_n; + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; }; struct fmha_fwd_pagedkv_args @@ -653,8 +653,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.block_scale_m, - args.block_scale_n, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr); } @@ -711,8 +711,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.block_scale_m, - args.block_scale_n, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr); } diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index dee8079701..2909c57335 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -187,8 +187,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { - constexpr ck_tile::index_t block_scale_m_ = 128; - constexpr ck_tile::index_t block_scale_n_ = 128; + constexpr ck_tile::index_t block_scale_size_q_ = 128; + constexpr ck_tile::index_t block_scale_size_kv_ = 128; const std::string data_type = []() { if constexpr(std::is_same_v) @@ -451,8 +451,8 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - size_t num_block_scale_q = 0; - size_t num_block_scale_k = 0; + size_t i_block_scale_q = 0; + size_t i_block_scale_k = 0; std::vector bseqstart_q_host = {0}; std::vector bseqstart_k_host = {0}; auto max_seqlen_k = std::numeric_limits::min(); @@ -471,10 +471,10 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } - num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_); - num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_); - bseqstart_q_host.push_back(num_block_scale_q); - bseqstart_k_host.push_back(num_block_scale_k); + i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); + i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); + bseqstart_q_host.push_back(i_block_scale_q); + bseqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -536,12 +536,14 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); - const ck_tile::index_t num_block_scale_m = - (mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_) - : num_block_scale_q; - const ck_tile::index_t num_block_scale_n = - (mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_) - : num_block_scale_k; + const ck_tile::index_t num_block_scale_q = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) + : i_block_scale_q; + const ck_tile::index_t num_block_scale_kv = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_) + : i_block_scale_k; ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -595,15 +597,15 @@ fwd_result fmha_fwd_run(mode_enum mode, // TODO - change the tensor length for different quant scale ck_tile::HostTensor q_descale_host( qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead, num_block_scale_m} + ? std::array{shape_batch, nhead, num_block_scale_q} : std::array{1, 1, 1}); ck_tile::HostTensor k_descale_host( qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_n} + ? std::array{shape_batch, nhead_k, num_block_scale_kv} : std::array{1, 1, 1}); ck_tile::HostTensor v_descale_host( qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_n} + ? std::array{shape_batch, nhead_k, num_block_scale_kv} : std::array{1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] @@ -985,9 +987,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t nhead_stride_q_descale = num_block_scale_m; - const ck_tile::index_t nhead_stride_k_descale = num_block_scale_n; - const ck_tile::index_t nhead_stride_v_descale = num_block_scale_n; + const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q; + const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv; + const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -1005,9 +1007,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); - const ck_tile::index_t batch_stride_q_descale = num_block_scale_m * nhead; - const ck_tile::index_t batch_stride_k_descale = num_block_scale_n * nhead_k; - const ck_tile::index_t batch_stride_v_descale = num_block_scale_n * nhead_k; + const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead; + const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k; + const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1114,8 +1116,8 @@ fwd_result fmha_fwd_run(mode_enum mode, args.batch_stride_k_descale = batch_stride_k_descale; args.batch_stride_v_descale = batch_stride_v_descale; - args.block_scale_m = block_scale_m_; - args.block_scale_n = block_scale_n_; + args.block_scale_size_q = block_scale_size_q_; + args.block_scale_size_kv = block_scale_size_kv_; } else { @@ -1644,10 +1646,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return value * scale_s * q_descale_host(b_idx, std::get<0>(idx), - q_offset + std::get<1>(idx) / block_scale_m_) * + q_offset + std::get<1>(idx) / block_scale_size_q_) * k_descale_host(b_idx, std::get<0>(idx) / nr, - k_offset + std::get<2>(idx) / block_scale_n_); + k_offset + std::get<2>(idx) / block_scale_size_kv_); }); } else @@ -1834,7 +1836,8 @@ fwd_result fmha_fwd_run(mode_enum mode, return ck_tile::type_convert(value) * v_descale_host(b_idx, std::get<0>(idx) / nr, - v_offset + std::get<2>(idx) / block_scale_n_); + v_offset + + std::get<2>(idx) / block_scale_size_kv_); }, ck_tile::idx_identity{}); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0975c306f3..be50c3dc9e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -173,8 +173,8 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_k_descale; ck_tile::index_t nhead_stride_v_descale; - ck_tile::index_t block_scale_m; - ck_tile::index_t block_scale_n; + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; }; struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs @@ -377,8 +377,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - ck_tile::index_t block_scale_m, - ck_tile::index_t block_scale_n, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -461,8 +461,8 @@ struct FmhaFwdKernel kargs.batch_stride_k_descale = batch_stride_k_descale; kargs.batch_stride_v_descale = batch_stride_v_descale; - kargs.block_scale_m = block_scale_m; - kargs.block_scale_n = block_scale_n; + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; } if constexpr(kHasDropout) { @@ -549,8 +549,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_m, - ck_tile::index_t block_scale_n, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -606,8 +606,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_m, - block_scale_n, + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -666,8 +666,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_m, - ck_tile::index_t block_scale_n, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -723,8 +723,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_m, - block_scale_n, + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -778,8 +778,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - ck_tile::index_t block_scale_m, - ck_tile::index_t block_scale_n, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -857,8 +857,8 @@ struct FmhaFwdKernel kargs.nhead_stride_k_descale = nhead_stride_k_descale; kargs.nhead_stride_v_descale = nhead_stride_v_descale; - kargs.block_scale_m = block_scale_m; - kargs.block_scale_n = block_scale_n; + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; kargs.bseqstart_q_ptr = reinterpret_cast(bseqstart_q_ptr); kargs.bseqstart_k_ptr = reinterpret_cast(bseqstart_k_ptr); @@ -946,8 +946,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_m, - ck_tile::index_t block_scale_n, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -998,8 +998,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_m, - block_scale_n, + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -1053,8 +1053,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_m, - ck_tile::index_t block_scale_n, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -1105,8 +1105,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_m, - block_scale_n, + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -1743,7 +1743,7 @@ struct FmhaFwdKernel kargs.nhead_stride_v_descale + batch_offset_v_descale; - size_t idx = i_m0 / kargs.block_scale_m; + size_t idx = i_m0 / kargs.block_scale_size_q; float q_descale = q_descale_ptr[idx]; // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 // Both P and rowsum are scaled by 2^shift, canceling in normalization @@ -1774,7 +1774,7 @@ struct FmhaFwdKernel dropout, k_descale_ptr, v_descale_ptr, - kargs.block_scale_n); + kargs.block_scale_size_kv); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 89f5fd648b..f26192cf2e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -170,9 +170,9 @@ struct BlockFmhaPipelineQRKSVS DropoutType& dropout, const float* k_descale_ptr, const float* v_descale_ptr, - index_t block_scale_n) const + index_t block_scale_size_kv) const { - ignore = block_scale_n; + ignore = block_scale_size_kv; ignore = v_descale_ptr; static_assert( std::is_same_v> && @@ -346,7 +346,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); const auto row = k_origin.at(number<0>{}); - const index_t idx = row / block_scale_n; + const index_t idx = row / block_scale_size_kv; k_descale = k_descale_ptr[idx]; } // STAGE 1, QK gemm @@ -682,7 +682,7 @@ struct BlockFmhaPipelineQRKSVS { const auto v_origin = v_dram_window.get_window_origin(); const auto col = v_origin.at(number<1>{}); - const index_t idx = col / block_scale_n; + const index_t idx = col / block_scale_size_kv; v_descale = v_descale_ptr[idx]; } move_tile_window(v_dram_window, {0, kK1}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 9c8d168b2f..086b4ac2aa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -192,7 +192,7 @@ struct BlockFmhaPipelineQRKSVSAsync DropoutType& dropout, const float* k_descale_ptr, const float* v_descale_ptr, - index_t block_scale_n) const + index_t block_scale_size_kv) const { static_assert( std::is_same_v> && @@ -388,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); const auto row = k_origin.at(number<0>{}); - const index_t idx = row / block_scale_n; + const index_t idx = row / block_scale_size_kv; k_descale = k_descale_ptr[idx]; } // STAGE 1, QK gemm @@ -759,7 +759,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto v_origin = v_dram_window.get_window_origin(); const auto col = v_origin.at(number<1>{}); - const index_t idx = col / block_scale_n; + const index_t idx = col / block_scale_size_kv; v_descale = v_descale_ptr[idx]; } // STAGE 3, KV gemm From 99b720d8ca12fc6e54e3d69bdec6577aae46a5e0 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 15 Jan 2026 07:23:20 +0000 Subject: [PATCH 17/25] fix for deduction guide --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 50e83d3dd8..10f4992871 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1778,9 +1778,9 @@ struct FmhaFwdKernel identity{}, // bias_element_func randval_dram_window, lse_dram_window, - identity{}, // lse_element_func - scales{q_descale}, // s_acc_element_func - identity{}, // p_compute_element_func - No scaling (done in exp2) + identity{}, // lse_element_func + scales(q_descale), // s_acc_element_func + identity{}, // p_compute_element_func - No scaling (done in exp2) identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) mask, position_encoding, From 6d61a5e154eed31fb5601ab704126ddd8024085e Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 15 Jan 2026 08:23:59 +0000 Subject: [PATCH 18/25] fix some comments --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 1 - .../pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 11 +++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 02ee55091d..d5e935eacb 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -751,7 +751,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); - // return fwd_result::no_instance; } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 2aed4756d2..8b4b6dbd47 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -65,6 +65,9 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -660,11 +663,11 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { #if CK_TILE_USE_OCP_FP8 - validated_m -= 8.0f; // for Bias/Alibi/SoftCap - row_max -= 8.0f; // for else branch + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch #else - validated_m -= 7.0f; - row_max -= 7.0f; + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; #endif } #endif From bbe780e1bee8306d81762eb6246af8abbed46bfe Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 16 Jan 2026 06:05:18 +0000 Subject: [PATCH 19/25] add P scale to qr_ksvs_pipeline --- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 24 ++++++++++++++++--- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 6 ++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 6fa926809e..365896566c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -60,6 +60,10 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -592,7 +596,21 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -600,13 +618,13 @@ struct BlockFmhaPipelineQRKSVS if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 8b4b6dbd47..87b972a2ab 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -65,6 +65,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] static constexpr float OCP_FP8_SHIFT = 8.0f; static constexpr float FNUZ_FP8_SHIFT = 7.0f; @@ -657,9 +658,8 @@ struct BlockFmhaPipelineQRKSVSAsync // For BLOCKSCALE: precompute (m - shift) once per row // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) - auto validated_m_raw = get_validated_m(m[i_idx]); - auto validated_m = validated_m_raw; - auto row_max = scale_s * validated_m_raw; + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { #if CK_TILE_USE_OCP_FP8 From e8c31d710cd892281a92853c6999493b29f1dd9a Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 16 Jan 2026 06:28:14 +0000 Subject: [PATCH 20/25] add comment to idx_identity --- include/ck_tile/core/numeric/math.hpp | 2 +- include/ck_tile/core/utility/functional.hpp | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index e4b79d0187..a46ae509dd 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -41,7 +41,7 @@ struct scales CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const { auto new_scale = lhs_ * other; - return scales(new_scale); + return scales>(new_scale); } private: diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 1172f1537c..aa4bfa3f15 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -119,10 +119,13 @@ struct identity } }; +// Similar to identity, but takes an additional index parameter as the first argument. +// The index is ignored and only the second argument (value) is forwarded. +// Useful for indexed element-wise operations where the functor signature requires an index. struct idx_identity { template - CK_TILE_HOST_DEVICE constexpr T&& operator()(I&&, T&& arg) const noexcept + CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept { return std::forward(arg); } From fb9a3f2e0f1169416913a7eba90bca5488d69532 Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 16 Jan 2026 07:44:14 +0000 Subject: [PATCH 21/25] change the method of calculating descale block index --- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 22 +++++++++---------- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 14 +++++------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 365896566c..2fbc9fdb54 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -369,10 +369,9 @@ struct BlockFmhaPipelineQRKSVS float k_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { - const auto k_origin = k_dram_block_window.get_window_origin(); - const auto row = k_origin.at(number<0>{}); - const index_t idx = row / block_scale_size_kv; - k_descale = k_descale_ptr[idx]; + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( @@ -716,19 +715,18 @@ struct BlockFmhaPipelineQRKSVS tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } - float v_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - const auto v_origin = v_dram_window.get_window_origin(); - const auto col = v_origin.at(number<1>{}); - const index_t idx = col / block_scale_size_kv; - v_descale = v_descale_ptr[idx]; - } move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm auto o_acc0 = decltype(o_acc){}; clear_tile(o_acc0); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 87b972a2ab..046a2f0b9e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -414,10 +414,9 @@ struct BlockFmhaPipelineQRKSVSAsync float k_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { - const auto k_origin = k_dram_block_window.get_window_origin(); - const auto row = k_origin.at(number<0>{}); - const index_t idx = row / block_scale_size_kv; - k_descale = k_descale_ptr[idx]; + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C @@ -777,10 +776,9 @@ struct BlockFmhaPipelineQRKSVSAsync float v_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { - const auto v_origin = v_dram_window.get_window_origin(); - const auto col = v_origin.at(number<1>{}); - const index_t idx = col / block_scale_size_kv; - v_descale = v_descale_ptr[idx]; + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; } // STAGE 3, KV gemm auto o_acc0 = decltype(o_acc){}; From 162cda76e5770db3c02664360aedb9cd3e861d75 Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 16 Jan 2026 09:26:54 +0000 Subject: [PATCH 22/25] unify naming style: use block_scale_ as name prefix --- example/ck_tile/01_fmha/fmha_fwd.hpp | 8 ++--- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 4 +-- .../host/reference/reference_batched_gemm.hpp | 6 ++-- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 34 ++++++++++--------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 09eb8b8df2..aedbb0e17c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,8 +230,8 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) - const void* blk_scale_seqstart_q_ptr; - const void* blk_scale_seqstart_k_ptr; + const void* block_scale_seqstart_q_ptr; + const void* block_scale_seqstart_k_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -626,8 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, - args.blk_scale_seqstart_q_ptr, - args.blk_scale_seqstart_k_ptr, + args.block_scale_seqstart_q_ptr, + args.block_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index d5e935eacb..493f6544ed 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1138,9 +1138,9 @@ fwd_result fmha_fwd_run(mode_enum mode, args.v_descale_ptr = reinterpret_cast(v_descale_buf.GetDeviceBuffer()); - args.blk_scale_seqstart_q_ptr = + args.block_scale_seqstart_q_ptr = (mode == mode_enum::group ? bseqstart_q_buf.GetDeviceBuffer() : nullptr); - args.blk_scale_seqstart_k_ptr = + args.block_scale_seqstart_k_ptr = (mode == mode_enum::group ? bseqstart_k_buf.GetDeviceBuffer() : nullptr); args.nhead_stride_q_descale = nhead_stride_q_descale; diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 8d266ffca4..d742426740 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -51,9 +51,9 @@ template + typename AElementOp = ck_tile::idx_identity, + typename BElementOp = ck_tile::idx_identity, + typename ACCElementOp = ck_tile::idx_identity> CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, const HostTensor& b_b_n_k, HostTensor& c_b_m_n, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 10f4992871..0039c57cfc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -187,8 +187,8 @@ struct FmhaFwdKernel struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs { - const int32_t* bseqstart_q_ptr; - const int32_t* bseqstart_k_ptr; + const int32_t* block_scale_seqstart_q_ptr; + const int32_t* block_scale_seqstart_k_ptr; }; struct FmhaFwdCommonLSEKargs @@ -752,8 +752,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* bseqstart_q_ptr, - const void* bseqstart_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -869,8 +869,10 @@ struct FmhaFwdKernel kargs.block_scale_size_q = block_scale_size_q; kargs.block_scale_size_kv = block_scale_size_kv; - kargs.bseqstart_q_ptr = reinterpret_cast(bseqstart_q_ptr); - kargs.bseqstart_k_ptr = reinterpret_cast(bseqstart_k_ptr); + kargs.block_scale_seqstart_q_ptr = + reinterpret_cast(block_scale_seqstart_q_ptr); + kargs.block_scale_seqstart_k_ptr = + reinterpret_cast(block_scale_seqstart_k_ptr); } if constexpr(kHasDropout) { @@ -923,8 +925,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* bseqstart_q_ptr, - const void* bseqstart_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -976,8 +978,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, - bseqstart_q_ptr, - bseqstart_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -1032,8 +1034,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* bseqstart_q_ptr, - const void* bseqstart_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1085,8 +1087,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, - bseqstart_q_ptr, - bseqstart_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -1295,8 +1297,8 @@ struct FmhaFwdKernel } if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { - const long_index_t bquery_start = kargs.bseqstart_q_ptr[i_batch]; - const long_index_t bkey_start = kargs.bseqstart_k_ptr[i_batch]; + const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; + const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; batch_offset_q_descale = bquery_start; batch_offset_k_descale = bkey_start; batch_offset_v_descale = bkey_start; From 6d19a464eed024ca9fbb89b7370d9c3965e49c4d Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 16 Jan 2026 09:46:17 +0000 Subject: [PATCH 23/25] unify naming style --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 36 ++++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 493f6544ed..3838291636 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -474,11 +474,11 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - size_t i_block_scale_q = 0; - size_t i_block_scale_k = 0; - std::vector bseqstart_q_host = {0}; - std::vector bseqstart_k_host = {0}; - auto max_seqlen_k = std::numeric_limits::min(); + size_t i_block_scale_q = 0; + size_t i_block_scale_k = 0; + std::vector block_scale_seqstart_q_host = {0}; + std::vector block_scale_seqstart_k_host = {0}; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -496,8 +496,8 @@ fwd_result fmha_fwd_run(mode_enum mode, } i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); - bseqstart_q_host.push_back(i_block_scale_q); - bseqstart_k_host.push_back(i_block_scale_k); + block_scale_seqstart_q_host.push_back(i_block_scale_q); + block_scale_seqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -772,8 +772,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem bseqstart_q_buf(bseqstart_q_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem bseqstart_k_buf(bseqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * + sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -819,8 +821,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); - bseqstart_q_buf.ToDevice(bseqstart_q_host.data()); - bseqstart_k_buf.ToDevice(bseqstart_k_host.data()); + block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); + block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -1139,9 +1141,11 @@ fwd_result fmha_fwd_run(mode_enum mode, reinterpret_cast(v_descale_buf.GetDeviceBuffer()); args.block_scale_seqstart_q_ptr = - (mode == mode_enum::group ? bseqstart_q_buf.GetDeviceBuffer() : nullptr); + (mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer() + : nullptr); args.block_scale_seqstart_k_ptr = - (mode == mode_enum::group ? bseqstart_k_buf.GetDeviceBuffer() : nullptr); + (mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer() + : nullptr); args.nhead_stride_q_descale = nhead_stride_q_descale; args.nhead_stride_k_descale = nhead_stride_k_descale; @@ -1665,9 +1669,9 @@ fwd_result fmha_fwd_run(mode_enum mode, if(qscale.type == quant_scale_enum::blockscale) { const ck_tile::index_t q_offset = - (mode == mode_enum::batch) ? 0 : bseqstart_q_host[wb]; + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; const ck_tile::index_t k_offset = - (mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb]; + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; ck_tile::reference_batched_quant_gemm( p_host_ref, From 536cc5673fe547a544c3cab5c210a58e89b16344 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 21 Jan 2026 04:07:23 +0000 Subject: [PATCH 24/25] update the CHANGELOG.md --- CHANGELOG.md | 1 + example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 2 ++ 2 files changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a257e464..f1f5121439 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. +* Added block scale support for FMHA forward kernels. ### Changed diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 3838291636..b6287245a0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -210,6 +210,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the + // compute block size constexpr ck_tile::index_t block_scale_size_q_ = 128; constexpr ck_tile::index_t block_scale_size_kv_ = 128; From 1df0a1246b85f8ba093b8a3af214d96238c42e61 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 21 Jan 2026 04:19:46 +0000 Subject: [PATCH 25/25] Add FP8 block scale quantization support for FMHA forward kernel --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1f5121439..dfb50e9bdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. -* Added block scale support for FMHA forward kernels. +* Added FP8 block scale quantization for FMHA forward kernel. ### Changed