diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a257e464..dfb50e9bdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. +* Added FP8 block scale quantization for FMHA forward kernel. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index a3cfe2622a..cac6671ca5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str: QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", + "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..ed86f57232 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1018,7 +1018,7 @@ def get_pipelines( # no need lse/dropout kernels for logits, qscale, mask, bias, sink in itertools.product( ["t", "f"], - ["no", "pertensor"], + ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], ["f", "t"], @@ -1146,7 +1146,10 @@ def get_pipelines( elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] + ["f"], + ["no", "pertensor", "blockscale"], + get_mask_map(mask_impl).keys(), + ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fdd720fd75..aedbb0e17c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,8 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* block_scale_seqstart_q_ptr; + const void* block_scale_seqstart_k_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -257,6 +259,9 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -264,6 +269,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -276,6 +284,9 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; }; struct fmha_fwd_pagedkv_args @@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, + args.block_scale_seqstart_q_ptr, + args.block_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); @@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, + args.batch_stride_q_descale, + args.batch_stride_k_descale, + args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0c988b2acc..b6287245a0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the + // compute block size + constexpr ck_tile::index_t block_scale_size_q_ = 128; + constexpr ck_tile::index_t block_scale_size_kv_ = 128; + const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - auto max_seqlen_k = std::numeric_limits::min(); + size_t i_block_scale_q = 0; + size_t i_block_scale_k = 0; + std::vector block_scale_seqstart_q_host = {0}; + std::vector block_scale_seqstart_k_host = {0}; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } + i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); + i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); + block_scale_seqstart_q_host.push_back(i_block_scale_q); + block_scale_seqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); + const ck_tile::index_t num_block_scale_q = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) + : i_block_scale_q; + const ck_tile::index_t num_block_scale_kv = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_) + : i_block_scale_k; + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor sink_host({nhead}); @@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor q_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_q} + : std::array{1, 1, 1}); + ck_tile::HostTensor k_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_kv} + : std::array{1, 1, 1}); + ck_tile::HostTensor v_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_kv} + : std::array{1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -717,6 +748,12 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } + else if(qscale.type == quant_scale_enum::blockscale) + { + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -737,6 +774,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * + sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -782,6 +823,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); + block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); + block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -975,11 +1018,14 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q; + const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv; + const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -997,6 +1043,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead; + const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k; + const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1084,9 +1133,39 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + if(qscale.type == quant_scale_enum::blockscale) + { + args.q_descale_ptr = + reinterpret_cast(q_descale_buf.GetDeviceBuffer()); + args.k_descale_ptr = + reinterpret_cast(k_descale_buf.GetDeviceBuffer()); + args.v_descale_ptr = + reinterpret_cast(v_descale_buf.GetDeviceBuffer()); + + args.block_scale_seqstart_q_ptr = + (mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer() + : nullptr); + args.block_scale_seqstart_k_ptr = + (mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer() + : nullptr); + + args.nhead_stride_q_descale = nhead_stride_q_descale; + args.nhead_stride_k_descale = nhead_stride_k_descale; + args.nhead_stride_v_descale = nhead_stride_v_descale; + + args.batch_stride_q_descale = batch_stride_q_descale; + args.batch_stride_k_descale = batch_stride_k_descale; + args.batch_stride_v_descale = batch_stride_v_descale; + + args.block_scale_size_q = block_scale_size_q_; + args.block_scale_size_kv = block_scale_size_kv_; + } + else + { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + } args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1589,14 +1668,42 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - ck_tile:: - reference_batched_gemm( + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t q_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; + const ck_tile::index_t k_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + ck_tile::reference_batched_quant_gemm( q_host_ref, k_host_ref, s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s_host)); + ck_tile::idx_identity{}, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return value * scale_s * + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + std::get<1>(idx) / block_scale_size_q_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + std::get<2>(idx) / block_scale_size_kv_); + }); + } + else + { + ck_tile:: + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } if(0.f < logits_soft_cap) { @@ -1794,13 +1901,35 @@ fwd_result fmha_fwd_run(mode_enum mode, } } - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t v_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + ck_tile:: + reference_batched_quant_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return ck_tile::type_convert(value) * + v_descale_host(b_idx, + std::get<0>(idx) / nr, + v_offset + + std::get<2>(idx) / block_scale_size_kv_); + }, + ck_tile::idx_identity{}); + } + else + { + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + } ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off @@ -1808,7 +1937,6 @@ fwd_result fmha_fwd_run(mode_enum mode, if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on - auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err(o_host_result, o_host_ref, @@ -1866,31 +1994,33 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results(*json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale ? "no_scale" - : "pertensor", - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results( + *json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale + ? "no_scale" + : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 59d4ac1707..feb28cba24 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,6 +13,7 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, + blockscale, }; struct quant_scale_info @@ -25,6 +26,8 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; + else if(type == quant_scale_enum::blockscale) + os << "bs"; } static quant_scale_info decode(std::string str) @@ -38,6 +41,10 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } + else if(str == "bs" || str == "2") + { + info.type = quant_scale_enum::blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 596542eb9d..227f26c8f3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -95,10 +95,11 @@ run_fp8bf16_tests() { for perm in 0 1 ; do for b in 1 2 ; do for hdim in 64 128 256 ; do + for scale in 1 2; do - $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS - done ; done ; done + done ; done ; done ; done } run_fp8fp32_tests() { diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 96e76f669d..a46ae509dd 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -37,6 +37,13 @@ struct scales return lhs_ * rhs; } + template + CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const + { + auto new_scale = lhs_ * other; + return scales>(new_scale); + } + private: Scale lhs_; }; diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 898d21574e..aa4bfa3f15 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -119,6 +119,18 @@ struct identity } }; +// Similar to identity, but takes an additional index parameter as the first argument. +// The index is ignored and only the second argument (value) is forwarded. +// Useful for indexed element-wise operations where the functor signature requires an index. +struct idx_identity +{ + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept + { + return std::forward(arg); + } +}; + namespace detail { // RemainLengths: sequence<...> diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 63f13b1b16..d742426740 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -47,4 +47,44 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + AccDataType v_a = ck_tile::type_convert( + a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k))); + AccDataType v_b = ck_tile::type_convert( + b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k))); + + v_acc += v_a * v_b; + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert( + acc_element_op(std::make_tuple(batch, m, n), v_acc)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 3755a2bc71..7e0f704bef 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, + BLOCKSCALE, }; template @@ -27,5 +28,10 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "blockscale"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index adbedc5259..0039c57cfc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -168,6 +168,29 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; + struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs + { + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; + }; + + struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + + struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + const int32_t* block_scale_seqstart_q_ptr; + const int32_t* block_scale_seqstart_k_ptr; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -243,9 +266,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t> { @@ -269,9 +295,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -328,6 +357,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -335,6 +367,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -343,6 +378,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -413,6 +450,23 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -478,6 +532,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -485,6 +542,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -492,6 +552,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -528,6 +590,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -535,6 +600,9 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -542,6 +610,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -581,6 +651,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -588,6 +661,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -595,6 +671,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -631,6 +709,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -638,6 +719,9 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -645,6 +729,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -666,6 +752,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -685,6 +773,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -694,6 +785,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -763,6 +856,24 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; + + kargs.block_scale_seqstart_q_ptr = + reinterpret_cast(block_scale_seqstart_q_ptr); + kargs.block_scale_seqstart_k_ptr = + reinterpret_cast(block_scale_seqstart_k_ptr); + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -814,6 +925,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -833,6 +946,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -841,6 +957,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -860,6 +978,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -879,6 +999,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -887,6 +1010,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -909,6 +1034,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -928,6 +1055,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -936,6 +1066,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -955,6 +1087,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -974,6 +1108,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -982,6 +1119,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1111,13 +1250,16 @@ struct FmhaFwdKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + long_index_t batch_offset_q_descale = 0; + long_index_t batch_offset_k_descale = 0; + long_index_t batch_offset_v_descale = 0; const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s @@ -1153,6 +1295,14 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; + const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; + batch_offset_q_descale = bquery_start; + batch_offset_k_descale = bkey_start; + batch_offset_v_descale = bkey_start; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1220,6 +1370,15 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + batch_offset_q_descale = + static_cast(i_batch) * kargs.batch_stride_q_descale; + batch_offset_k_descale = + static_cast(i_batch) * kargs.batch_stride_k_descale; + batch_offset_v_descale = + static_cast(i_batch) * kargs.batch_stride_v_descale; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // If cumulative seqlen pointers are provided, override per-batch effective lengths @@ -1540,7 +1699,8 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { + + auto o_acc_tile = [&, i_nhead_ = i_nhead]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1581,8 +1741,62 @@ struct FmhaFwdKernel block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + size_t idx = i_m0 / kargs.block_scale_size_q; + float q_descale = q_descale_ptr[idx]; + // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 + // Both P and rowsum are scaled by 2^shift, canceling in normalization + // No additional scaling needed in p_compute_element_func or o_acc_element_func + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + scales(q_descale), // s_acc_element_func + identity{}, // p_compute_element_func - No scaling (done in exp2) + identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + kargs.block_scale_size_kv, + sink_value); + } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index dcccdf541c..2fbc9fdb54 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,8 +57,13 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -167,6 +172,9 @@ struct BlockFmhaPipelineQRKSVS const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -358,6 +366,13 @@ struct BlockFmhaPipelineQRKSVS static_assert(1 <= k1_loops); do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; + } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -427,11 +442,20 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); schedule_gemm0(); } + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -449,7 +473,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -466,7 +490,7 @@ struct BlockFmhaPipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -571,7 +595,21 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -579,13 +617,13 @@ struct BlockFmhaPipelineQRKSVS if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -676,18 +714,39 @@ struct BlockFmhaPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } + move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -722,11 +781,16 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(++i_total_loops < num_total_loop); // store lse @@ -846,6 +910,9 @@ struct BlockFmhaPipelineQRKSVS block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7224ed3a70..046a2f0b9e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,6 +46,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -64,6 +65,10 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -190,6 +195,9 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -403,6 +411,13 @@ struct BlockFmhaPipelineQRKSVSAsync // main loop do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; + } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -449,11 +464,20 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -471,7 +495,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -488,7 +512,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -630,7 +654,21 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -638,13 +676,13 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -735,7 +773,27 @@ struct BlockFmhaPipelineQRKSVSAsync #endif }(); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -745,7 +803,7 @@ struct BlockFmhaPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -808,13 +866,19 @@ struct BlockFmhaPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc, + o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } + + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(i_total_loops < num_total_loop); // store lse @@ -922,6 +986,9 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_v); } };