diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 620b5058202..847030fffb4 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -257,6 +257,24 @@ struct ConvTypeConfig template struct PipelineTypeTraits; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV2; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV2; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 9b7213837a1..cc8fad1fad4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -39,6 +39,8 @@ struct BaseGemmPipelineAGmemBGmemCRegV1 template struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 { + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -123,228 +125,207 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1(); } - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const + struct PipelineImpl : public PipelineImplBase { - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr bool is_a_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), - kLdsAlignmentInBytes) * - kLdsAlignmentInBytes; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto as_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // B DRAM tile window for load - auto bs_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // B LDS tile window for store - auto b_copy_lds_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_lds_load_tile_distr); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_lds_load_tile_distr); - - // Block GEMM - auto block_gemm = BlockGemm(); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - + using Base = PipelineImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const { - // move to 1 - // Move each A — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - // Move each B — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp); - } - else - { - store_tile(a_copy_lds_window, elementwise_As_res); - } + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - // LDS write 0 - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp); - } - else { - store_tile(b_copy_lds_window, elementwise_Bs_res); + // move to 1 + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write 0 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } } - } - - index_t iCounter = num_loop - 1; - while(iCounter > 0) - { - // global read i + 1 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - // LDS write i + 1 - if constexpr(is_a_col_major) + index_t iCounter = num_loop - 1; + while(iCounter > 0) { - auto a_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp_loop); - } - else - { - store_tile(a_copy_lds_window, elementwise_As_res); + // global read i + 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + block_sync_lds(); + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write i + 1 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } + + iCounter--; } - // LDS write i + 1 - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp_loop); - } - else + // tail { - store_tile(b_copy_lds_window, elementwise_Bs_res); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } - iCounter--; - } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + return c_block_tile; } - - return c_block_tile; - } - + }; template ::value && @@ -355,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.operator()(a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index c711c768ec8..35ae2085ca6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -38,6 +38,8 @@ struct BaseGemmPipelineAGmemBGmemCRegV2 template struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2 { + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -56,6 +58,8 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2>; using BDataType = remove_cvref_t>; + using BlockGemm = remove_cvref_t())>; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = @@ -127,205 +131,187 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2(); } - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const + struct PipelineImpl : public PipelineImplBase { - - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil( - sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) * - 16; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto as_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - as_copy_dram_window[number<0>{}].get_tile_distribution()); - - // B DRAM tile window for load - auto bs_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - bs_copy_dram_window[number<0>{}].get_tile_distribution()); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode()); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_lds_load_tile_distr); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_lds_load_tile_distr); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - { - // move to 1 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - store_tile(a_copy_lds_window, elementwise_As_res); - // global read 1 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // LDS write 0 - store_tile(b_copy_lds_window, elementwise_Bs_res); - // global read 1 - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - store_tile(a_copy_lds_window, elementwise_As_res); - // global read i + 2 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // LDS write i + 1 - store_tile(b_copy_lds_window, elementwise_Bs_res); - // global read i + 2 - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - iCounter--; - - } while(iCounter > 0); - - // tail + using Base = PipelineImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const { - block_sync_lds(); - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == + BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / + APackedSize, + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + { + // move to 1 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + store_tile(a_copy_lds_window, elementwise_As_res); + // global read 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // LDS write 0 + store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read 1 + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - block_sync_lds(); + // LDS write i + 1 + store_tile(a_copy_lds_window, elementwise_As_res); + // global read i + 2 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); - // LDS write num_loop - 1 - store_tile(a_copy_lds_window, elementwise_As_res); + // LDS write i + 1 + store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read i + 2 + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - store_tile(b_copy_lds_window, elementwise_Bs_res); + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + // LDS write num_loop - 1 + store_tile(a_copy_lds_window, elementwise_As_res); + + store_tile(b_copy_lds_window, elementwise_Bs_res); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; } + }; - return c_block_tile; - } - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { - return operator()( + return PipelineImpl{}.operator()( a_dram_block_window_tmp, [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp,