diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 155f19881e..b1cd1a52a7 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -4,7 +4,13 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; void abquant_quantgrouped_instance_factory( std::unordered_map>& lut) @@ -78,7 +84,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -93,7 +99,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -108,7 +114,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -123,7 +129,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5b..a95ca4862c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -192,6 +192,28 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill static constexpr bool PreshuffleQuant = true; }; +template +struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + +template +struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + template struct GemmConfigQuantPrefill : public GemmConfigBase { @@ -209,6 +231,13 @@ struct GemmConfigQuantPrefill : public GemmConfigBase ck_tile::get_k_warp_tile(); }; +template +struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill +{ + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + template struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9af..912527c929 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -33,6 +33,7 @@ template ); + constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped; using ComputeDataType = std::conditional_t; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase > CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) { - using DstrSpan = remove_cvref_t; - - static_ford{}([&](auto dstr_idx_impl) { - constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl); + using DstrSpanImpl = typename remove_cvref_t::Impl; - f(dstr_idx); - }); + if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case + f(detail::make_tile_distributed_index(sequence<>{})); + else + static_ford{}( + [&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); }); } // unpacked span, this version support span with unpack(multi-arg) functor diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 63a5151108..c95b980b8b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -213,6 +213,19 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }); }; + + auto q_block_tensor = aq_block_tensor; + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); + sweep_tile_span(aq_spans[I0], [&](auto im) { + sweep_tile_span(aq_spans[I1], [&](auto ik) { + q_block_tensor(make_tuple(im, ik)) *= + bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); + }); + }); + } + // hot loop: static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { @@ -243,9 +256,29 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg } }); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - AQPickerCommon aq_picker(aq_block_tensor); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for_product, number>{}([&](auto mIter, + auto nIter) { + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_kq = tile_distributed_index{}; + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); + }); + } + else + { + AQPickerCommon aq_picker( + aq_block_tensor); constexpr auto tbuf_offset = number{}, @@ -273,7 +306,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; }); - }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d13..15e87c4e50 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -285,37 +285,64 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase "C block tensor data type!"); constexpr auto warp_size = get_warp_size(); + // Start from AQ block tensor and then scale it using BQ; this represents + // the combined A/B quantization scales for the block. + auto q_block_tensor = aq_block_tensor; + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); + sweep_tile_span(aq_spans[I0{}], [&](auto im) { + sweep_tile_span(aq_spans[I1{}], [&](auto ik) { + q_block_tensor(make_tuple(im, ik)) *= + bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); + }); + }); + } + // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for_product, number>{}([&](auto mIter, + auto nIter) { CWarpTensor c_warp_tensor; + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } + if constexpr(Traits::NQPerBlock == 1) + { + constexpr auto cw_spans = CWarpTensor::get_distributed_spans(); + static_assert(cw_spans[I0{}].impl_.size() == 0); + sweep_tile_span(cw_spans[I1{}], [&](auto in) { + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_n = detail::make_tile_distributed_index( + merge_sequences(sequence{}, in.impl_)); + constexpr auto block_idx_kq = tile_distributed_index{}; + constexpr auto empty_idx = tile_distributed_index<>{}; + c_block_tensor(make_tuple(block_idx_m, block_idx_n)) += + c_warp_tensor(make_tuple(empty_idx, in)) * + q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); }); - + } + else + { constexpr auto tbuf_offset = number{}, @@ -387,7 +414,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase b_scale_reg_f); }); } - }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 0f3951ffcc..566f0b6153 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -101,10 +101,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); // clang-format on } - + /** + * @tparam nloop The number of iterations in the hot loop, + * used to normalize scheduling costs. + */ template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + static_assert(nloop > 0, "nloop must be greater than 0"); // Estimated number of VMEM vector loads for A per block: // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = @@ -127,12 +131,13 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; // Approximate number of LDS reads per block - constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop; // Approximate number of LDS writes per block // (e.g., writing A from VMEM into LDS once per A load) constexpr index_t ds_write_inst = Aload_inst; // Number of MFMA instructions per wave for one block tile: - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + constexpr index_t mfma_inst = + ((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop); // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. @@ -169,7 +174,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe } // Always mark some VALU work in the loop to reflect auxiliary scalar // or vector ALU instructions that coexist with MFMA (Blockscale calculation). - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU }); }); __builtin_amdgcn_sched_barrier(0); @@ -380,7 +385,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Prefetch A1 a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // initialize C @@ -407,7 +411,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe while(iCounter > 0) { __builtin_amdgcn_sched_barrier(0); - // Prefill A(2i+1) + // Prefill A(2i+1) ds_write a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_pong, a_block_tile_tmp); @@ -435,10 +439,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // prefetch Q(2i+1) aq_block_tile_2 = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile_2 = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Preload A(2i+1) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -460,6 +468,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // prefetch Q(2i+1) aq_block_tile = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile = load_tile(bq_copy_dram_window); @@ -481,7 +491,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile_2, bq_block_tile_2, a_warp_windows_pong); - + // Preload A(2i+2) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -521,7 +531,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile, bq_block_tile, a_warp_windows_ping); - + // Preload A ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp;