diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d10..e037be5a18 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantDecode; +using GemmConfig = GemmConfigQuantDecodeInterwave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5b..2cc7a79c55 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -95,6 +95,25 @@ struct GemmConfigQuantDecode : public GemmConfigBase ck_tile::get_k_warp_tile(); }; +template +struct GemmConfigQuantDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + template struct GemmConfigRowColQuant : public GemmConfigBase { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9af..e73248e7a4 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else { ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(1.0f)}(*aq_tensor_ptr); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b52..9d19e902e5 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { CWarpTensor c_warp_tensor; + // for every column in AQ static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; @@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr } }; + template + struct BlockGemmImpl + { + static constexpr index_t KPerThread = GemmTraits::KPerThread; + static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + + static constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; + + static constexpr auto ALdsTileDistr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + static constexpr auto BLdsTileDistr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) + { + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + + auto a_lds_gemm_window = make_tile_window( + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); + auto b_lds_gemm_window = make_tile_window( + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); + + load_int4_tile( + a_warp_tile_, a_lds_gemm_window); + load_int4_tile( + b_warp_tile_, b_lds_gemm_window); + } + + // C += A * B with quantization support + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as corresponding " + "C block tensor data type!"); + constexpr auto warp_size = get_warp_size(); + + // Track which KRepeat chunk is currently loaded + index_t current_k_repeat_loaded = -1; + + // Restructured loop: M → N → QScale → KIterPerQScale + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Iterate over quantization groups + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + CWarpTensor c_warp_tensor; + + // Accumulate K iterations for this quantization group + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + // Map quantization indices to global K iteration + constexpr auto kIterGlobal = + kQScale * Traits::KIterPerQScale + kIterInQScale; + + // Map to KRepeat chunk and KInnerLoopIter offset + constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter; + constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter; + + // Prefetch new chunk if needed + if constexpr(kInnerIdx == 0) + { + if(current_k_repeat_loaded != kRepeatIdx) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + __builtin_amdgcn_sched_barrier(0); + + if constexpr(kRepeatIdx != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + current_k_repeat_loaded = kRepeatIdx; + } + } + + // Load A warp tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // Load B warp tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Synchronization barrier at the end of last iteration + if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 && + kIterInQScale == Traits::KIterPerQScale - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // Accumulate: first iteration initializes, rest accumulate + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Set priority for scheduling + if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // Apply quantization scale after accumulating all K iterations for this + // group + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + AQPickerCommon aq_picker( + aq_block_tensor); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + }; + public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { @@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr MakeCBlockTile(); } - template @@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr bool_constant a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + if constexpr(Scheduler == GemmPipelineScheduler::Interwave) + { + block_gemm_impl_.template LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } + else + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 650cd947f7..442d1d4ae1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -486,6 +486,307 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + [[maybe_unused]] index_t m, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto& a_copy_dram_window = a_windows.at(I0{}); + auto& a_copy_lds_window = a_windows.at(I1{}); + auto& a_lds_gemm_window = a_windows.at(I2{}); + + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto& b_copy_dram_window = b_windows.at(I0{}); + auto& b_copy_lds_window = b_windows.at(I1{}); + auto& b_lds_gemm_window = b_windows.at(I2{}); + + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + tuple_array a_block_tiles; + tuple_array b_block_tiles; + tuple_array aq_block_tiles; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + LoadAndConvertATile( + a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + } + + static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tiles.get(number{}), + aq_copy_dram_window, + aq_dram_tile_window_step); + }); + + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d( + a_shuffle_tmp, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d( + b_shuffle_tmp, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + } + + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tiles.get(number{}), + aq_copy_dram_window, + aq_dram_tile_window_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + auto HotLoopTail = [&](auto tail_num) { + static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, + a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{})); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, + b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{})); + } + }); + + block_sync_lds(); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm( + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); + } + return c_block_tile; + } + }; + template