From 2f13c7a5b05b1ff120de57f318342908f61ec86d Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 16 Jan 2026 09:13:45 +0000 Subject: [PATCH 01/12] added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp --- .../ck_tile/builder/reflect/conv_traits.hpp | 2 + .../builder/reflect/conv_traits_device | 0 ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 4 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 4 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 46 +++ ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 4 +- .../builder/reflect/conv_traits_helpers.hpp | 345 +++++++++++------- .../reflect/instance_to_conv_traits.hpp | 1 + ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 7 + .../builder/test/conv/ck/test_conv_traits.cpp | 128 +++++++ 10 files changed, 393 insertions(+), 148 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 451a74be342..2bc45efe01c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -102,6 +102,8 @@ struct ConvTraits OutputTileTransferInfo c_tile_transfer; + std::optional num_gemm_prefetch_stage = std::nullopt; + builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; }; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device new file mode 100644 index 00000000000..e69de29bb2d diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index cdd238f36a1..2f5d84a4a80 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 28c43c342fc..2108c790548 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 00000000000..5e88ee17839 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index c4bed850ebc..0cce3bf5130 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 46c196e95ad..c1bf3ad7635 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -80,6 +80,22 @@ namespace ck_tile::reflect::conv { // SECTION 1: ENUM CONVERSIONS // ============================================================================ +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// Backwards weight layout concept - checks for In, wei and out layouts +template +concept HasBwdWeiLayouts = requires { + typename T::InLayout; + typename T::WeiLayout; + typename T::OutLayout; +}; + /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. /// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. /// @return The corresponding builder::PipelineVersion enum value. @@ -322,12 +338,23 @@ constexpr builder::ConvSpecialization conv_spec() // Tensor Layouts // ---------------------------------------------------------------------------- +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Helper function to report unsupported layout combinations with a clear error message. -/// @details This consteval function uses throw (not static_assert) to ensure the error is not -/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +/// @details This consteval function is designed to fail at compile time with a descriptive +/// error message when an unsupported layout combination is encountered. template [[noreturn]] consteval void report_unsupported_layout_error() { + // This will produce a compile-time error with the exception message throw "Unsupported convolution layout combination detected!\n" "The combination of ALayout, BLayout, and ELayout template parameters\n" "is not recognized for the given spatial dimension.\n" @@ -335,111 +362,99 @@ template "Check the conv_layout() function for the list of supported layout combinations."; } -/// @brief Derives the grouped convolution layout from a device kernel Instance type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array containing the layouts for: -/// - [0] Input tensor layout -/// - [1] Weight tensor layout -/// - [2] Output tensor layout -/// @details This function examines the Instance's ALayout, BLayout, and ELayout types -/// along with the spatial dimension to determine the appropriate layout configuration. -/// -/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions). -/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants. -/// -/// @note Compilation will fail with a clear error message if the layout combination -/// is not supported for the given spatial dimension. -/// -/// TODO: If we don't check for supported layouts, this function can be simplified. -template -constexpr std::array conv_layout() +template +constexpr auto conv_layout() { - using InstTraits = InstanceTraits; - using A = typename InstTraits::ALayout; - using B = typename InstTraits::BLayout; - using E = typename InstTraits::ELayout; - namespace ctl = ck::tensor_layout::convolution; - using enum builder::TensorLayout; - - // Helper to check if layouts match expected types - constexpr auto layouts_match = []() { - return std::is_same_v && std::is_same_v && std::is_same_v; - }; - // Helper to construct layout array - constexpr auto make_layouts = [](auto in, auto weight, auto out) { - return std::array{in, weight, out}; - }; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - constexpr int spatial_dim = InstTraits::kSpatialDim; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(spatial_dim == 1) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNWC, GKXC, GNWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNWC, GKXC, GNWK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NWGC, GKXC, NWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCW, GKXC, NGKW); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCW, GKCX, NGKW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNWC, GKXC, GNWK); // Unreachable - } - } - else if constexpr(spatial_dim == 2) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNHWC, GKYXC, GNHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNHWC, GKYXC, GNHWK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NHWGC, GKYXC, NHWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NHWGC, GKYXC, NHWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCHW, GKYXC, NGKHW); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCHW, GKCYX, NGKHW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable - } - } - else if constexpr(spatial_dim == 3) + switch(kSpatialDim) { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNDHWC, GKZYXC, GNDHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNDHWC, GKZYXC, GNDHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NDHWGC, GKZYXC, NDHWGK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NGCDHW, GKZYXC, NGKDHW); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NGCDHW, GKCZYX, NGKDHW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable - } - } - else - { - report_unsupported_layout_error(); - return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } + + // If we reach here, the layout combination is not supported + // Call consteval function to trigger a compile-time error with a clear message + report_unsupported_layout_error(); + + // This return is unreachable but needed to satisfy the compiler + return layouts(GNHWC, GKYXC, GNHWK); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout + +template +constexpr auto fwd_conv_layout() + requires HasFwdConvLayouts> +{ + + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + return conv_layout::kSpatialDim>(); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout +template +constexpr auto bwd_wei_conv_layout() + requires HasBwdWeiLayouts> +{ + + using A = typename InstanceTraits::InLayout; + using B = typename InstanceTraits::WeiLayout; + using E = typename InstanceTraits::OutLayout; + return conv_layout::kSpatialDim>(); } // ---------------------------------------------------------------------------- @@ -447,13 +462,11 @@ constexpr std::array conv_layout() // ---------------------------------------------------------------------------- /// @brief Helper function to report unsupported data type with a clear error message. -/// @details This consteval function uses throw (not static_assert) to ensure the error is not -/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. -template +template [[noreturn]] consteval void report_unsupported_data_type_error() { throw "Unsupported data type detected!\n" - "The ADataType is not recognized.\n" + "The DataTypeFromInstance is not recognized.\n" "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " @@ -462,62 +475,44 @@ template "Please verify that your kernel instance uses a supported data type."; } -/// @brief Derives the data type from a device kernel Instance type. -/// @tparam Instance The device kernel instance type. -/// @return A builder::DataType enum value representing the input data type. -/// @details This function examines the Instance's ADataType to determine the data type -/// used for the input tensor. The function supports various floating-point and integer -/// types, including tuple types for mixed-precision operations. -/// -/// Supported data types include: -/// - FP16 (ck::half_t) -/// - FP16_FP16 (ck::Tuple) -/// - BF16 (ck::bhalf_t) -/// - BF16_BF16 (ck::Tuple) -/// - FP32 (float) -/// - FP32_FP32 (ck::Tuple) -/// - FP64 (double) -/// - FP8 (ck::f8_t) -/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t) -/// - I8 (int8_t) -/// - I8_I8 (ck::Tuple) -/// - U8 (uint8_t) -template +/// @brief Derives the data type from a device kernel `Instance` type. +/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). +// Note: maybe move to types.hpp? +template constexpr builder::DataType conv_data_type() + { - using InstTraits = InstanceTraits; - using ADataType = typename InstTraits::ADataType; using enum builder::DataType; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return FP16; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return FP16_FP16; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF16; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return BF16_BF16; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP32; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return FP32_FP32; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP64; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return I8; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return I8_I8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return U8; else { - report_unsupported_data_type_error(); + report_unsupported_data_type_error(); return FP32; // Unreachable } } @@ -736,4 +731,70 @@ constexpr builder::PipelineScheduler get_pipeline_scheduler() } } +// ============================================================================ +// SECTION 4: Helper functions for common structures often used in reflection +// ============================================================================ + +template +constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock) +{ + return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0}; +} + +template +constexpr InputTileTransferInfo +conv_traits_a_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock) +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, .k1 = _k1}, + .transfer_params = {.k1 = _k1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; +} + +template +constexpr InputTileTransferInfo +conv_traits_b_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock) +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, .k1 = _k1}, + .transfer_params = {.k1 = _k1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; +} + +template +constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma, + .gemm_n = InstTraits::kNPerWmma, + .m_iter = InstTraits::kMRepeat, + .n_iter = InstTraits::kNRepeat}; +} + +template +constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle}, + .thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0], + InstTraits::kCDEThreadClusterLengths[1], + InstTraits::kCDEThreadClusterLengths[2], + InstTraits::kCDEThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; +} + } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 00010e2d48b..a37efff0c7c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -6,3 +6,4 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 782fd158c53..645d75258e0 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -79,6 +79,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; } // namespace ck::tensor_operation::device +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle device kernel +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag +{ +}; + namespace ck_tile::reflect { // Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle @@ -176,6 +181,8 @@ struct InstanceTraits> { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag; // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 42235df2fe0..2c1f492dc22 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace { @@ -26,6 +27,133 @@ class ConvTraitsTest : public ::testing::Test { }; +// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NummGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + ck::Sequence< + 1, + 32, + 1, + 8>, // CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEShuffleBlockTransferScalarPerVector_NPerBlock + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1>; // BlkGemmPipelineVer + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.num_gemm_prefetch_stage, 1); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, + ck_tile::reflect::conv::convert_pipeline_scheduler()); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) { From e945dbbb786723012b64101d6ed4d632cbf3cdd0 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 16 Jan 2026 09:58:20 +0000 Subject: [PATCH 02/12] added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle --- .../ck_tile/builder/reflect/conv_traits.hpp | 2 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 53 ++++++++ .../builder/reflect/conv_traits_helpers.hpp | 22 ++++ .../reflect/instance_to_conv_traits.hpp | 4 + ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 88 ++++++++----- .../builder/test/conv/ck/test_conv_traits.cpp | 121 ++++++++++++++++++ 6 files changed, 258 insertions(+), 32 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 2bc45efe01c..0041e4352db 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -88,7 +88,7 @@ struct ConvTraits builder::ElementwiseOperation weight_element_op; builder::ElementwiseOperation output_element_op; - builder::GemmPadding gemm_padding; + std::optional gemm_padding = std::nullopt; builder::ConvSpecialization conv_specialization; // --- Algorithm Information --- diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp new file mode 100644 index 00000000000..0aa19c64a3e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index c1bf3ad7635..c17284417d4 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -784,6 +784,15 @@ constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params() .n_iter = InstTraits::kNRepeat}; } +template +constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; +} + template constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() { @@ -797,4 +806,17 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; } +template +constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; +} + } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index a37efff0c7c..586715eb35d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -3,7 +3,11 @@ #pragma once +// Fwd instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +// Bwd instances +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 2c893b9c1dd..1edf03740f3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -224,7 +250,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -242,30 +268,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 2c1f492dc22..100f8c438fd 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -12,6 +12,7 @@ #include #include +#include namespace { using ck_tile::builder::ConvDirection; @@ -27,6 +28,126 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + // test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction) { From ea1e3fe68b5958cd73cc4aa5b5303329f76334d8 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 16 Jan 2026 10:04:41 +0000 Subject: [PATCH 03/12] added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3 --- .../builder/reflect/conv_traits_device | 0 ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 53 ++++++++ .../reflect/instance_to_conv_traits.hpp | 3 +- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 70 +++++++--- .../builder/test/conv/ck/test_conv_traits.cpp | 120 ++++++++++++++++++ 5 files changed, 224 insertions(+), 22 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100644 index 00000000000..de986455143 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 586715eb35d..8049e204371 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -9,5 +9,6 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" -// Bwd instances +// Bwd weight instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 147028f9cfb..ce23dac1d7e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag +{ +}; + template > { + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag; static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -167,7 +175,7 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -222,7 +250,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -240,30 +268,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 100f8c438fd..bbd59b5bf16 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -28,6 +28,126 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + // Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) { From b9f197d21fb23750dd20b01c4892d948ffcdc68a Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 16 Jan 2026 10:27:35 +0000 Subject: [PATCH 04/12] added reflection of max_transpose parameters --- .../builder/include/ck_tile/builder/reflect/conv_traits.hpp | 3 +++ ...conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 3 +++ 2 files changed, 6 insertions(+) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 0041e4352db..6e4a04bdad3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -106,6 +106,9 @@ struct ConvTraits builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; + + std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; + std::optional max_transpose_dst_scalar_per_vector = std::nullopt; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 0aa19c64a3e..39fde332178 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -47,6 +47,9 @@ constexpr ConvTraits instance_to_conv_traits() .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, .pipeline_version = get_pipeline_version(), .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, }; } From 8013a02265a2098ad02a382784a65f0e1c228ed7 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 16 Jan 2026 12:10:49 +0000 Subject: [PATCH 05/12] fix printing of std optional parameters --- .../include/ck_tile/builder/reflect/conv_description.hpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index a7b6c60a73e..f858cc32b64 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -64,7 +64,7 @@ struct GemmAlgorithmInfo builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; builder::ConvSpecialization conv_specialization; - builder::GemmPadding padding; + std::optional padding; }; /// @brief Provides human-readable descriptions of convolution kernel instances @@ -121,7 +121,11 @@ class ConvDescription : public Description algorithm_.tile_dims.n, "×", algorithm_.tile_dims.k); - f.writeLine(2, "Gemm padding: ", algorithm_.padding); + if(algorithm_.padding) + f.writeLine( + 2, "Gemm padding: ", algorithm_.padding.value_or(builder::GemmPadding::DEFAULT)); + else + f.writeLine(2, "Struct does not contain optional padding argument"); f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization); // Pipeline section f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version); From 7ef80d057026f6c6d7d277e0c162e8fa54e951c7 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 16 Jan 2026 14:46:24 +0000 Subject: [PATCH 06/12] fix use of undefined ck::index --- .../builder/include/ck_tile/builder/reflect/conv_traits.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 6e4a04bdad3..28271b25261 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -107,8 +107,8 @@ struct ConvTraits builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; - std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; - std::optional max_transpose_dst_scalar_per_vector = std::nullopt; + std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; + std::optional max_transpose_dst_scalar_per_vector = std::nullopt; }; } // namespace ck_tile::reflect::conv From c5da43d3061e3d013622ecc0c1d66cb91e64f9b8 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Mon, 19 Jan 2026 12:08:30 +0000 Subject: [PATCH 07/12] added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle --- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 53 +++++++ .../reflect/instance_to_conv_traits.hpp | 1 + ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 66 ++++++--- .../builder/test/conv/ck/test_conv_traits.cpp | 130 +++++++++++++++++- 4 files changed, 225 insertions(+), 25 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 00000000000..2f7c68458f9 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 8049e204371..5d2b43f5051 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -12,3 +12,4 @@ // Bwd weight instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 6508ac7d6eb..173da8268a3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -59,6 +59,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag +{ +}; template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; @@ -211,6 +232,9 @@ struct InstanceTraits< using ComputeTypeA = ComputeTypeA_; using ComputeTypeB = ComputeTypeB_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + // Static member function to generate instance string static std::string instance_string() { @@ -220,7 +244,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -240,30 +264,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 15. BlockSize - oss << "," << kMPerBlock; // 16. MPerBlock - oss << "," << kNPerBlock; // 17. NPerBlock - oss << "," << kK0PerBlock; // 18. K0PerBlock - oss << "," << kK1; // 19. K1 - oss << "," << kMPerXDL; // 20. MPerXDL - oss << "," << kNPerXDL; // 21. NPerXDL - oss << "," << kMXdlPerWave; // 22. MXdlPerWave - oss << "," << kNXdlPerWave; // 23. NXdlPerWave + kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kK1; // 19. K1 + oss << "," << kMPerXDL; // 20. MPerXDL + oss << "," << kNPerXDL; // 21. NPerXDL + oss << "," << kMXdlPerWave; // 22. MXdlPerWave + oss << "," << kNXdlPerWave; // 23. NXdlPerWave oss << "," << detail::sequence_name(); // 24. oss << "," << detail::sequence_name(); // 25. oss << "," << detail::sequence_name(); // 26. oss << "," << kABlockTransferSrcVectorDim; // 27. oss << "," << kABlockTransferSrcScalarPerVector; // 28. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 29. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << detail::sequence_name(); // 32. oss << "," << detail::sequence_name(); // 33. oss << "," << kBBlockTransferSrcVectorDim; // 34. oss << "," << kBBlockTransferSrcScalarPerVector; // 35. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index bbd59b5bf16..3780304bb0e 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -13,6 +13,7 @@ #include #include +#include namespace { using ck_tile::builder::ConvDirection; @@ -28,8 +29,129 @@ class ConvTraitsTest : public ::testing::Test { }; -// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle -TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleV3TraitsExtraction) +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::Tuple<>, // DsLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleV3TraitsExtraction) { // Define a concrete instance type with specific template parameters using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< @@ -148,8 +270,8 @@ TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleV3TraitsExtraction) EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); } -// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle -TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction) { // Define a concrete instance type with specific template parameters using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< From 3e034ba050ca3bb5a08ef4cd4b6d320c044cfeaf Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Mon, 19 Jan 2026 15:40:19 +0000 Subject: [PATCH 08/12] added xdl two stage instance to reflection --- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 56 ++++++++ .../reflect/instance_to_conv_traits.hpp | 1 + ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 22 +-- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 68 +++++++--- .../builder/test/conv/ck/test_conv_traits.cpp | 127 +++++++++++++++++- 5 files changed, 240 insertions(+), 34 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 00000000000..a218c7ca448 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,56 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 5d2b43f5051..6b33569072d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -13,3 +13,4 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index f1e40de7d21..a2b9099acee 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -176,7 +176,7 @@ struct InstanceTraits< using WeiElementwiseOperation = WeiElementwiseOperation_; using OutElementwiseOperation = OutElementwiseOperation_; - static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization; static constexpr ck::index_t kBlockSize = BlockSize; static constexpr ck::index_t kMPerBlock = MPerBlock; @@ -255,16 +255,16 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 13. BlockSize - oss << "," << kMPerBlock; // 14. MPerBlock - oss << "," << kNPerBlock; // 15. NPerBlock - oss << "," << kKPerBlock; // 16. KPerBlock - oss << "," << kABK1; // 17. ABK1 - oss << "," << kMPerWmma; // 18. MPerWmma - oss << "," << kNPerWmma; // 19. NPerWmma - oss << "," << kMRepeat; // 20. MRepeat - oss << "," << kNRepeat; // 21. NRepeat + kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kKPerBlock; // 16. KPerBlock + oss << "," << kABK1; // 17. ABK1 + oss << "," << kMPerWmma; // 18. MPerWmma + oss << "," << kNPerWmma; // 19. NPerWmma + oss << "," << kMRepeat; // 20. MRepeat + oss << "," << kNRepeat; // 21. NRepeat oss << "," << detail::sequence_name(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 460b49de937..999aff6f1e1 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -234,7 +260,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -252,30 +278,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 3780304bb0e..4afe2398d55 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace { using ck_tile::builder::ConvDirection; @@ -29,6 +30,130 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + // Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) { @@ -146,8 +271,6 @@ TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); - EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 From 9907357c8a23d07e950f963e48e3c913d92434d2 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Mon, 19 Jan 2026 16:15:04 +0000 Subject: [PATCH 09/12] added additional variables --- .../builder/include/ck_tile/builder/reflect/conv_traits.hpp | 1 + .../conv_traits_device_grouped_conv_bwd_weight_multiple_d_ | 0 ...its_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 1 + 3 files changed, 2 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_ diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 28271b25261..17bfc930ee7 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -109,6 +109,7 @@ struct ConvTraits std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; std::optional max_transpose_dst_scalar_per_vector = std::nullopt; + std::optional num_groups_to_merge = std::nullopt; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_ new file mode 100644 index 00000000000..e69de29bb2d diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index a218c7ca448..dc5be82b35c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -50,6 +50,7 @@ constexpr ConvTraits instance_to_conv_traits() .max_transpose_transfer_src_scalar_per_vector = InstTraits::kTransposeTransferSrcScalarPerVector, .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, }; } From c84973a4331a74e5ed9293c5d60a3a14badd5721 Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Tue, 20 Jan 2026 07:23:23 +0000 Subject: [PATCH 10/12] added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle, _v3, grouped_conv_two_stage_wmma_cshuffle_v3, --- ...device_grouped_conv_bwd_weight_multiple_d_ | 0 ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 46 ++++ ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 50 ++++ ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 2 +- .../reflect/instance_to_conv_traits.hpp | 2 + ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 71 +++-- ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 49 +++- .../builder/test/conv/ck/test_conv_traits.cpp | 254 +++++++++++++++++- 8 files changed, 437 insertions(+), 37 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_ create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_ deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..f052a9701bc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..4f39b00b5cc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kABK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kABK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index dc5be82b35c..5666233091e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -12,7 +12,7 @@ namespace ck_tile::reflect::conv { -/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag template requires HasInstanceTraits && std::same_as::device_kernel_tag, diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 6b33569072d..8e0a67b2a60 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -13,4 +13,6 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index cde1896993b..3325ff0b1c8 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -62,6 +62,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag +{ +}; template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -231,7 +256,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -251,30 +276,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 15. BlockSize - oss << "," << kMPerBlock; // 16. MPerBlock - oss << "," << kNPerBlock; // 17. NPerBlock - oss << "," << kKPerBlock; // 18. KPerBlock - oss << "," << kABK1; // 19. ABK1 - oss << "," << kMPerWmma; // 20. MPerWmma - oss << "," << kNPerWmma; // 21. NPerWmma - oss << "," << kMRepeat; // 22. MRepeat - oss << "," << kNRepeat; // 23. NRepeat + kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kKPerBlock; // 18. KPerBlock + oss << "," << kK1; // 19. ABK1 + oss << "," << kMPerWmma; // 20. MPerWmma + oss << "," << kNPerWmma; // 21. NPerWmma + oss << "," << kMRepeat; // 22. MRepeat + oss << "," << kNRepeat; // 23. NRepeat oss << "," << detail::sequence_name(); // 24. oss << "," << detail::sequence_name(); // 25. oss << "," << detail::sequence_name(); // 26. oss << "," << kABlockTransferSrcVectorDim; // 27. oss << "," << kABlockTransferSrcScalarPerVector; // 28. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << detail::sequence_name(); // 32. oss << "," << detail::sequence_name(); // 33. oss << "," << kBBlockTransferSrcVectorDim; // 34. oss << "," << kBBlockTransferSrcScalarPerVector; // 35. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. oss << "," << kCShuffleMRepeatPerShuffle; // 38. oss << "," << kCShuffleNRepeatPerShuffle; // 39. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index a2b9099acee..4b90a6ab64d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag device kernel +struct DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; @@ -215,13 +231,26 @@ struct InstanceTraits< static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -237,7 +266,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -270,15 +299,15 @@ struct InstanceTraits< oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 4afe2398d55..21ddc9f5c93 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -14,7 +14,10 @@ #include #include +#include #include +#include + namespace { using ck_tile::builder::ConvDirection; @@ -30,8 +33,253 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::Tuple<>, // DsLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + // Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 -TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageCshuffleTraitsExtraction) +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 32, // MPerWMMA + 32, // NPerXDL + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction) { // Define a concrete instance type with specific template parameters using DeviceInstance = @@ -154,8 +402,8 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageCshuffleTraitsExtraction) EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); } -// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 -TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction) +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleXDLTraitsExtraction) { // Define a concrete instance type with specific template parameters using DeviceInstance = From e8c333bb14b36317e070a1e273817d87d54ccc3c Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Tue, 20 Jan 2026 11:41:31 +0000 Subject: [PATCH 11/12] added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3 --- .../ck_tile/builder/reflect/conv_traits.hpp | 2 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 50 +++++++ ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 36 ++--- .../reflect/instance_to_conv_traits.hpp | 1 + ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 26 ++-- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 71 ++++++---- .../builder/test/conv/ck/test_conv_traits.cpp | 125 +++++++++++++++++- 8 files changed, 255 insertions(+), 58 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 17bfc930ee7..16a9c47f7eb 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -102,7 +102,7 @@ struct ConvTraits OutputTileTransferInfo c_tile_transfer; - std::optional num_gemm_prefetch_stage = std::nullopt; + std::optional num_gemm_k_prefetch_stage = std::nullopt; builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..13625aa1822 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 5e88ee17839..9413107df7e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -22,24 +22,24 @@ constexpr ConvTraits instance_to_conv_traits() using InstTraits = InstanceTraits; return ConvTraits{ - .spatial_dim = InstTraits::kSpatialDim, - .direction = conv_direction(), - .layout = fwd_conv_layout(), - .data_type = conv_data_type(), - .input_element_op = elementwise_op(), - .weight_element_op = elementwise_op(), - .output_element_op = elementwise_op(), - .gemm_padding = gemm_spec(), - .conv_specialization = conv_spec(), - .thread_block_size = InstTraits::kBlockSize, - .tile_dims = conv_traits_data_tile(), - .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), - .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), - .warp_gemm = conv_traits_wmma_warp_gemm_params(), - .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), - .num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, - .pipeline_version = get_pipeline_version(), - .pipeline_scheduler = get_pipeline_scheduler(), + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 8e0a67b2a60..c186bde6cd7 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -16,3 +16,4 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 3325ff0b1c8..c3a5f9df29d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -62,7 +62,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { -/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag { }; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index f87e295159a..d3017868773 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -149,7 +149,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -234,16 +234,16 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 29459d67b01..cfc8b4e05af 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -62,6 +62,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag +{ +}; + template > { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3"; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -172,13 +178,13 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -232,7 +257,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -250,30 +275,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 21ddc9f5c93..bc77c0f2b5b 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace { @@ -33,7 +34,127 @@ class ConvTraitsTest : public ::testing::Test { }; -// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle_V3 +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleWmmaV3TraitsExtraction) { // Define a concrete instance type with specific template parameters @@ -834,7 +955,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction) // Verify specializations EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); - EXPECT_EQ(traits.num_gemm_prefetch_stage, 1); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1); // Verify algorithm information EXPECT_EQ(traits.thread_block_size, 256); From 71df5c58636277bb2059146e0ba1e75c074b95bd Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Tue, 20 Jan 2026 12:04:02 +0000 Subject: [PATCH 12/12] added reflection for bwd_weight_wmma_cshuffle --- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 48 +++++++ .../reflect/instance_to_conv_traits.hpp | 1 + ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 49 ++++++-- .../builder/test/conv/ck/test_conv_traits.cpp | 119 ++++++++++++++++++ 4 files changed, 204 insertions(+), 13 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 00000000000..470a10d0317 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index c186bde6cd7..e10baaf7120 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -17,3 +17,4 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index d3017868773..eba422b85f0 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -59,6 +59,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +{ +}; + template > // Use false to match with the default value { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag; static constexpr ck::index_t kSpatialDim = NDimSpatial; @@ -171,8 +177,8 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::LoopScheduler kLoopSched = LoopSched; static constexpr ck::PipelineVersion kPipelineVer = PipelineVer; @@ -240,8 +263,8 @@ struct InstanceTraits(); // 22. @@ -249,15 +272,15 @@ struct InstanceTraits(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index bc77c0f2b5b..32211135653 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace { @@ -34,6 +35,124 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + 3, // NDimSpatial + ck::tensor_layout::convolution::GNDHWC, // InLayout + ck::tensor_layout::convolution::GKZYXC, // WeiLayout + ck::tensor_layout::convolution::GNDHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + 1, // NummGemmKPrefetchStage + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + false>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + // Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction) {