Skip to content

Commit 160bd91

Browse files
committed
[WIP] Partial attempt at implementing RunGemm using RunGemmDesc
1 parent 26cdb3e commit 160bd91

1 file changed

Lines changed: 73 additions & 61 deletions

File tree

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -936,75 +936,28 @@ struct UniversalGemmKernel
936936
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
937937
}
938938

939-
/**
940-
* @brief Runs single GEMM problem cooperatively by whole workgroup.
941-
*
942-
* @param as_ptr input As pointer
943-
* @param bs_ptr input Bs pointer
944-
* @param ds_ptr input Ds pointer
945-
* @param e_ptr output E pointer
946-
* @param smem_ptr_0 The start memory pointer of the shared memory block.
947-
* @param kargs GEMM kernel arguments
948-
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
949-
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
950-
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
951-
*
952-
*/
953-
template <bool UseDefaultScheduler = true>
954-
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
955-
const std::array<const BDataType*, NumBTensor>& bs_ptr,
956-
const std::array<const void*, NumDTensor>& ds_ptr,
957-
EDataType* e_ptr,
958-
void* smem_ptr_0,
959-
const KernelArgs& kargs,
960-
const SplitKBatchOffset& splitk_batch_offset,
961-
const index_t block_idx_m,
962-
const index_t block_idx_n)
963-
{
964-
// Create Gemm tensor views, pad views and tile windows
965-
const auto& gemm_tensor_views_tuple =
966-
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
967-
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
968-
969-
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
970-
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
971-
972-
const index_t num_loop =
973-
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
974-
975-
// Run GEMM cooperatively by whole workgroup.
976-
const auto& as_block_window = gemm_tile_windows.at(I0);
977-
const auto& bs_block_window = gemm_tile_windows.at(I1);
978-
const auto& ds_block_window = gemm_tile_windows.at(I2);
979-
980-
const auto& c_block_tile = GemmPipeline{}.template operator()(
981-
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
982-
983-
if(UseDefaultScheduler || (get_warp_id() == 0))
984-
{
985-
// Run Epilogue Pipeline
986-
auto& c_block_window = gemm_tile_windows.at(I3);
987-
988-
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
989-
}
990-
}
991-
992939
// Version of RunGemm using descriptors
993-
template <typename AGridDesc,
994-
typename BGridDesc,
940+
// FIXME: Currently Templated to XsList to allow both arrays and tuples for convenience, which
941+
// doesn't enforce same size nor matching types (as with arrays)
942+
template <typename AsList,
943+
typename BsList,
944+
typename DsList,
945+
typename AGridDescs,
946+
typename BGridDescs,
947+
typename DGridDescs,
995948
typename EGridDesc,
996949
bool UseDefaultScheduler = true>
997-
CK_TILE_DEVICE static void RunGemmDesc(const std::array<const ADataType*, NumATensor>& as_ptr,
998-
const std::array<const BDataType*, NumBTensor>& bs_ptr,
999-
const std::array<const void*, NumDTensor>& ds_ptr,
950+
CK_TILE_DEVICE static void RunGemmDesc(const AsList& as_ptr,
951+
const BsList& bs_ptr,
952+
const DsList& ds_ptr,
1000953
EDataType* e_ptr,
1001954
void* smem_ptr_0,
1002955
const SplitKBatchOffset& splitk_batch_offset,
1003956
const index_t block_idx_m,
1004957
const index_t block_idx_n,
1005-
const std::array<AGridDesc, NumATensor>& as_desc,
1006-
const std::array<BGridDesc, NumBTensor>& bs_desc,
1007-
const std::array<EGridDesc, NumDTensor>& ds_desc,
958+
const AGridDescs& as_desc,
959+
const BGridDescs& bs_desc,
960+
const DGridDescs& ds_desc,
1008961
const EGridDesc& e_desc)
1009962
{
1010963
// Create tensor views from descriptors (supports arbitrary stride patterns)
@@ -1061,6 +1014,65 @@ struct UniversalGemmKernel
10611014
}
10621015
}
10631016

1017+
/**
1018+
* @brief Runs single GEMM problem cooperatively by whole workgroup.
1019+
*
1020+
* @param as_ptr input As pointer
1021+
* @param bs_ptr input Bs pointer
1022+
* @param ds_ptr input Ds pointer
1023+
* @param e_ptr output E pointer
1024+
* @param smem_ptr_0 The start memory pointer of the shared memory block.
1025+
* @param kargs GEMM kernel arguments
1026+
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
1027+
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
1028+
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
1029+
*
1030+
*/
1031+
template <bool UseDefaultScheduler = true>
1032+
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
1033+
const std::array<const BDataType*, NumBTensor>& bs_ptr,
1034+
const std::array<const void*, NumDTensor>& ds_ptr,
1035+
EDataType* e_ptr,
1036+
void* smem_ptr_0,
1037+
const KernelArgs& kargs,
1038+
const SplitKBatchOffset& splitk_batch_offset,
1039+
const index_t block_idx_m,
1040+
const index_t block_idx_n)
1041+
{
1042+
const auto& gemm_tensor_views_tuple =
1043+
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1044+
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
1045+
1046+
// FIXME: Refactor to generate descriptors and views separately, then rework signatures
1047+
// FIXME: pointers need to be extracted as well
1048+
// FIXME: Fails (at least) 1024x1024x256_splitk2 and 1024x1024x256_splitk4 in
1049+
// test_gemm_tile_engine_fp16_rcr_quick_coverage_config_compv3_cshuffle_intrawave_False_False_False_False_32x64x16_2x2x1_16x16x16
1050+
1051+
auto as_desc = generate_tuple(
1052+
[&](auto i) { return gemm_tensor_views_tuple.at(I0)[i].get_tensor_descriptor(); },
1053+
number<NumATensor>{});
1054+
auto bs_desc = generate_tuple(
1055+
[&](auto i) { return gemm_tensor_views_tuple.at(I1)[i].get_tensor_descriptor(); },
1056+
number<NumBTensor>{});
1057+
auto ds_desc = generate_tuple(
1058+
[&](auto i) { return gemm_tensor_views_tuple.at(I2)[i].get_tensor_descriptor(); },
1059+
number<NumDTensor>{});
1060+
auto e_desc = gemm_tensor_views_tuple.at(I3).get_tensor_descriptor();
1061+
1062+
RunGemmDesc(_as_ptr,
1063+
_bs_ptr,
1064+
_ds_ptr,
1065+
_e_ptr,
1066+
smem_ptr_0,
1067+
splitk_batch_offset,
1068+
block_idx_m,
1069+
block_idx_n,
1070+
as_desc,
1071+
bs_desc,
1072+
ds_desc,
1073+
e_desc);
1074+
}
1075+
10641076
/**
10651077
* @brief Runs single GEMM problem cooperatively by whole workgroup.
10661078
*

0 commit comments

Comments
 (0)