From 651e612f5d6b4bbbda8e62a8df916b7f9c749a16 Mon Sep 17 00:00:00 2001 From: KenSCLin Date: Tue, 20 Jan 2026 15:54:11 +0000 Subject: [PATCH 1/6] GEMM Blockscale ABQuant Optimization --- .../gemm_abquant_quantgrouped.cpp | 16 +++- .../38_block_scale_gemm/gemm_utils.hpp | 29 ++++++ .../run_gemm_quant_example.inc | 5 +- include/ck_tile/core/tensor/sweep_tile.hpp | 12 +-- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 95 +++++++++++++------ ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 82 ++++++++++------ .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 19 ++-- 7 files changed, 182 insertions(+), 76 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 155f19881ea..b1cd1a52a71 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -4,7 +4,13 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; void abquant_quantgrouped_instance_factory( std::unordered_map>& lut) @@ -78,7 +84,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -93,7 +99,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -108,7 +114,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -123,7 +129,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5ba..a95ca4862cf 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -192,6 +192,28 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill static constexpr bool PreshuffleQuant = true; }; +template +struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + +template +struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + template struct GemmConfigQuantPrefill : public GemmConfigBase { @@ -209,6 +231,13 @@ struct GemmConfigQuantPrefill : public GemmConfigBase ck_tile::get_k_warp_tile(); }; +template +struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill +{ + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + template struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9afd..0a9a92c3e5a 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -33,6 +33,7 @@ template ); + constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped; using ComputeDataType = std::conditional_t; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase > CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) { - using DstrSpan = remove_cvref_t; - - static_ford{}([&](auto dstr_idx_impl) { - constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl); + using DstrSpanImpl = typename remove_cvref_t::Impl; - f(dstr_idx); - }); + if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case + f(detail::make_tile_distributed_index(sequence<>{})); + else + static_ford{}( + [&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); }); } // unpacked span, this version support span with unpack(multi-arg) functor diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 63a51511088..7b1002e2052 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -213,6 +213,22 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }); }; + + using I0 = number<0>; + using I1 = number<1>; + auto q_block_tensor = std::move(aq_block_tensor); + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); + // CK_PRINT(); + sweep_tile_span(aq_spans[I0{}], [&](auto im) { + sweep_tile_span(aq_spans[I1{}], [&](auto ik) { + q_block_tensor(make_tuple(im, ik)) *= + bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); + }); + }); + } + // hot loop: static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { @@ -244,35 +260,58 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - AQPickerCommon aq_picker(aq_block_tensor); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - index_t reg_offset = [&]() { - if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) - { - return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + - kQScale; - } - else - { - return nIter * KPerBlockBQ + kQScale; - } - }(); - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - aq_picker.template cvt_scale_to_fp32(scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; - }); + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_kq = tile_distributed_index{}; + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = + c_ref + acc_val * q_block_tensor(make_tuple( + block_idx_m, block_idx_kq)); // b_scale_reg_f + }); + } + else + { + AQPickerCommon aq_picker( + aq_block_tensor); + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + aq_picker.template cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; + }); + } }); }); }); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d139..10e8931ab6f 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -285,37 +285,63 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase "C block tensor data type!"); constexpr auto warp_size = get_warp_size(); + auto q_block_tensor = std::move(aq_block_tensor); + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); + // CK_PRINT(); + sweep_tile_span(aq_spans[I0{}], [&](auto im) { + sweep_tile_span(aq_spans[I1{}], [&](auto ik) { + q_block_tensor(make_tuple(im, ik)) *= + bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); + }); + }); + } + // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for_product, number>{}([&](auto mIter, + auto nIter) { CWarpTensor c_warp_tensor; + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto cw_spans = CWarpTensor::get_distributed_spans(); + static_assert(cw_spans[I0{}].impl_.size() == 0); + sweep_tile_span(cw_spans[I1{}], [&](auto in) { + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_n = detail::make_tile_distributed_index( + merge_sequences(sequence{}, in.impl_)); + constexpr auto block_idx_kq = tile_distributed_index{}; + constexpr auto empty_idx = tile_distributed_index<>{}; + c_block_tensor(make_tuple(block_idx_m, block_idx_n)) += + c_warp_tensor(make_tuple(empty_idx, in)) * + q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); }); - + } + else + { constexpr auto tbuf_offset = number{}, @@ -387,7 +413,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase b_scale_reg_f); }); } - }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 0f3951ffccc..5eea4b9ff5c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -127,12 +127,12 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; // Approximate number of LDS reads per block - constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop; // Approximate number of LDS writes per block // (e.g., writing A from VMEM into LDS once per A load) constexpr index_t ds_write_inst = Aload_inst; // Number of MFMA instructions per wave for one block tile: - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) / nloop * (kNPerBlock / WG::kN) / nloop; // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. @@ -169,7 +169,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe } // Always mark some VALU work in the loop to reflect auxiliary scalar // or vector ALU instructions that coexist with MFMA (Blockscale calculation). - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU }); }); __builtin_amdgcn_sched_barrier(0); @@ -380,7 +380,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Prefetch A1 a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // initialize C @@ -407,7 +406,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe while(iCounter > 0) { __builtin_amdgcn_sched_barrier(0); - // Prefill A(2i+1) + // Prefill A(2i+1) ds_write a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_pong, a_block_tile_tmp); @@ -435,10 +434,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // prefetch Q(2i+1) aq_block_tile_2 = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile_2 = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Preload A(2i+1) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -460,6 +463,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // prefetch Q(2i+1) aq_block_tile = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile = load_tile(bq_copy_dram_window); @@ -481,7 +486,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile_2, bq_block_tile_2, a_warp_windows_pong); - + // Preload A(2i+2) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -521,7 +526,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile, bq_block_tile, a_warp_windows_ping); - + // Preload A ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; From 4a320457b9fb50480285b781f63730996a48ffe2 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 20 Jan 2026 14:38:16 -0800 Subject: [PATCH 2/6] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc | 1 - 1 file changed, 1 deletion(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 0a9a92c3e5a..912527c929a 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -89,7 +89,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - // constexpr bool transpose_c = false; // row-col and tensor quants use the regular pipeline, A/B/AB quants use their own using PipelineProblem = std::conditional_t< From b39c391b083a9709a7dd19c705d3d23ef75cde46 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 20 Jan 2026 14:39:59 -0800 Subject: [PATCH 3/6] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 10e8931ab6f..87f9635039c 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -285,14 +285,18 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase "C block tensor data type!"); constexpr auto warp_size = get_warp_size(); - auto q_block_tensor = std::move(aq_block_tensor); + // Start from AQ block tensor and then scale it using BQ; this represents + // the combined A/B quantization scales for the block. + auto aq_scaled_block_tensor = std::move(aq_block_tensor); + // Backward-compatible alias for any existing uses of q_block_tensor in this scope. + auto& q_block_tensor = aq_scaled_block_tensor; if constexpr(Traits::NQPerBlock == 1) { constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); // CK_PRINT(); sweep_tile_span(aq_spans[I0{}], [&](auto im) { sweep_tile_span(aq_spans[I1{}], [&](auto ik) { - q_block_tensor(make_tuple(im, ik)) *= + aq_scaled_block_tensor(make_tuple(im, ik)) *= bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); }); }); From 05d0f8587f905261c5b18400903eb2a9f8a7f9c2 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 20 Jan 2026 14:40:32 -0800 Subject: [PATCH 4/6] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 7b1002e2052..abbf3aaa52b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -216,14 +216,14 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg using I0 = number<0>; using I1 = number<1>; - auto q_block_tensor = std::move(aq_block_tensor); + auto aq_scaled_block_tensor = std::move(aq_block_tensor); if constexpr(Traits::NQPerBlock == 1) { constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); // CK_PRINT(); sweep_tile_span(aq_spans[I0{}], [&](auto im) { sweep_tile_span(aq_spans[I1{}], [&](auto ik) { - q_block_tensor(make_tuple(im, ik)) *= + aq_scaled_block_tensor(make_tuple(im, ik)) *= bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); }); }); From 54e355129b5717c9f5cfabfe5a0a1d324b1b0894 Mon Sep 17 00:00:00 2001 From: KenSCLin Date: Wed, 21 Jan 2026 02:27:02 +0000 Subject: [PATCH 5/6] fix precommit error --- ...block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 10 +++++++--- .../block_universal_gemm_as_aquant_bs_bquant_cr.hpp | 4 ++-- .../pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 9 +++++++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index abbf3aaa52b..9adac7fa8c9 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -214,16 +214,20 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }; - using I0 = number<0>; - using I1 = number<1>; + using I0 = number<0>; + using I1 = number<1>; + // Start from AQ block tensor and then scale it using BQ; this represents + // the combined A/B quantization scales for the block. auto aq_scaled_block_tensor = std::move(aq_block_tensor); + // Backward-compatible alias for any existing uses of q_block_tensor in this scope. + auto& q_block_tensor = aq_scaled_block_tensor; if constexpr(Traits::NQPerBlock == 1) { constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); // CK_PRINT(); sweep_tile_span(aq_spans[I0{}], [&](auto im) { sweep_tile_span(aq_spans[I1{}], [&](auto ik) { - aq_scaled_block_tensor(make_tuple(im, ik)) *= + q_block_tensor(make_tuple(im, ik)) *= bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); }); }); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 87f9635039c..cdd6fc1666d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -289,14 +289,14 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // the combined A/B quantization scales for the block. auto aq_scaled_block_tensor = std::move(aq_block_tensor); // Backward-compatible alias for any existing uses of q_block_tensor in this scope. - auto& q_block_tensor = aq_scaled_block_tensor; + auto& q_block_tensor = aq_scaled_block_tensor; if constexpr(Traits::NQPerBlock == 1) { constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); // CK_PRINT(); sweep_tile_span(aq_spans[I0{}], [&](auto im) { sweep_tile_span(aq_spans[I1{}], [&](auto ik) { - aq_scaled_block_tensor(make_tuple(im, ik)) *= + q_block_tensor(make_tuple(im, ik)) *= bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); }); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 5eea4b9ff5c..566f0b6153f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -101,10 +101,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); // clang-format on } - + /** + * @tparam nloop The number of iterations in the hot loop, + * used to normalize scheduling costs. + */ template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + static_assert(nloop > 0, "nloop must be greater than 0"); // Estimated number of VMEM vector loads for A per block: // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = @@ -132,7 +136,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // (e.g., writing A from VMEM into LDS once per A load) constexpr index_t ds_write_inst = Aload_inst; // Number of MFMA instructions per wave for one block tile: - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) / nloop * (kNPerBlock / WG::kN) / nloop; + constexpr index_t mfma_inst = + ((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop); // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. From 9058ac46fcac61b6ba455089272b8438d796c47b Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 21 Jan 2026 03:05:59 +0000 Subject: [PATCH 6/6] clean --- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 118 ++++++++---------- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 5 +- 2 files changed, 55 insertions(+), 68 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 9adac7fa8c9..c95b980b8bc 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -214,19 +214,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }; - using I0 = number<0>; - using I1 = number<1>; - // Start from AQ block tensor and then scale it using BQ; this represents - // the combined A/B quantization scales for the block. - auto aq_scaled_block_tensor = std::move(aq_block_tensor); - // Backward-compatible alias for any existing uses of q_block_tensor in this scope. - auto& q_block_tensor = aq_scaled_block_tensor; + auto q_block_tensor = aq_block_tensor; if constexpr(Traits::NQPerBlock == 1) { constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); - // CK_PRINT(); - sweep_tile_span(aq_spans[I0{}], [&](auto im) { - sweep_tile_span(aq_spans[I1{}], [&](auto ik) { + sweep_tile_span(aq_spans[I0], [&](auto im) { + sweep_tile_span(aq_spans[I1], [&](auto ik) { q_block_tensor(make_tuple(im, ik)) *= bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); }); @@ -263,60 +256,57 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg } }); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(Traits::NQPerBlock == 1) - { - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - constexpr auto block_idx_m = tile_distributed_index{}; - constexpr auto block_idx_kq = tile_distributed_index{}; - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = - c_ref + acc_val * q_block_tensor(make_tuple( - block_idx_m, block_idx_kq)); // b_scale_reg_f - }); - } - else - { - AQPickerCommon aq_picker( - aq_block_tensor); - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - index_t reg_offset = [&]() { - if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) - { - return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + - kQScale; - } - else - { - return nIter * KPerBlockBQ + kQScale; - } - }(); - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - aq_picker.template cvt_scale_to_fp32(scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; - }); - } - }); + static_for_product, number>{}([&](auto mIter, + auto nIter) { + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_kq = tile_distributed_index{}; + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); + }); + } + else + { + AQPickerCommon aq_picker( + aq_block_tensor); + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + aq_picker.template cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; + }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index cdd6fc1666d..15e87c4e50b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -287,13 +287,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // Start from AQ block tensor and then scale it using BQ; this represents // the combined A/B quantization scales for the block. - auto aq_scaled_block_tensor = std::move(aq_block_tensor); - // Backward-compatible alias for any existing uses of q_block_tensor in this scope. - auto& q_block_tensor = aq_scaled_block_tensor; + auto q_block_tensor = aq_block_tensor; if constexpr(Traits::NQPerBlock == 1) { constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); - // CK_PRINT(); sweep_tile_span(aq_spans[I0{}], [&](auto im) { sweep_tile_span(aq_spans[I1{}], [&](auto ik) { q_block_tensor(make_tuple(im, ik)) *=