diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 8eff0e7469f..337e3d268d5 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -241,6 +241,31 @@ struct GemmConfigComputeV6 : public GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; }; +template +struct GemmConfigComputeV7 : public GemmConfigBase +{ + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V7; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmConfigPreshuffleDecode : public GemmConfigBase { @@ -423,6 +448,15 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV7; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV7; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index af0f81e832c..8d327b64036 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -63,6 +63,22 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, tile_window, elementwise, number{}, bool_constant{}); } +template +CK_TILE_DEVICE auto +load_tile_with_elementwise_vectorload1(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) +{ + // TODO: Tile windows should works with unknow number of params + // Load element_wise API works only when the input typle is a tuple-type + return tile_window[number<0>{}].load_vectorload1( + tile_window, elementwise, number{}, bool_constant{}); +} + // Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. template + CK_TILE_DEVICE auto load_vectorload1(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load_vectorload1(dst_tensor, + tile_window, + elementwise, + number{}, + bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE void load_vectorload1(DistributedTensor& dst_tensor, + const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + using OldTraits = typename Base::Traits; + using Traits = typename Base::TraitsVectorload1; + using vector_t = thread_buffer; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = typename Base::TileDstr{}; + constexpr auto sizeOfTuple = TileWindow_::size(); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord * OldTraits::ScalarPerVector, 1>{}( + [&](auto iCoordAccess) { + constexpr auto iAccess = + number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const auto idx_vec_value = generate_tuple( + [&](auto jj) { + return tile_window[number{}] + .get_bottom_tensor_view() + .template get_vectorized_elements( + bottom_tensor_thread_coord, + 0, // linear offset + bool_constant{}); + }, + number{}); + + static_for<0, 1, Traits::PackedSize>{}([&](auto j) { + // write into distributed tensor + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + ck_tile::apply( + [&](auto&&... t) { + elementwise(dst_tensor.get_thread_buffer().template at(), + t.template get_as()[0]...); + }, + idx_vec_value); + }); + // move thread coordinate + if constexpr(iCoordAccess != + (NumAccessPerCoord * OldTraits::ScalarPerVector - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(generate_tuple([&](auto) { return number<0>{}; }, + number{}), + idx_diff_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, + bottom_tensor_thread_coord, + idx_diff_ps_ys); + } + }); + }); + } + template diff --git a/include/ck_tile/core/tensor/tile_window_base.hpp b/include/ck_tile/core/tensor/tile_window_base.hpp index 2f5eaf40b14..c7bcce3963b 100644 --- a/include/ck_tile/core/tensor/tile_window_base.hpp +++ b/include/ck_tile/core/tensor/tile_window_base.hpp @@ -206,6 +206,68 @@ struct tile_window_with_tile_dstr_base static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); }; + struct TraitsVectorload1 + { + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = 1; + using vector_t = + thread_buffer; + + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto thread_tensor_lengths_ys = + to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + }; + // return vector dimension among [y0, y1, ...] CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2c3a1611216..a8a033e65b5 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -49,6 +49,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v7.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 628f5f7dc89..baf492b04aa 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -421,6 +421,19 @@ struct UniversalGemmKernel return false; } + if(GemmPipeline::GetPipelineName() == "COMPUTE_V7") + { + if(GemmPipeline::kPadK == false || GemmPipeline::kPadM == false || + GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Compute pipeline v7 needs all paddings enabled!"); + } + return false; + } + } + const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() : GemmPipeline::template GetVectorSizeA(); bool AsTensorIsValid = {true}; @@ -439,7 +452,7 @@ struct UniversalGemmKernel } AsTensorIsValid = false; } - if(kargs.K % vectorSizeA != 0) + if(kargs.K % vectorSizeA != 0 && GemmPipeline::GetPipelineName() != "COMPUTE_V7") { const auto remainder = kargs.K % vectorSizeA; constexpr ck_tile::index_t APackedSize = @@ -471,7 +484,7 @@ struct UniversalGemmKernel } AsTensorIsValid = false; } - if(kargs.M % vectorSizeA != 0) + if(kargs.M % vectorSizeA != 0 && GemmPipeline::GetPipelineName() != "COMPUTE_V7") { const auto remainder = kargs.M % vectorSizeA; constexpr ck_tile::index_t APackedSize = @@ -511,7 +524,7 @@ struct UniversalGemmKernel } BsTensorIsValid = false; } - if(kargs.N % vectorSizeB != 0) + if(kargs.N % vectorSizeB != 0 && GemmPipeline::GetPipelineName() != "COMPUTE_V7") { const auto remainder = kargs.N % vectorSizeB; constexpr ck_tile::index_t BPackedSize = @@ -544,7 +557,8 @@ struct UniversalGemmKernel } BsTensorIsValid = false; } - if(kargs.K % vectorSizeB != 0) + if(kargs.K % vectorSizeB != 0 && + GemmPipeline::GetPipelineName() != "COMPUTE_V7") { const auto remainder = kargs.K % vectorSizeB; constexpr ck_tile::index_t BPackedSize = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v7.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v7.hpp new file mode 100644 index 00000000000..06f95ce5bce --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v7.hpp @@ -0,0 +1,780 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV7 +{ + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST_DEVICE static constexpr TailNumber + GetBlockLoopTailNum([[maybe_unused]] index_t num_loop) + { + return TailNumber::Odd; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, + bool has_hot_loop, + [[maybe_unused]] TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +}; + +// Compute optimized pipeline +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 +template +struct GemmPipelineAgBgCrCompV7 : public BaseGemmPipelineAgBgCrCompV7 +{ + using Base = BaseGemmPipelineAgBgCrCompV7; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + + using Base::PrefetchStages; + using Base::UsePersistentKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V7"; + // clang-format on + } + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "pipeline_AgBgCrCompV7", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + auto first_a_load_tile_with_elementwise = [&]() { + if constexpr(HasHotLoop) + return load_tile_with_elementwise(a_copy_dram_window, a_element_func); + else + return load_tile_with_elementwise_vectorload1(a_copy_dram_window, + a_element_func); + }(); + + auto first_b_load_tile_with_elementwise = [&]() { + if constexpr(HasHotLoop) + return load_tile_with_elementwise(b_copy_dram_window, b_element_func); + else + return load_tile_with_elementwise_vectorload1(b_copy_dram_window, + b_element_func); + }(); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = first_a_load_tile_with_elementwise; + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = first_b_load_tile_with_elementwise; + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // LDS write 0 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); + } + + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 1; + while(i < (num_loop - 1)) + { + elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + block_sync_lds(); + + elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); + } + + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } + + elementwise_As_res = + load_tile_with_elementwise_vectorload1(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + block_sync_lds(); + + elementwise_Bs_res = + load_tile_with_elementwise_vectorload1(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); + } + + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + + // tail + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + // __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + + /** + * @brief This function runs the pipeline by wrapping it with the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + constexpr bool hot_loop = hot_loop_.value; + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_dram_block_window_tmp, + PassThrough, + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + return operator()(a_dram_block_window_tmp, + b_dram_block_window_tmp, + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline by wrapping it with + * the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline using compile-time + * known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp index b8ba584ef8f..9e88a554c2a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp @@ -12,6 +12,7 @@ enum struct GemmPipeline COMPUTE_V4, COMPUTE_V5, COMPUTE_V6, + COMPUTE_V7, MEMORY, BASIC_V1, BASIC_V2,