From 90425564221683154de51e294e95594e4837e71f Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Mon, 19 Jan 2026 23:27:21 +0000 Subject: [PATCH 1/7] [CK TILE] Fix basic pipelines --- .../20_grouped_convolution/conv_configs.hpp | 18 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 401 +++++++++--------- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 360 ++++++++-------- 3 files changed, 382 insertions(+), 397 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 620b5058202..847030fffb4 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -257,6 +257,24 @@ struct ConvTypeConfig template struct PipelineTypeTraits; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV2; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV2; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 936c38ddf33..ab1fca85495 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -39,6 +39,8 @@ struct BaseGemmPipelineAGmemBGmemCRegV1 template struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 { + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -121,228 +123,207 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1(); } - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const + struct PipelineImpl : public PipelineImplBase { - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr bool is_a_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), - kLdsAlignmentInBytes) * - kLdsAlignmentInBytes; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto as_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // B DRAM tile window for load - auto bs_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // B LDS tile window for store - auto b_copy_lds_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_lds_load_tile_distr); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_lds_load_tile_distr); - - // Block GEMM - auto block_gemm = BlockGemm(); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - + using Base = GemmPipelineAgBgCrImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const { - // move to 1 - // Move each A — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - // Move each B — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp); - } - else - { - store_tile(a_copy_lds_window, elementwise_As_res); - } + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - // LDS write 0 - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp); - } - else { - store_tile(b_copy_lds_window, elementwise_Bs_res); + // move to 1 + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write 0 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } } - } - - index_t iCounter = num_loop - 1; - while(iCounter > 0) - { - // global read i + 1 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - if constexpr(is_a_col_major) + index_t iCounter = num_loop - 1; + while(iCounter > 0) { - auto a_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp_loop); - } - else - { - store_tile(a_copy_lds_window, elementwise_As_res); + // global read i + 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + block_sync_lds(); + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write i + 1 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } + + iCounter--; } - // LDS write i + 1 - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp_loop); - } - else + // tail { - store_tile(b_copy_lds_window, elementwise_Bs_res); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } - iCounter--; + return c_block_tile; } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - - return c_block_tile; - } - + }; template ::value && @@ -353,7 +334,7 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2 { + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -56,6 +58,8 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2>; using BDataType = remove_cvref_t>; + using BlockGemm = remove_cvref_t())>; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = @@ -127,205 +131,187 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2(); } - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const + struct PipelineImpl : public PipelineImplBase { - - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil( - sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) * - 16; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto as_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - as_copy_dram_window[number<0>{}].get_tile_distribution()); - - // B DRAM tile window for load - auto bs_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - bs_copy_dram_window[number<0>{}].get_tile_distribution()); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode()); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_lds_load_tile_distr); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_lds_load_tile_distr); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - { - // move to 1 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - store_tile(a_copy_lds_window, elementwise_As_res); - // global read 1 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // LDS write 0 - store_tile(b_copy_lds_window, elementwise_Bs_res); - // global read 1 - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - store_tile(a_copy_lds_window, elementwise_As_res); - // global read i + 2 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // LDS write i + 1 - store_tile(b_copy_lds_window, elementwise_Bs_res); - // global read i + 2 - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - iCounter--; - - } while(iCounter > 0); - - // tail + using Base = GemmPipelineAgBgCrImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const { - block_sync_lds(); - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == + BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / + APackedSize, + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + { + // move to 1 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + store_tile(a_copy_lds_window, elementwise_As_res); + // global read 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // LDS write 0 + store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read 1 + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - block_sync_lds(); + // LDS write i + 1 + store_tile(a_copy_lds_window, elementwise_As_res); + // global read i + 2 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); - // LDS write num_loop - 1 - store_tile(a_copy_lds_window, elementwise_As_res); + // LDS write i + 1 + store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read i + 2 + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - store_tile(b_copy_lds_window, elementwise_Bs_res); + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + // LDS write num_loop - 1 + store_tile(a_copy_lds_window, elementwise_As_res); + + store_tile(b_copy_lds_window, elementwise_Bs_res); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; } + }; - return c_block_tile; - } - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { - return operator()( + return PipelineImpl{}.operator()( a_dram_block_window_tmp, [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, From a616e872addeae7dd72fcb0b83a98c3f164ab433 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 20 Jan 2026 11:24:57 +0000 Subject: [PATCH 2/7] fixes --- .../ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp | 4 ++-- .../ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index ab1fca85495..f58bc456d6b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -125,7 +125,7 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1; + using Base = PipelineImplBase; template (p_b_lds, b_lds_block_desc); - // // Tile distribution for load from lds + // Tile distribution for load from lds constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); constexpr auto b_lds_load_tile_distr = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 0bab8d11191..35ae2085ca6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -133,7 +133,7 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2; + using Base = PipelineImplBase; template (p_b_lds, b_lds_block_desc); - // // Tile distribution for load from lds + // Tile distribution for load from lds constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); constexpr auto b_lds_load_tile_distr = From f95333bd9cf8bf2ce42e17cf10a3583d22a7f8df Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 20 Jan 2026 18:46:46 -0500 Subject: [PATCH 3/7] fix basic example --- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index eb8e565a60c..cc8fad1fad4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -360,6 +360,28 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.operator()(a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } }; } // namespace ck_tile From 9ba455d5c013c63277fd42f5eafacf22cfec5933 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 21 Jan 2026 12:16:41 +0100 Subject: [PATCH 4/7] Add interwave pipeline --- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 229 ++++++++++++++++-- 1 file changed, 206 insertions(+), 23 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index cc8fad1fad4..3454c2aa81b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -125,7 +125,13 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1(); } + template struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase { using Base = PipelineImplBase; @@ -326,6 +332,205 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + { + // move to 1 + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write 0 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } + } + + index_t iCounter = num_loop - 1; + while(iCounter > 0) + { + // global read i + 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + block_sync_lds(); + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write i + 1 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } + + iCounter--; + } + + // tail + { + block_sync_lds(); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; + } + }; + template ::value && @@ -336,7 +541,7 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1{}.operator()( a_dram_block_window_tmp, [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, @@ -360,28 +565,6 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - return PipelineImpl{}.operator()(a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); - } }; } // namespace ck_tile From 11d310b0c634c9aa94674851a501748b6b9e7d69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 21 Jan 2026 12:18:13 +0100 Subject: [PATCH 5/7] Update gemm_pipeline_agmem_bgmem_creg_v1.hpp --- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 3454c2aa81b..442b87280ac 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -565,6 +565,29 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.operator()(a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + }; } // namespace ck_tile From 387f3a993224f2aca369c3d04abeb30cd321dd63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 21 Jan 2026 12:18:57 +0100 Subject: [PATCH 6/7] Update gemm_pipeline_agmem_bgmem_creg_v1.hpp --- .../ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 442b87280ac..9fe1f4dbc0a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -587,7 +587,6 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 Date: Wed, 21 Jan 2026 16:55:39 +0100 Subject: [PATCH 7/7] Update gemm_pipeline_agmem_bgmem_creg_v1.hpp --- .../ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 9fe1f4dbc0a..7885b98d5cd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -580,7 +580,7 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1{}.operator()(a_dram_block_window_tmp, a_element_func, b_dram_block_window_tmp, b_element_func,