From 3db86598ef3079faead4345f15c3970e470a146a Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 14 Jan 2026 17:53:16 +0000 Subject: [PATCH 1/8] WIP: host level interwave pipeline compiles --- .../gemm_aquant_quantgrouped.cpp | 5 +- .../38_block_scale_gemm/gemm_utils.hpp | 36 +++ .../block_universal_gemm_as_aquant_bs_cr.hpp | 186 +++++++++++ .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 306 ++++++++++++++++++ 4 files changed, 532 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d100..5e45a059ffa 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -3,8 +3,11 @@ #include "run_gemm_quant_example.inc" +// template +// using GemmConfig = GemmConfigQuantDecode; + template -using GemmConfig = GemmConfigQuantDecode; +using GemmConfig = GemmConfigQuantIntrawave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5ba..06fdbab4f59 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -95,6 +95,42 @@ struct GemmConfigQuantDecode : public GemmConfigBase ck_tile::get_k_warp_tile(); }; +template +struct GemmConfigQuantIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); +}; + +template +struct GemmConfigQuantInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + template struct GemmConfigRowColQuant : public GemmConfigBase { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b526..0bf35820b93 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -274,6 +274,7 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { CWarpTensor c_warp_tensor; + // for every column in AQ static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; @@ -322,6 +323,191 @@ struct AQuantBlockUniversalGemmAsBsCr } }; + // template + // struct BlockGemmImpl + // { + // static constexpr index_t KPerThread = GemmTraits::KPerThread; + // static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + // static constexpr index_t QuantGroupSizeK = GemmTraits::QuantGroupSize::kK; + + // // For quantized GEMM with Interwave, use KPerThread as the chunk size to keep + // // quant-group boundaries aligned and keep the structure similar to the base Interwave + // loop. static constexpr index_t KPerInnerLoop = KPerThread; static constexpr index_t + // KRepeat = 1; static constexpr index_t KInnerLoopIter = KPerInnerLoop / + // WarpGemm::kKPerThread; + + // static constexpr index_t KIterPerQScale = GemmTraits::KIterPerQScale; + + // // For quantized interwave, correctness is governed by KIterPerQScale and the + // // existing GemmTraits_ assertions; we don't require KPerInnerLoop to match + // // QuantGroupSizeK. + + // static constexpr auto ALdsTileDistr = + // make_static_tile_distribution(MakeABlockDistributionEncode()); + // static constexpr auto BLdsTileDistr = + // make_static_tile_distribution(MakeBBlockDistributionEncode()); + + // using ALdsTile = + // decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile + // = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + // ALdsTile a_warp_tile_; + // BLdsTile b_warp_tile_; + + // template + // CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + // const BSmemBlockWindow& b_block_window, + // bool_constant = {}, + // bool_constant = {}) + // { + // constexpr auto a_lds_load_distr = [&]() { + // if constexpr(ALoadTranspose) + // return make_static_tile_distribution(typename InputTileDistributionTraits< + // decltype(MakeABlockDistributionEncode()), + // ADataType>::TransposedDstrEncode{}); + // else + // return make_static_tile_distribution(MakeABlockDistributionEncode()); + // }(); + // constexpr auto b_lds_load_distr = [&]() { + // if constexpr(BLoadTranspose) + // return make_static_tile_distribution(typename InputTileDistributionTraits< + // decltype(MakeBBlockDistributionEncode()), + // BDataType>::TransposedDstrEncode{}); + // else + // return make_static_tile_distribution(MakeBBlockDistributionEncode()); + // }(); + // constexpr auto a_lds_shape = []() { + // if constexpr(ALoadTranspose) + // return make_tuple(number{}, number{}); + // else + // return make_tuple(number{}, number{}); + // }(); + // constexpr auto b_lds_shape = []() { + // if constexpr(BLoadTranspose) + // return make_tuple(number{}, number{}); + // else + // return make_tuple(number{}, number{}); + // }(); + // constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + // constexpr auto a_offset = + // ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, + // k_idx_offset}; + // constexpr auto b_offset = + // BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, + // k_idx_offset}; + + // auto a_lds_gemm_window = make_tile_window( + // a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, + // a_lds_load_distr); + // auto b_lds_gemm_window = make_tile_window( + // b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, + // b_lds_load_distr); + + // load_int4_tile( + // a_warp_tile_, a_lds_gemm_window); + // load_int4_tile( + // b_warp_tile_, b_lds_gemm_window); + // } + + // // C += A * B with quantization scales + // template + // CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + // AQBlockTensor& aq_block_tensor, + // const ASmemBlockWindow& a_block_window, + // const BSmemBlockWindow& b_block_window, + // bool_constant a_load_tr = {}, + // bool_constant b_load_tr = {}) + // { + // static_assert(std::is_same_v, + // "The CDataType as defined in traits should be the same as corresponding + // " "C block tensor data type!"); + // constexpr auto warp_size = get_warp_size(); + + // static_for<0, KRepeat, 1>{}([&](auto kIter) { + // LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + // __builtin_amdgcn_sched_barrier(0); + + // if constexpr(kIter.value != 0 || KRepeat == 1) + // { + // __builtin_amdgcn_s_barrier(); + // __builtin_amdgcn_sched_barrier(0); + // } + + // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // CWarpTensor c_warp_tensor; + + // static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { + // constexpr index_t global_k_iter = + // kIter.value * KInnerLoopIter + kInnerIter.value; + // constexpr index_t q_scale_idx = global_k_iter / + // Traits::KIterPerQScale; constexpr index_t k_iter_in_qscale = + // global_k_iter % Traits::KIterPerQScale; + + // AWarpTensor a_warp_tensor; + // a_warp_tensor.get_thread_buffer() = + // a_warp_tile_.get_y_sliced_thread_data( + // merge_sequences(sequence{}, + // a_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // BWarpTensor b_warp_tensor; + // b_warp_tensor.get_thread_buffer() = + // b_warp_tile_.get_y_sliced_thread_data( + // merge_sequences(sequence{}, + // b_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // if constexpr(k_iter_in_qscale == 0) + // { + // c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + // } + // else + // { + // WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // } + + // if constexpr(k_iter_in_qscale == Traits::KIterPerQScale - 1) + // { + // constexpr auto tbuf_offset = number< + // typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + // merge_sequences(sequence{}, + // c_warp_y_index_zeros)) / + // CBlockTensor::PackedSize>{}; + + // AQPickerCommon{}> + // aq_picker(aq_block_tensor); + + // static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + // [&](auto c_row) { + // float scale_reg_f = aq_picker.template pick(); + // c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] + // += + // (c_warp_tensor.get_thread_buffer()[c_row] * + // scale_reg_f); + // }); + // } + // }); + // }); + // }); + + // __builtin_amdgcn_sched_barrier(0); + // __builtin_amdgcn_s_setprio(0); + // __builtin_amdgcn_sched_barrier(0); + // }); + // } + // }; + public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 650cd947f73..a5d87195f12 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -486,6 +486,312 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; + // template <> + // struct PipelineImpl : public PipelineImplBase + // { + // using Base = PipelineImplBase; + + // template + // CK_TILE_DEVICE static void + // LoadAndConvertATile(ABlockTile_& a_block_tile, + // ADramWindow& a_dram_window, + // const DramTileWindowStep& dram_tile_window_step) + // { + // using DestDataType = typename ABlockTile_::DataType; + // using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + // constexpr index_t UnaryOpSize = 8; + // load_int4_tile(a_block_tile, a_dram_window); + // move_tile_window(a_dram_window, dram_tile_window_step); + // } + + // template + // CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + // const AElementFunction& a_element_func, + // const BDramBlockWindowTmp& b_dram_block_window_tmp, + // const BElementFunction& b_element_func, + // const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + // [[maybe_unused]] index_t m, + // index_t num_loop, + // void* p_smem) const + // { + // static_assert( + // std::is_same_v> + // && + // std::is_same_v> && + // std::is_same_v>, + // "A/B/AQ Dram block window should have the same data type as appropriate " + // "([A|B|AQ]DataType) defined in Problem definition!"); + + // constexpr bool is_a_col_major = + // std::is_same_v; + // constexpr bool is_aq_col_major = + // std::is_same_v; + // constexpr bool is_b_row_major = std::is_same_v; + + // static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); + + // static_assert(is_a_col_major + // ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + // MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + // : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + // KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + // "A block window has incorrect lengths for defined ALayout!"); + // static_assert(is_b_row_major + // ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + // NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + // : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + // KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + // "B block window has incorrect lengths for defined BLayout!"); + + // auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = + // ab_lds_blocks.at(I1{}); + + // constexpr auto a_lds_load_tile_distr = + // make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + // constexpr auto b_lds_load_tile_distr = + // make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // auto a_windows = + // Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + // auto& a_copy_dram_window = a_windows.at(I0{}); + // auto& a_copy_lds_window = a_windows.at(I1{}); + // auto& a_lds_gemm_window = a_windows.at(I2{}); + + // auto b_windows = + // Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + // auto& b_copy_dram_window = b_windows.at(I0{}); + // auto& b_copy_lds_window = b_windows.at(I1{}); + // auto& b_lds_gemm_window = b_windows.at(I2{}); + + // auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + // auto block_gemm = BlockGemm(); + // auto c_block_tile = block_gemm.MakeCBlockTile(); + + // using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + // using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + // using ABlockTile = + // decltype(make_static_distributed_tensor(ABlockTileDistr{})); + // using BBlockTile = + // decltype(make_static_distributed_tensor(BBlockTileDistr{})); + // using AQBlockTile = + // decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + // tuple_array a_block_tiles; + // tuple_array b_block_tiles; + // tuple_array aq_block_tiles; + + // using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + // using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + // using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + // constexpr ADramTileWindowStep a_dram_tile_window_step = + // is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + // constexpr BDramTileWindowStep b_dram_tile_window_step = + // is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + // constexpr AQDramTileWindowStep aq_dram_tile_window_step = + // is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + // LoadAndConvertATile( + // a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); + // Base::GlobalPrefetch( + // b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + // Base::GlobalPrefetch( + // aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); + + // tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // if constexpr(is_a_col_major && !is_a_load_tr_v()) + // { + // auto a_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledARegTileDistribution()); + // transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); + // Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + // } + // else + // { + // Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + // } + // if constexpr(is_b_row_major && !is_b_load_tr_v()) + // { + // auto b_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledBRegTileDistribution()); + // transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); + // Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + // } + // else + // { + // Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + // } + + // static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + // LoadAndConvertATile(a_block_tiles.get(number{}), + // a_copy_dram_window, + // a_dram_tile_window_step); + // Base::GlobalPrefetch(b_block_tiles.get(number{}), + // b_copy_dram_window, + // b_dram_tile_window_step); + // Base::GlobalPrefetch(aq_block_tiles.get(number{}), + // aq_copy_dram_window, + // aq_dram_tile_window_step); + // }); + + // if constexpr(HasHotLoop) + // { + // index_t i = 0; + // do + // { + // static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { + // block_sync_lds(); + // block_gemm(c_block_tile, + // aq_block_tiles.get(number{}), + // a_lds_gemm_window, + // b_lds_gemm_window); + + // if constexpr(is_a_col_major && !is_a_load_tr_v()) + // { + // auto a_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledARegTileDistribution()); + // transpose_tile2d( + // a_shuffle_tmp, + // a_block_tiles.get(number<(prefetch_idx + 1) % + // PrefetchStages>{})); + // Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + // } + // else + // { + // Base::LocalPrefill( + // a_copy_lds_window, + // a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + // a_element_func); + // } + // if constexpr(is_b_row_major && !is_b_load_tr_v()) + // { + // auto b_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledBRegTileDistribution()); + // transpose_tile2d( + // b_shuffle_tmp, + // b_block_tiles.get(number<(prefetch_idx + 1) % + // PrefetchStages>{})); + // Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + // } + // else + // { + // Base::LocalPrefill( + // b_copy_lds_window, + // b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + // b_element_func); + // } + + // LoadAndConvertATile(a_block_tiles.get(number{}), + // a_copy_dram_window, + // a_dram_tile_window_step); + // Base::GlobalPrefetch(b_block_tiles.get(number{}), + // b_copy_dram_window, + // b_dram_tile_window_step); + // Base::GlobalPrefetch(aq_block_tiles.get(number{}), + // aq_copy_dram_window, + // aq_dram_tile_window_step); + // }); + + // i += PrefetchStages; + // } while(i < (num_loop - PrefetchStages)); + // } + + // auto HotLoopTail = [&](auto tail_num) { + // static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { + // block_sync_lds(); + // block_gemm(c_block_tile, + // aq_block_tiles.get(number{}), + // a_lds_gemm_window, + // b_lds_gemm_window); + + // if constexpr(is_a_col_major && !is_a_load_tr_v()) + // { + // auto a_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledARegTileDistribution()); + // transpose_tile2d(a_shuffle_tmp, + // a_block_tiles.get(number{})); + // Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + // } + // else + // { + // Base::LocalPrefill(a_copy_lds_window, + // a_block_tiles.get(number{})); + // } + // if constexpr(is_b_row_major && !is_b_load_tr_v()) + // { + // auto b_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledBRegTileDistribution()); + // transpose_tile2d(b_shuffle_tmp, + // b_block_tiles.get(number{})); + // Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + // } + // else + // { + // Base::LocalPrefill(b_copy_lds_window, + // b_block_tiles.get(number{})); + // } + // }); + + // block_sync_lds(); + // block_gemm(c_block_tile, + // aq_block_tiles.get(number{}), + // a_lds_gemm_window, + // b_lds_gemm_window); + // }; + + // if constexpr(TailNum == TailNumber::One) + // { + // block_sync_lds(); + // block_gemm( + // c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, + // b_lds_gemm_window); + // } + // else if constexpr(TailNum == TailNumber::Two) + // { + // HotLoopTail(number<2>{}); + // } + // else if constexpr(TailNum == TailNumber::Three) + // { + // HotLoopTail(number<3>{}); + // } + // else if constexpr(TailNum == TailNumber::Four) + // { + // HotLoopTail(number<4>{}); + // } + // else if constexpr(TailNum == TailNumber::Five) + // { + // HotLoopTail(number<5>{}); + // } + // else if constexpr(TailNum == TailNumber::Six) + // { + // HotLoopTail(number<6>{}); + // } + // else if constexpr(TailNum == TailNumber::Seven) + // { + // HotLoopTail(number<7>{}); + // } + // else if constexpr(TailNum == TailNumber::Full) + // { + // HotLoopTail(number{}); + // } + // return c_block_tile; + // } + // }; + template From 3f9a29b450fd6a79dda7652fa0515739841d013f Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 15 Jan 2026 17:01:31 +0000 Subject: [PATCH 2/8] WIP: interwave implementation computes correct GEMM result when no aquant --- .../gemm_aquant_quantgrouped.cpp | 2 +- .../run_gemm_quant_example.inc | 10 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 357 +++++------ .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 603 +++++++++--------- 4 files changed, 477 insertions(+), 495 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index 5e45a059ffa..ca5f75cee5c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -7,7 +7,7 @@ // using GemmConfig = GemmConfigQuantDecode; template -using GemmConfig = GemmConfigQuantIntrawave; +using GemmConfig = GemmConfigQuantInterwave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9afd..63d94a0ba94 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -557,7 +557,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, std::mt19937 gen(rd()); std::uniform_int_distribution fill_seed(0, 500); - if(init_method == 0) + if(init_method == 0) // uniform distribution { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -594,7 +594,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } @@ -627,12 +627,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, *bq_tensor_ptr); } } - else if(init_method == 1) + else if(init_method == 1) // monotonic initialization { std::cout << "Monotonic initialization is not supported." << std::endl; return 0; } - else if(init_method == 2) + else if(init_method == 2) // constant initialization { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else { ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(1.0f)}(*aq_tensor_ptr); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 0bf35820b93..a34a854218e 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -323,190 +323,179 @@ struct AQuantBlockUniversalGemmAsBsCr } }; - // template - // struct BlockGemmImpl - // { - // static constexpr index_t KPerThread = GemmTraits::KPerThread; - // static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; - // static constexpr index_t QuantGroupSizeK = GemmTraits::QuantGroupSize::kK; - - // // For quantized GEMM with Interwave, use KPerThread as the chunk size to keep - // // quant-group boundaries aligned and keep the structure similar to the base Interwave - // loop. static constexpr index_t KPerInnerLoop = KPerThread; static constexpr index_t - // KRepeat = 1; static constexpr index_t KInnerLoopIter = KPerInnerLoop / - // WarpGemm::kKPerThread; - - // static constexpr index_t KIterPerQScale = GemmTraits::KIterPerQScale; - - // // For quantized interwave, correctness is governed by KIterPerQScale and the - // // existing GemmTraits_ assertions; we don't require KPerInnerLoop to match - // // QuantGroupSizeK. - - // static constexpr auto ALdsTileDistr = - // make_static_tile_distribution(MakeABlockDistributionEncode()); - // static constexpr auto BLdsTileDistr = - // make_static_tile_distribution(MakeBBlockDistributionEncode()); - - // using ALdsTile = - // decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile - // = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - // ALdsTile a_warp_tile_; - // BLdsTile b_warp_tile_; - - // template - // CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - // const BSmemBlockWindow& b_block_window, - // bool_constant = {}, - // bool_constant = {}) - // { - // constexpr auto a_lds_load_distr = [&]() { - // if constexpr(ALoadTranspose) - // return make_static_tile_distribution(typename InputTileDistributionTraits< - // decltype(MakeABlockDistributionEncode()), - // ADataType>::TransposedDstrEncode{}); - // else - // return make_static_tile_distribution(MakeABlockDistributionEncode()); - // }(); - // constexpr auto b_lds_load_distr = [&]() { - // if constexpr(BLoadTranspose) - // return make_static_tile_distribution(typename InputTileDistributionTraits< - // decltype(MakeBBlockDistributionEncode()), - // BDataType>::TransposedDstrEncode{}); - // else - // return make_static_tile_distribution(MakeBBlockDistributionEncode()); - // }(); - // constexpr auto a_lds_shape = []() { - // if constexpr(ALoadTranspose) - // return make_tuple(number{}, number{}); - // else - // return make_tuple(number{}, number{}); - // }(); - // constexpr auto b_lds_shape = []() { - // if constexpr(BLoadTranspose) - // return make_tuple(number{}, number{}); - // else - // return make_tuple(number{}, number{}); - // }(); - // constexpr auto k_idx_offset = KIdx * KPerInnerLoop; - // constexpr auto a_offset = - // ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, - // k_idx_offset}; - // constexpr auto b_offset = - // BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, - // k_idx_offset}; - - // auto a_lds_gemm_window = make_tile_window( - // a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, - // a_lds_load_distr); - // auto b_lds_gemm_window = make_tile_window( - // b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, - // b_lds_load_distr); - - // load_int4_tile( - // a_warp_tile_, a_lds_gemm_window); - // load_int4_tile( - // b_warp_tile_, b_lds_gemm_window); - // } - - // // C += A * B with quantization scales - // template - // CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - // AQBlockTensor& aq_block_tensor, - // const ASmemBlockWindow& a_block_window, - // const BSmemBlockWindow& b_block_window, - // bool_constant a_load_tr = {}, - // bool_constant b_load_tr = {}) - // { - // static_assert(std::is_same_v, - // "The CDataType as defined in traits should be the same as corresponding - // " "C block tensor data type!"); - // constexpr auto warp_size = get_warp_size(); - - // static_for<0, KRepeat, 1>{}([&](auto kIter) { - // LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); - // __builtin_amdgcn_sched_barrier(0); - - // if constexpr(kIter.value != 0 || KRepeat == 1) - // { - // __builtin_amdgcn_s_barrier(); - // __builtin_amdgcn_sched_barrier(0); - // } - - // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // CWarpTensor c_warp_tensor; - - // static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { - // constexpr index_t global_k_iter = - // kIter.value * KInnerLoopIter + kInnerIter.value; - // constexpr index_t q_scale_idx = global_k_iter / - // Traits::KIterPerQScale; constexpr index_t k_iter_in_qscale = - // global_k_iter % Traits::KIterPerQScale; - - // AWarpTensor a_warp_tensor; - // a_warp_tensor.get_thread_buffer() = - // a_warp_tile_.get_y_sliced_thread_data( - // merge_sequences(sequence{}, - // a_warp_y_index_zeros), - // merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // BWarpTensor b_warp_tensor; - // b_warp_tensor.get_thread_buffer() = - // b_warp_tile_.get_y_sliced_thread_data( - // merge_sequences(sequence{}, - // b_warp_y_index_zeros), - // merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // if constexpr(k_iter_in_qscale == 0) - // { - // c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - // } - // else - // { - // WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // } - - // if constexpr(k_iter_in_qscale == Traits::KIterPerQScale - 1) - // { - // constexpr auto tbuf_offset = number< - // typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( - // merge_sequences(sequence{}, - // c_warp_y_index_zeros)) / - // CBlockTensor::PackedSize>{}; - - // AQPickerCommon{}> - // aq_picker(aq_block_tensor); - - // static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - // [&](auto c_row) { - // float scale_reg_f = aq_picker.template pick(); - // c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] - // += - // (c_warp_tensor.get_thread_buffer()[c_row] * - // scale_reg_f); - // }); - // } - // }); - // }); - // }); - - // __builtin_amdgcn_sched_barrier(0); - // __builtin_amdgcn_s_setprio(0); - // __builtin_amdgcn_sched_barrier(0); - // }); - // } - // }; + template + struct BlockGemmImpl + { + static constexpr index_t KPerThread = GemmTraits::KPerThread; + static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + // static constexpr index_t QuantGroupSizeK = GemmTraits::QuantGroupSize::kK; + + // Match the base Interwave loop structure; quantization handling will be reintroduced + // later. + static constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; + + // static constexpr index_t KIterPerQScale = GemmTraits::KIterPerQScale; + + static constexpr auto ALdsTileDistr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + static constexpr auto BLdsTileDistr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) + { + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + + auto a_lds_gemm_window = make_tile_window( + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); + auto b_lds_gemm_window = make_tile_window( + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); + + load_int4_tile( + a_warp_tile_, a_lds_gemm_window); + load_int4_tile( + b_warp_tile_, b_lds_gemm_window); + } + + // C += A * B (quantization/scaling paths are intentionally commented for now) + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as corresponding " + "C block tensor data type!"); + // constexpr auto warp_size = get_warp_size(); + + static_for<0, KRepeat, 1>{}([&](auto kIter) { + if(get_thread_id() == 0 && get_block_id() == 0) + { + printf("kIter: %d\n", kIter.value); + } + LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + __builtin_amdgcn_sched_barrier(0); + + if constexpr(kIter.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Quantization accumulation path (AQPicker, qscale application) to be + // re-enabled later. + + if constexpr(kIter.value == KRepeat - 1 && + kInnerIter.value == KInnerLoopIter - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + if constexpr(kInnerIter.value == 0 && mIter.value == 0 && + nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + }; public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index a5d87195f12..4c251187f51 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -486,311 +486,304 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; - // template <> - // struct PipelineImpl : public PipelineImplBase - // { - // using Base = PipelineImplBase; - - // template - // CK_TILE_DEVICE static void - // LoadAndConvertATile(ABlockTile_& a_block_tile, - // ADramWindow& a_dram_window, - // const DramTileWindowStep& dram_tile_window_step) - // { - // using DestDataType = typename ABlockTile_::DataType; - // using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; - // constexpr index_t UnaryOpSize = 8; - // load_int4_tile(a_block_tile, a_dram_window); - // move_tile_window(a_dram_window, dram_tile_window_step); - // } - - // template - // CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - // const AElementFunction& a_element_func, - // const BDramBlockWindowTmp& b_dram_block_window_tmp, - // const BElementFunction& b_element_func, - // const AQDramBlockWindowTmp& aq_dram_block_window_tmp, - // [[maybe_unused]] index_t m, - // index_t num_loop, - // void* p_smem) const - // { - // static_assert( - // std::is_same_v> - // && - // std::is_same_v> && - // std::is_same_v>, - // "A/B/AQ Dram block window should have the same data type as appropriate " - // "([A|B|AQ]DataType) defined in Problem definition!"); - - // constexpr bool is_a_col_major = - // std::is_same_v; - // constexpr bool is_aq_col_major = - // std::is_same_v; - // constexpr bool is_b_row_major = std::is_same_v; - - // static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); - - // static_assert(is_a_col_major - // ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - // MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - // : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - // KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - // "A block window has incorrect lengths for defined ALayout!"); - // static_assert(is_b_row_major - // ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - // NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - // : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - // KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - // "B block window has incorrect lengths for defined BLayout!"); - - // auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = - // ab_lds_blocks.at(I1{}); - - // constexpr auto a_lds_load_tile_distr = - // make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - // constexpr auto b_lds_load_tile_distr = - // make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // auto a_windows = - // Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); - // auto& a_copy_dram_window = a_windows.at(I0{}); - // auto& a_copy_lds_window = a_windows.at(I1{}); - // auto& a_lds_gemm_window = a_windows.at(I2{}); - - // auto b_windows = - // Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); - // auto& b_copy_dram_window = b_windows.at(I0{}); - // auto& b_copy_lds_window = b_windows.at(I1{}); - // auto& b_lds_gemm_window = b_windows.at(I2{}); - - // auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); - - // auto block_gemm = BlockGemm(); - // auto c_block_tile = block_gemm.MakeCBlockTile(); - - // using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - // using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); - - // using ABlockTile = - // decltype(make_static_distributed_tensor(ABlockTileDistr{})); - // using BBlockTile = - // decltype(make_static_distributed_tensor(BBlockTileDistr{})); - // using AQBlockTile = - // decltype(make_static_distributed_tensor(AQBlockTileDistr{})); - - // tuple_array a_block_tiles; - // tuple_array b_block_tiles; - // tuple_array aq_block_tiles; - - // using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - // using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - // using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; - - // constexpr ADramTileWindowStep a_dram_tile_window_step = - // is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - // constexpr BDramTileWindowStep b_dram_tile_window_step = - // is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - // constexpr AQDramTileWindowStep aq_dram_tile_window_step = - // is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); - - // LoadAndConvertATile( - // a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - // Base::GlobalPrefetch( - // b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); - // Base::GlobalPrefetch( - // aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); - - // tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // if constexpr(is_a_col_major && !is_a_load_tr_v()) - // { - // auto a_shuffle_tmp = make_static_distributed_tensor( - // Policy::template MakeShuffledARegTileDistribution()); - // transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - // Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - // } - // else - // { - // Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); - // } - // if constexpr(is_b_row_major && !is_b_load_tr_v()) - // { - // auto b_shuffle_tmp = make_static_distributed_tensor( - // Policy::template MakeShuffledBRegTileDistribution()); - // transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - // Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - // } - // else - // { - // Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); - // } - - // static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - // LoadAndConvertATile(a_block_tiles.get(number{}), - // a_copy_dram_window, - // a_dram_tile_window_step); - // Base::GlobalPrefetch(b_block_tiles.get(number{}), - // b_copy_dram_window, - // b_dram_tile_window_step); - // Base::GlobalPrefetch(aq_block_tiles.get(number{}), - // aq_copy_dram_window, - // aq_dram_tile_window_step); - // }); - - // if constexpr(HasHotLoop) - // { - // index_t i = 0; - // do - // { - // static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { - // block_sync_lds(); - // block_gemm(c_block_tile, - // aq_block_tiles.get(number{}), - // a_lds_gemm_window, - // b_lds_gemm_window); - - // if constexpr(is_a_col_major && !is_a_load_tr_v()) - // { - // auto a_shuffle_tmp = make_static_distributed_tensor( - // Policy::template MakeShuffledARegTileDistribution()); - // transpose_tile2d( - // a_shuffle_tmp, - // a_block_tiles.get(number<(prefetch_idx + 1) % - // PrefetchStages>{})); - // Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - // } - // else - // { - // Base::LocalPrefill( - // a_copy_lds_window, - // a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - // a_element_func); - // } - // if constexpr(is_b_row_major && !is_b_load_tr_v()) - // { - // auto b_shuffle_tmp = make_static_distributed_tensor( - // Policy::template MakeShuffledBRegTileDistribution()); - // transpose_tile2d( - // b_shuffle_tmp, - // b_block_tiles.get(number<(prefetch_idx + 1) % - // PrefetchStages>{})); - // Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - // } - // else - // { - // Base::LocalPrefill( - // b_copy_lds_window, - // b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - // b_element_func); - // } - - // LoadAndConvertATile(a_block_tiles.get(number{}), - // a_copy_dram_window, - // a_dram_tile_window_step); - // Base::GlobalPrefetch(b_block_tiles.get(number{}), - // b_copy_dram_window, - // b_dram_tile_window_step); - // Base::GlobalPrefetch(aq_block_tiles.get(number{}), - // aq_copy_dram_window, - // aq_dram_tile_window_step); - // }); - - // i += PrefetchStages; - // } while(i < (num_loop - PrefetchStages)); - // } - - // auto HotLoopTail = [&](auto tail_num) { - // static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { - // block_sync_lds(); - // block_gemm(c_block_tile, - // aq_block_tiles.get(number{}), - // a_lds_gemm_window, - // b_lds_gemm_window); - - // if constexpr(is_a_col_major && !is_a_load_tr_v()) - // { - // auto a_shuffle_tmp = make_static_distributed_tensor( - // Policy::template MakeShuffledARegTileDistribution()); - // transpose_tile2d(a_shuffle_tmp, - // a_block_tiles.get(number{})); - // Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - // } - // else - // { - // Base::LocalPrefill(a_copy_lds_window, - // a_block_tiles.get(number{})); - // } - // if constexpr(is_b_row_major && !is_b_load_tr_v()) - // { - // auto b_shuffle_tmp = make_static_distributed_tensor( - // Policy::template MakeShuffledBRegTileDistribution()); - // transpose_tile2d(b_shuffle_tmp, - // b_block_tiles.get(number{})); - // Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - // } - // else - // { - // Base::LocalPrefill(b_copy_lds_window, - // b_block_tiles.get(number{})); - // } - // }); - - // block_sync_lds(); - // block_gemm(c_block_tile, - // aq_block_tiles.get(number{}), - // a_lds_gemm_window, - // b_lds_gemm_window); - // }; - - // if constexpr(TailNum == TailNumber::One) - // { - // block_sync_lds(); - // block_gemm( - // c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, - // b_lds_gemm_window); - // } - // else if constexpr(TailNum == TailNumber::Two) - // { - // HotLoopTail(number<2>{}); - // } - // else if constexpr(TailNum == TailNumber::Three) - // { - // HotLoopTail(number<3>{}); - // } - // else if constexpr(TailNum == TailNumber::Four) - // { - // HotLoopTail(number<4>{}); - // } - // else if constexpr(TailNum == TailNumber::Five) - // { - // HotLoopTail(number<5>{}); - // } - // else if constexpr(TailNum == TailNumber::Six) - // { - // HotLoopTail(number<6>{}); - // } - // else if constexpr(TailNum == TailNumber::Seven) - // { - // HotLoopTail(number<7>{}); - // } - // else if constexpr(TailNum == TailNumber::Full) - // { - // HotLoopTail(number{}); - // } - // return c_block_tile; - // } - // }; + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + [[maybe_unused]] index_t m, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto& a_copy_dram_window = a_windows.at(I0{}); + auto& a_copy_lds_window = a_windows.at(I1{}); + auto& a_lds_gemm_window = a_windows.at(I2{}); + + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto& b_copy_dram_window = b_windows.at(I0{}); + auto& b_copy_lds_window = b_windows.at(I1{}); + auto& b_lds_gemm_window = b_windows.at(I2{}); + + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + tuple_array a_block_tiles; + tuple_array b_block_tiles; + tuple_array aq_block_tiles; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + LoadAndConvertATile( + a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + } + + static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tiles.get(number{}), + aq_copy_dram_window, + aq_dram_tile_window_step); + }); + + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d( + a_shuffle_tmp, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d( + b_shuffle_tmp, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + } + + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tiles.get(number{}), + aq_copy_dram_window, + aq_dram_tile_window_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + auto HotLoopTail = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{})); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{})); + } + }); + + block_sync_lds(); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm( + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); + } + return c_block_tile; + } + }; template Date: Mon, 19 Jan 2026 14:32:46 +0000 Subject: [PATCH 3/8] WIP: quantization works for subset of problem shapes --- .../38_block_scale_gemm/CMakeLists.txt | 2 +- .../gemm_aquant_quantgrouped.cpp | 5 +- .../38_block_scale_gemm/gemm_quant.cpp | 6 +- .../38_block_scale_gemm/gemm_utils.hpp | 34 ----- .../run_gemm_quant_example.inc | 97 +++++++++++++ .../block_universal_gemm_as_aquant_bs_cr.hpp | 127 ++++++++++++------ 6 files changed, 186 insertions(+), 85 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index ec536f72878..f007806eb28 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -12,7 +12,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} gemm_quant.cpp - gemm_abquant_quantgrouped.cpp + # gemm_abquant_quantgrouped.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ca5f75cee5c..ad1a4e0d100 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -3,11 +3,8 @@ #include "run_gemm_quant_example.inc" -// template -// using GemmConfig = GemmConfigQuantDecode; - template -using GemmConfig = GemmConfigQuantInterwave; +using GemmConfig = GemmConfigQuantDecode; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 8de58b0a309..2bb7a0725cf 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -95,8 +95,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) return hash_multiple_strings(params); } -void abquant_quantgrouped_instance_factory( - std::unordered_map>& lut); +// void abquant_quantgrouped_instance_factory( +// std::unordered_map>& lut); void aquant_quantgrouped_instance_factory( std::unordered_map>& lut); void aquant_quantgrouped_preshufflequant_instance_factory( @@ -154,7 +154,7 @@ int main(int argc, char* argv[]) ck_tile::hip_check_error(hipSetDevice(device_id)); std::unordered_map> lut; - abquant_quantgrouped_instance_factory(lut); + // abquant_quantgrouped_instance_factory(lut); aquant_quantgrouped_instance_factory(lut); aquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_fp8_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 06fdbab4f59..57755ae28ac 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -93,40 +93,6 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); -}; - -template -struct GemmConfigQuantIntrawave : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); -}; - -template -struct GemmConfigQuantInterwave : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 63d94a0ba94..2c95c658d6b 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -659,6 +659,103 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } + else if(init_method == 3) // constant initialization + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + } + } + else if(init_method == 4) // uniform distribution + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index a34a854218e..cd9a9a877b6 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -405,7 +405,7 @@ struct AQuantBlockUniversalGemmAsBsCr b_warp_tile_, b_lds_gemm_window); } - // C += A * B (quantization/scaling paths are intentionally commented for now) + // C += A * B with quantization support template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - [[maybe_unused]] AQBlockTensor& aq_block_tensor, + AQBlockTensor& aq_block_tensor, const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window, bool_constant a_load_tr = {}, @@ -422,48 +422,66 @@ struct AQuantBlockUniversalGemmAsBsCr static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as corresponding " "C block tensor data type!"); - // constexpr auto warp_size = get_warp_size(); - - static_for<0, KRepeat, 1>{}([&](auto kIter) { - if(get_thread_id() == 0 && get_block_id() == 0) - { - printf("kIter: %d\n", kIter.value); - } - LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); - __builtin_amdgcn_sched_barrier(0); + constexpr auto warp_size = get_warp_size(); + + // Track which KRepeat chunk is currently loaded + index_t current_k_repeat_loaded = -1; + + // Restructured loop: M → N → QScale → KIterPerQScale + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Iterate over quantization groups + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + CWarpTensor c_warp_tensor; + + // Accumulate K iterations for this quantization group + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + // Map quantization indices to global K iteration + constexpr auto kIterGlobal = + kQScale * Traits::KIterPerQScale + kIterInQScale; - if constexpr(kIter.value != 0 || KRepeat == 1) - { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); - } + // Map to KRepeat chunk and KInnerLoopIter offset + constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter; + constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter; - static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // Prefetch new chunk if needed + if constexpr(kInnerIdx == 0) + { + if(current_k_repeat_loaded != kRepeatIdx) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + __builtin_amdgcn_sched_barrier(0); + + if constexpr(kRepeatIdx != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + current_k_repeat_loaded = kRepeatIdx; + } + } - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Load A warp tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // Load B warp tensor BWarpTensor b_warp_tensor; b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, + merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // Quantization accumulation path (AQPicker, qscale application) to be - // re-enabled later. - - if constexpr(kIter.value == KRepeat - 1 && - kInnerIter.value == KInnerLoopIter - 1 && + // Synchronization barrier at the end of last iteration + if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 && + kIterInQScale == Traits::KIterPerQScale - 1 && mIter.value == MIterPerWarp - 1 && nIter.value == NIterPerWarp - 1) { @@ -472,24 +490,47 @@ struct AQuantBlockUniversalGemmAsBsCr __builtin_amdgcn_sched_barrier(0); } - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // Accumulate: first iteration initializes, rest accumulate + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } - if constexpr(kInnerIter.value == 0 && mIter.value == 0 && - nIter.value == 0) + // Set priority for scheduling + if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0) { __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(1); __builtin_amdgcn_sched_barrier(0); } }); + + // Apply quantization scale after accumulating all K iterations for this + // group + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + AQPickerCommon aq_picker( + aq_block_tensor); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); + + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); }); }); + // Reset scheduling priority after completing M iteration __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_sched_barrier(0); From d10a5c8ebe6b736fa252ede3c317f787b09d4e85 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Mon, 19 Jan 2026 14:32:46 +0000 Subject: [PATCH 4/8] WIP: quantization works for subset of problem shapes --- .../block/block_universal_gemm_as_aquant_bs_cr.hpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index cd9a9a877b6..1c094c9dc5d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -276,6 +276,7 @@ struct AQuantBlockUniversalGemmAsBsCr // for every column in AQ static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; @@ -328,17 +329,12 @@ struct AQuantBlockUniversalGemmAsBsCr { static constexpr index_t KPerThread = GemmTraits::KPerThread; static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; - // static constexpr index_t QuantGroupSizeK = GemmTraits::QuantGroupSize::kK; - // Match the base Interwave loop structure; quantization handling will be reintroduced - // later. static constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; - // static constexpr index_t KIterPerQScale = GemmTraits::KIterPerQScale; - static constexpr auto ALdsTileDistr = make_static_tile_distribution(MakeABlockDistributionEncode()); static constexpr auto BLdsTileDistr = From 361dae1a9717e649ba5ea95fd1200f1a00f3766c Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 20 Jan 2026 13:28:29 +0000 Subject: [PATCH 5/8] WIP: interwave memory pipeline passes local test --- .../38_block_scale_gemm/gemm_utils.hpp | 2 + .../run_gemm_quant_example.inc | 83 ++++++++++++++++++- .../block_universal_gemm_as_aquant_bs_cr.hpp | 5 +- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 53 ++++++++++-- 4 files changed, 136 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 57755ae28ac..7a9a172ac81 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -209,6 +209,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 2c95c658d6b..8965e552604 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -594,7 +594,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); } - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}( + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } @@ -756,6 +756,87 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, *bq_tensor_ptr); } } + else if(init_method == 5) // uniform distribution + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); + } + // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) + for(ck_tile::index_t row = 0; + row < static_cast(aq_tensor_ptr->get_length(0)); + ++row) + { + for(ck_tile::index_t col = 0; + col < static_cast(aq_tensor_ptr->get_length(1)); + ++col) + { + (*aq_tensor_ptr)(row, col) = static_cast(col + 1); + } + } + // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 1c094c9dc5d..05df65ea097 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -519,7 +519,10 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( [&](auto c_row) { float scale_reg_f = aq_picker.template pick(); - + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("scale_reg_f: %f\n", scale_reg_f); + // } c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 4c251187f51..8b4b8b11a9e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -520,6 +520,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem index_t num_loop, void* p_smem) const { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("HasHotLoop: %d\n", HasHotLoop); + // } static_assert( std::is_same_v> && std::is_same_v constexpr AQDramTileWindowStep aq_dram_tile_window_step = is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("First prefetch\n"); + // } LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch( @@ -609,6 +617,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem Base::GlobalPrefetch( aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("First prefetch done\n"); + // } + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); if constexpr(is_a_col_major && !is_a_load_tr_v()) @@ -620,6 +633,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } else { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("Local prefill A done\n"); + // } Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } if constexpr(is_b_row_major && !is_b_load_tr_v()) @@ -631,10 +648,18 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } else { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("Local prefill B done\n"); + // } Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); } static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("Second Prefetch %d\n", static_cast(prefetch_idx)); + // } LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); @@ -644,6 +669,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem Base::GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("Second Prefetch %d done\n", static_cast(prefetch_idx)); + // } }); if constexpr(HasHotLoop) @@ -707,7 +736,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } auto HotLoopTail = [&](auto tail_num) { - static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("HotLoopTail %d\n", static_cast(tail_num)); + // } + static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { block_sync_lds(); block_gemm(c_block_tile, aq_block_tiles.get(number{}), @@ -718,25 +751,27 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); + transpose_tile2d(a_shuffle_tmp, + a_block_tiles.get(number{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{})); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); + transpose_tile2d(b_shuffle_tmp, + b_block_tiles.get(number{})); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{})); + b_block_tiles.get(number{})); } }); @@ -749,12 +784,20 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem if constexpr(TailNum == TailNumber::One) { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("TailNum: One\n"); + // } block_sync_lds(); block_gemm( c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); } else if constexpr(TailNum == TailNumber::Two) { + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("TailNum: Two\n"); + // } HotLoopTail(number<2>{}); } else if constexpr(TailNum == TailNumber::Three) From d0ff0db11e2bf48fa03549edf78209f229e15c40 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 20 Jan 2026 16:35:23 +0000 Subject: [PATCH 6/8] feat: Add interwave pipeline implementation for memory pipline in aquant --- example/ck_tile/38_block_scale_gemm/CMakeLists.txt | 2 +- example/ck_tile/38_block_scale_gemm/gemm_quant.cpp | 6 +++--- example/ck_tile/38_block_scale_gemm/gemm_utils.hpp | 4 ++-- .../38_block_scale_gemm/run_gemm_quant_example.inc | 12 ++++++------ 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index f007806eb28..ec536f72878 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -12,7 +12,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} gemm_quant.cpp - # gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 2bb7a0725cf..8de58b0a309 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -95,8 +95,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) return hash_multiple_strings(params); } -// void abquant_quantgrouped_instance_factory( -// std::unordered_map>& lut); +void abquant_quantgrouped_instance_factory( + std::unordered_map>& lut); void aquant_quantgrouped_instance_factory( std::unordered_map>& lut); void aquant_quantgrouped_preshufflequant_instance_factory( @@ -154,7 +154,7 @@ int main(int argc, char* argv[]) ck_tile::hip_check_error(hipSetDevice(device_id)); std::unordered_map> lut; - // abquant_quantgrouped_instance_factory(lut); + abquant_quantgrouped_instance_factory(lut); aquant_quantgrouped_instance_factory(lut); aquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_fp8_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 7a9a172ac81..e26ffa0fcbe 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -94,7 +94,7 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -210,7 +210,7 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 8965e552604..6f9b984ec17 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -557,7 +557,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, std::mt19937 gen(rd()); std::uniform_int_distribution fill_seed(0, 500); - if(init_method == 0) // uniform distribution + if(init_method == 0) { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -627,12 +627,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, *bq_tensor_ptr); } } - else if(init_method == 1) // monotonic initialization + else if(init_method == 1) { std::cout << "Monotonic initialization is not supported." << std::endl; return 0; } - else if(init_method == 2) // constant initialization + else if(init_method == 2) { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -659,7 +659,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 3) // constant initialization + else if(init_method == 3) { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -686,7 +686,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 4) // uniform distribution + else if(init_method == 4) { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -756,7 +756,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, *bq_tensor_ptr); } } - else if(init_method == 5) // uniform distribution + else if(init_method == 5) { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { From 35486bb3a8373bd5462889dcabc4e42270ab446f Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 20 Jan 2026 20:30:59 +0000 Subject: [PATCH 7/8] fix: compilation error on gfx950 --- .../gemm_aquant_quantgrouped.cpp | 2 +- .../38_block_scale_gemm/gemm_utils.hpp | 19 +++++++++++++++++++ .../block_universal_gemm_as_aquant_bs_cr.hpp | 13 +++++++++++-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d100..e037be5a183 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantDecode; +using GemmConfig = GemmConfigQuantDecodeInterwave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index e26ffa0fcbe..9caf413f501 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -97,6 +97,25 @@ struct GemmConfigQuantDecode : public GemmConfigBase // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; +template +struct GemmConfigQuantDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + template struct GemmConfigRowColQuant : public GemmConfigBase { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 05df65ea097..1f2855e8545 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -544,7 +544,8 @@ struct AQuantBlockUniversalGemmAsBsCr MakeCBlockTile(); } - template @@ -553,7 +554,15 @@ struct AQuantBlockUniversalGemmAsBsCr bool_constant a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + if constexpr(Scheduler == GemmPipelineScheduler::Interwave) + { + block_gemm_impl_.template LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } + else + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } } // C += A * B From a4b5fcc53c9dc8cd86ce15e15c6c84d5abd51c5f Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 21 Jan 2026 11:52:13 +0000 Subject: [PATCH 8/8] chore: remove debug statements from the code --- .../38_block_scale_gemm/gemm_utils.hpp | 4 - .../run_gemm_quant_example.inc | 178 ------------------ .../block_universal_gemm_as_aquant_bs_cr.hpp | 5 - .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 41 ---- 4 files changed, 228 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 9caf413f501..2cc7a79c552 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -93,8 +93,6 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); - - // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -228,8 +226,6 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); - - // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 6f9b984ec17..e73248e7a49 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -659,184 +659,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 3) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else - { - ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - } - } - else if(init_method == 4) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - } - ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } - else if(init_method == 5) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); - } - // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) - for(ck_tile::index_t row = 0; - row < static_cast(aq_tensor_ptr->get_length(0)); - ++row) - { - for(ck_tile::index_t col = 0; - col < static_cast(aq_tensor_ptr->get_length(1)); - ++col) - { - (*aq_tensor_ptr)(row, col) = static_cast(col + 1); - } - } - // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 1f2855e8545..9d19e902e5c 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -519,17 +519,12 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( [&](auto c_row) { float scale_reg_f = aq_picker.template pick(); - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("scale_reg_f: %f\n", scale_reg_f); - // } c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); }); }); }); - // Reset scheduling priority after completing M iteration __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 8b4b8b11a9e..442d1d4ae14 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -520,10 +520,6 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem index_t num_loop, void* p_smem) const { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("HasHotLoop: %d\n", HasHotLoop); - // } static_assert( std::is_same_v> && std::is_same_v constexpr AQDramTileWindowStep aq_dram_tile_window_step = is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("First prefetch\n"); - // } LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch( @@ -617,11 +609,6 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem Base::GlobalPrefetch( aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("First prefetch done\n"); - // } - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); if constexpr(is_a_col_major && !is_a_load_tr_v()) @@ -633,10 +620,6 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } else { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("Local prefill A done\n"); - // } Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } if constexpr(is_b_row_major && !is_b_load_tr_v()) @@ -648,18 +631,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } else { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("Local prefill B done\n"); - // } Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); } static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("Second Prefetch %d\n", static_cast(prefetch_idx)); - // } LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); @@ -669,10 +644,6 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem Base::GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("Second Prefetch %d done\n", static_cast(prefetch_idx)); - // } }); if constexpr(HasHotLoop) @@ -736,10 +707,6 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } auto HotLoopTail = [&](auto tail_num) { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("HotLoopTail %d\n", static_cast(tail_num)); - // } static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { block_sync_lds(); block_gemm(c_block_tile, @@ -784,20 +751,12 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem if constexpr(TailNum == TailNumber::One) { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("TailNum: One\n"); - // } block_sync_lds(); block_gemm( c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); } else if constexpr(TailNum == TailNumber::Two) { - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("TailNum: Two\n"); - // } HotLoopTail(number<2>{}); } else if constexpr(TailNum == TailNumber::Three)