@@ -936,75 +936,28 @@ struct UniversalGemmKernel
936936 return make_tuple (as_block_window, bs_block_window, ds_block_window, e_block_window);
937937 }
938938
939- /* *
940- * @brief Runs single GEMM problem cooperatively by whole workgroup.
941- *
942- * @param as_ptr input As pointer
943- * @param bs_ptr input Bs pointer
944- * @param ds_ptr input Ds pointer
945- * @param e_ptr output E pointer
946- * @param smem_ptr_0 The start memory pointer of the shared memory block.
947- * @param kargs GEMM kernel arguments
948- * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
949- * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
950- * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
951- *
952- */
953- template <bool UseDefaultScheduler = true >
954- CK_TILE_DEVICE static void RunGemm (const std::array<const ADataType*, NumATensor>& as_ptr,
955- const std::array<const BDataType*, NumBTensor>& bs_ptr,
956- const std::array<const void *, NumDTensor>& ds_ptr,
957- EDataType* e_ptr,
958- void * smem_ptr_0,
959- const KernelArgs& kargs,
960- const SplitKBatchOffset& splitk_batch_offset,
961- const index_t block_idx_m,
962- const index_t block_idx_n)
963- {
964- // Create Gemm tensor views, pad views and tile windows
965- const auto & gemm_tensor_views_tuple =
966- MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
967- as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k );
968-
969- const auto & gemm_pad_views = MakeGemmPadViews (gemm_tensor_views_tuple);
970- auto gemm_tile_windows = MakeGemmTileWindows (gemm_pad_views, block_idx_m, block_idx_n);
971-
972- const index_t num_loop =
973- amd_wave_read_first_lane (TilePartitioner::GetLoopNum (splitk_batch_offset.splitted_k ));
974-
975- // Run GEMM cooperatively by whole workgroup.
976- const auto & as_block_window = gemm_tile_windows.at (I0);
977- const auto & bs_block_window = gemm_tile_windows.at (I1);
978- const auto & ds_block_window = gemm_tile_windows.at (I2);
979-
980- const auto & c_block_tile = GemmPipeline{}.template operator ()(
981- as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
982-
983- if (UseDefaultScheduler || (get_warp_id () == 0 ))
984- {
985- // Run Epilogue Pipeline
986- auto & c_block_window = gemm_tile_windows.at (I3);
987-
988- EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
989- }
990- }
991-
992939 // Version of RunGemm using descriptors
993- template <typename AGridDesc,
994- typename BGridDesc,
940+ // FIXME: Currently Templated to XsList to allow both arrays and tuples for convenience, which
941+ // doesn't enforce same size nor matching types (as with arrays)
942+ template <typename AsList,
943+ typename BsList,
944+ typename DsList,
945+ typename AGridDescs,
946+ typename BGridDescs,
947+ typename DGridDescs,
995948 typename EGridDesc,
996949 bool UseDefaultScheduler = true >
997- CK_TILE_DEVICE static void RunGemmDesc (const std::array< const ADataType*, NumATensor> & as_ptr,
998- const std::array< const BDataType*, NumBTensor> & bs_ptr,
999- const std::array< const void *, NumDTensor> & ds_ptr,
950+ CK_TILE_DEVICE static void RunGemmDesc (const AsList & as_ptr,
951+ const BsList & bs_ptr,
952+ const DsList & ds_ptr,
1000953 EDataType* e_ptr,
1001954 void * smem_ptr_0,
1002955 const SplitKBatchOffset& splitk_batch_offset,
1003956 const index_t block_idx_m,
1004957 const index_t block_idx_n,
1005- const std::array<AGridDesc, NumATensor> & as_desc,
1006- const std::array<BGridDesc, NumBTensor> & bs_desc,
1007- const std::array<EGridDesc, NumDTensor> & ds_desc,
958+ const AGridDescs & as_desc,
959+ const BGridDescs & bs_desc,
960+ const DGridDescs & ds_desc,
1008961 const EGridDesc& e_desc)
1009962 {
1010963 // Create tensor views from descriptors (supports arbitrary stride patterns)
@@ -1061,6 +1014,65 @@ struct UniversalGemmKernel
10611014 }
10621015 }
10631016
1017+ /* *
1018+ * @brief Runs single GEMM problem cooperatively by whole workgroup.
1019+ *
1020+ * @param as_ptr input As pointer
1021+ * @param bs_ptr input Bs pointer
1022+ * @param ds_ptr input Ds pointer
1023+ * @param e_ptr output E pointer
1024+ * @param smem_ptr_0 The start memory pointer of the shared memory block.
1025+ * @param kargs GEMM kernel arguments
1026+ * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
1027+ * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
1028+ * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
1029+ *
1030+ */
1031+ template <bool UseDefaultScheduler = true >
1032+ CK_TILE_DEVICE static void RunGemm (const std::array<const ADataType*, NumATensor>& as_ptr,
1033+ const std::array<const BDataType*, NumBTensor>& bs_ptr,
1034+ const std::array<const void *, NumDTensor>& ds_ptr,
1035+ EDataType* e_ptr,
1036+ void * smem_ptr_0,
1037+ const KernelArgs& kargs,
1038+ const SplitKBatchOffset& splitk_batch_offset,
1039+ const index_t block_idx_m,
1040+ const index_t block_idx_n)
1041+ {
1042+ const auto & gemm_tensor_views_tuple =
1043+ MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1044+ as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k );
1045+
1046+ // FIXME: Refactor to generate descriptors and views separately, then rework signatures
1047+ // FIXME: pointers need to be extracted as well
1048+ // FIXME: Fails (at least) 1024x1024x256_splitk2 and 1024x1024x256_splitk4 in
1049+ // test_gemm_tile_engine_fp16_rcr_quick_coverage_config_compv3_cshuffle_intrawave_False_False_False_False_32x64x16_2x2x1_16x16x16
1050+
1051+ auto as_desc = generate_tuple (
1052+ [&](auto i) { return gemm_tensor_views_tuple.at (I0)[i].get_tensor_descriptor (); },
1053+ number<NumATensor>{});
1054+ auto bs_desc = generate_tuple (
1055+ [&](auto i) { return gemm_tensor_views_tuple.at (I1)[i].get_tensor_descriptor (); },
1056+ number<NumBTensor>{});
1057+ auto ds_desc = generate_tuple (
1058+ [&](auto i) { return gemm_tensor_views_tuple.at (I2)[i].get_tensor_descriptor (); },
1059+ number<NumDTensor>{});
1060+ auto e_desc = gemm_tensor_views_tuple.at (I3).get_tensor_descriptor ();
1061+
1062+ RunGemmDesc (_as_ptr,
1063+ _bs_ptr,
1064+ _ds_ptr,
1065+ _e_ptr,
1066+ smem_ptr_0,
1067+ splitk_batch_offset,
1068+ block_idx_m,
1069+ block_idx_n,
1070+ as_desc,
1071+ bs_desc,
1072+ ds_desc,
1073+ e_desc);
1074+ }
1075+
10641076 /* *
10651077 * @brief Runs single GEMM problem cooperatively by whole workgroup.
10661078 *
0 commit comments