diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index a7b6c60a73..f858cc32b6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -64,7 +64,7 @@ struct GemmAlgorithmInfo builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; builder::ConvSpecialization conv_specialization; - builder::GemmPadding padding; + std::optional padding; }; /// @brief Provides human-readable descriptions of convolution kernel instances @@ -121,7 +121,11 @@ class ConvDescription : public Description algorithm_.tile_dims.n, "×", algorithm_.tile_dims.k); - f.writeLine(2, "Gemm padding: ", algorithm_.padding); + if(algorithm_.padding) + f.writeLine( + 2, "Gemm padding: ", algorithm_.padding.value_or(builder::GemmPadding::DEFAULT)); + else + f.writeLine(2, "Struct does not contain optional padding argument"); f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization); // Pipeline section f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version); diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 451a74be34..16a9c47f7e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -88,7 +88,7 @@ struct ConvTraits builder::ElementwiseOperation weight_element_op; builder::ElementwiseOperation output_element_op; - builder::GemmPadding gemm_padding; + std::optional gemm_padding = std::nullopt; builder::ConvSpecialization conv_specialization; // --- Algorithm Information --- @@ -102,8 +102,14 @@ struct ConvTraits OutputTileTransferInfo c_tile_transfer; + std::optional num_gemm_k_prefetch_stage = std::nullopt; + builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; + + std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; + std::optional max_transpose_dst_scalar_per_vector = std::nullopt; + std::optional num_groups_to_merge = std::nullopt; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..f052a9701b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000..2f7c68458f --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..4f39b00b5c --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kABK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kABK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 0000000000..5666233091 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 0000000000..470a10d031 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..13625aa182 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp new file mode 100644 index 0000000000..39fde33217 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -0,0 +1,56 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..de98645514 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index cdd238f36a..2f5d84a4a8 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 28c43c342f..2108c79054 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..9413107df7 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index c4bed850eb..0cce3bf513 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 46c196e95a..c17284417d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -80,6 +80,22 @@ namespace ck_tile::reflect::conv { // SECTION 1: ENUM CONVERSIONS // ============================================================================ +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// Backwards weight layout concept - checks for In, wei and out layouts +template +concept HasBwdWeiLayouts = requires { + typename T::InLayout; + typename T::WeiLayout; + typename T::OutLayout; +}; + /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. /// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. /// @return The corresponding builder::PipelineVersion enum value. @@ -322,12 +338,23 @@ constexpr builder::ConvSpecialization conv_spec() // Tensor Layouts // ---------------------------------------------------------------------------- +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Helper function to report unsupported layout combinations with a clear error message. -/// @details This consteval function uses throw (not static_assert) to ensure the error is not -/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +/// @details This consteval function is designed to fail at compile time with a descriptive +/// error message when an unsupported layout combination is encountered. template [[noreturn]] consteval void report_unsupported_layout_error() { + // This will produce a compile-time error with the exception message throw "Unsupported convolution layout combination detected!\n" "The combination of ALayout, BLayout, and ELayout template parameters\n" "is not recognized for the given spatial dimension.\n" @@ -335,111 +362,99 @@ template "Check the conv_layout() function for the list of supported layout combinations."; } -/// @brief Derives the grouped convolution layout from a device kernel Instance type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array containing the layouts for: -/// - [0] Input tensor layout -/// - [1] Weight tensor layout -/// - [2] Output tensor layout -/// @details This function examines the Instance's ALayout, BLayout, and ELayout types -/// along with the spatial dimension to determine the appropriate layout configuration. -/// -/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions). -/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants. -/// -/// @note Compilation will fail with a clear error message if the layout combination -/// is not supported for the given spatial dimension. -/// -/// TODO: If we don't check for supported layouts, this function can be simplified. -template -constexpr std::array conv_layout() +template +constexpr auto conv_layout() { - using InstTraits = InstanceTraits; - using A = typename InstTraits::ALayout; - using B = typename InstTraits::BLayout; - using E = typename InstTraits::ELayout; - namespace ctl = ck::tensor_layout::convolution; - using enum builder::TensorLayout; - - // Helper to check if layouts match expected types - constexpr auto layouts_match = []() { - return std::is_same_v && std::is_same_v && std::is_same_v; - }; - // Helper to construct layout array - constexpr auto make_layouts = [](auto in, auto weight, auto out) { - return std::array{in, weight, out}; - }; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - constexpr int spatial_dim = InstTraits::kSpatialDim; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(spatial_dim == 1) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNWC, GKXC, GNWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNWC, GKXC, GNWK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NWGC, GKXC, NWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCW, GKXC, NGKW); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCW, GKCX, NGKW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNWC, GKXC, GNWK); // Unreachable - } - } - else if constexpr(spatial_dim == 2) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNHWC, GKYXC, GNHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNHWC, GKYXC, GNHWK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NHWGC, GKYXC, NHWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NHWGC, GKYXC, NHWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCHW, GKYXC, NGKHW); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCHW, GKCYX, NGKHW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable - } - } - else if constexpr(spatial_dim == 3) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNDHWC, GKZYXC, GNDHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNDHWC, GKZYXC, GNDHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NDHWGC, GKZYXC, NDHWGK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NGCDHW, GKZYXC, NGKDHW); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NGCDHW, GKCZYX, NGKDHW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable - } - } - else + switch(kSpatialDim) { - report_unsupported_layout_error(); - return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } + + // If we reach here, the layout combination is not supported + // Call consteval function to trigger a compile-time error with a clear message + report_unsupported_layout_error(); + + // This return is unreachable but needed to satisfy the compiler + return layouts(GNHWC, GKYXC, GNHWK); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout + +template +constexpr auto fwd_conv_layout() + requires HasFwdConvLayouts> +{ + + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + return conv_layout::kSpatialDim>(); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout +template +constexpr auto bwd_wei_conv_layout() + requires HasBwdWeiLayouts> +{ + + using A = typename InstanceTraits::InLayout; + using B = typename InstanceTraits::WeiLayout; + using E = typename InstanceTraits::OutLayout; + return conv_layout::kSpatialDim>(); } // ---------------------------------------------------------------------------- @@ -447,13 +462,11 @@ constexpr std::array conv_layout() // ---------------------------------------------------------------------------- /// @brief Helper function to report unsupported data type with a clear error message. -/// @details This consteval function uses throw (not static_assert) to ensure the error is not -/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. -template +template [[noreturn]] consteval void report_unsupported_data_type_error() { throw "Unsupported data type detected!\n" - "The ADataType is not recognized.\n" + "The DataTypeFromInstance is not recognized.\n" "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " @@ -462,62 +475,44 @@ template "Please verify that your kernel instance uses a supported data type."; } -/// @brief Derives the data type from a device kernel Instance type. -/// @tparam Instance The device kernel instance type. -/// @return A builder::DataType enum value representing the input data type. -/// @details This function examines the Instance's ADataType to determine the data type -/// used for the input tensor. The function supports various floating-point and integer -/// types, including tuple types for mixed-precision operations. -/// -/// Supported data types include: -/// - FP16 (ck::half_t) -/// - FP16_FP16 (ck::Tuple) -/// - BF16 (ck::bhalf_t) -/// - BF16_BF16 (ck::Tuple) -/// - FP32 (float) -/// - FP32_FP32 (ck::Tuple) -/// - FP64 (double) -/// - FP8 (ck::f8_t) -/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t) -/// - I8 (int8_t) -/// - I8_I8 (ck::Tuple) -/// - U8 (uint8_t) -template +/// @brief Derives the data type from a device kernel `Instance` type. +/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). +// Note: maybe move to types.hpp? +template constexpr builder::DataType conv_data_type() + { - using InstTraits = InstanceTraits; - using ADataType = typename InstTraits::ADataType; using enum builder::DataType; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return FP16; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return FP16_FP16; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF16; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return BF16_BF16; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP32; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return FP32_FP32; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP64; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return I8; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return I8_I8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return U8; else { - report_unsupported_data_type_error(); + report_unsupported_data_type_error(); return FP32; // Unreachable } } @@ -736,4 +731,92 @@ constexpr builder::PipelineScheduler get_pipeline_scheduler() } } +// ============================================================================ +// SECTION 4: Helper functions for common structures often used in reflection +// ============================================================================ + +template +constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock) +{ + return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0}; +} + +template +constexpr InputTileTransferInfo +conv_traits_a_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock) +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, .k1 = _k1}, + .transfer_params = {.k1 = _k1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; +} + +template +constexpr InputTileTransferInfo +conv_traits_b_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock) +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, .k1 = _k1}, + .transfer_params = {.k1 = _k1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; +} + +template +constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma, + .gemm_n = InstTraits::kNPerWmma, + .m_iter = InstTraits::kMRepeat, + .n_iter = InstTraits::kNRepeat}; +} + +template +constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; +} + +template +constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle}, + .thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0], + InstTraits::kCDEThreadClusterLengths[1], + InstTraits::kCDEThreadClusterLengths[2], + InstTraits::kCDEThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; +} + +template +constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; +} + } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 00010e2d48..e10baaf712 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -3,6 +3,18 @@ #pragma once +// Fwd instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +// Bwd weight instances +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index cde1896993..c3a5f9df29 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -62,6 +62,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag +{ +}; template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -231,7 +256,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -251,30 +276,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 15. BlockSize - oss << "," << kMPerBlock; // 16. MPerBlock - oss << "," << kNPerBlock; // 17. NPerBlock - oss << "," << kKPerBlock; // 18. KPerBlock - oss << "," << kABK1; // 19. ABK1 - oss << "," << kMPerWmma; // 20. MPerWmma - oss << "," << kNPerWmma; // 21. NPerWmma - oss << "," << kMRepeat; // 22. MRepeat - oss << "," << kNRepeat; // 23. NRepeat + kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kKPerBlock; // 18. KPerBlock + oss << "," << kK1; // 19. ABK1 + oss << "," << kMPerWmma; // 20. MPerWmma + oss << "," << kNPerWmma; // 21. NPerWmma + oss << "," << kMRepeat; // 22. MRepeat + oss << "," << kNRepeat; // 23. NRepeat oss << "," << detail::sequence_name(); // 24. oss << "," << detail::sequence_name(); // 25. oss << "," << detail::sequence_name(); // 26. oss << "," << kABlockTransferSrcVectorDim; // 27. oss << "," << kABlockTransferSrcScalarPerVector; // 28. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << detail::sequence_name(); // 32. oss << "," << detail::sequence_name(); // 33. oss << "," << kBBlockTransferSrcVectorDim; // 34. oss << "," << kBBlockTransferSrcScalarPerVector; // 35. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. oss << "," << kCShuffleMRepeatPerShuffle; // 38. oss << "," << kCShuffleNRepeatPerShuffle; // 39. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 6508ac7d6e..173da8268a 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -59,6 +59,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag +{ +}; template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; @@ -211,6 +232,9 @@ struct InstanceTraits< using ComputeTypeA = ComputeTypeA_; using ComputeTypeB = ComputeTypeB_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + // Static member function to generate instance string static std::string instance_string() { @@ -220,7 +244,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -240,30 +264,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 15. BlockSize - oss << "," << kMPerBlock; // 16. MPerBlock - oss << "," << kNPerBlock; // 17. NPerBlock - oss << "," << kK0PerBlock; // 18. K0PerBlock - oss << "," << kK1; // 19. K1 - oss << "," << kMPerXDL; // 20. MPerXDL - oss << "," << kNPerXDL; // 21. NPerXDL - oss << "," << kMXdlPerWave; // 22. MXdlPerWave - oss << "," << kNXdlPerWave; // 23. NXdlPerWave + kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kK1; // 19. K1 + oss << "," << kMPerXDL; // 20. MPerXDL + oss << "," << kNPerXDL; // 21. NPerXDL + oss << "," << kMXdlPerWave; // 22. MXdlPerWave + oss << "," << kNXdlPerWave; // 23. NXdlPerWave oss << "," << detail::sequence_name(); // 24. oss << "," << detail::sequence_name(); // 25. oss << "," << detail::sequence_name(); // 26. oss << "," << kABlockTransferSrcVectorDim; // 27. oss << "," << kABlockTransferSrcScalarPerVector; // 28. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 29. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << detail::sequence_name(); // 32. oss << "," << detail::sequence_name(); // 33. oss << "," << kBBlockTransferSrcVectorDim; // 34. oss << "," << kBBlockTransferSrcScalarPerVector; // 35. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index f1e40de7d2..4b90a6ab64 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag device kernel +struct DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; @@ -215,13 +231,26 @@ struct InstanceTraits< static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -237,7 +266,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -255,30 +284,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 13. BlockSize - oss << "," << kMPerBlock; // 14. MPerBlock - oss << "," << kNPerBlock; // 15. NPerBlock - oss << "," << kKPerBlock; // 16. KPerBlock - oss << "," << kABK1; // 17. ABK1 - oss << "," << kMPerWmma; // 18. MPerWmma - oss << "," << kNPerWmma; // 19. NPerWmma - oss << "," << kMRepeat; // 20. MRepeat - oss << "," << kNRepeat; // 21. NRepeat + kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kKPerBlock; // 16. KPerBlock + oss << "," << kABK1; // 17. ABK1 + oss << "," << kMPerWmma; // 18. MPerWmma + oss << "," << kNPerWmma; // 19. NPerWmma + oss << "," << kMRepeat; // 20. MRepeat + oss << "," << kNRepeat; // 21. NRepeat oss << "," << detail::sequence_name(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 460b49de93..999aff6f1e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -234,7 +260,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -252,30 +278,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index f87e295159..eba422b85f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -59,6 +59,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +{ +}; + template > // Use false to match with the default value { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -164,15 +170,15 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::LoopScheduler kLoopSched = LoopSched; static constexpr ck::PipelineVersion kPipelineVer = PipelineVer; @@ -216,7 +239,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -234,30 +257,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 29459d67b0..cfc8b4e05a 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -62,6 +62,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag +{ +}; + template > { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3"; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -172,13 +178,13 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -232,7 +257,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -250,30 +275,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 2c893b9c1d..1edf03740f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -224,7 +250,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -242,30 +268,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 147028f9cf..ce23dac1d7 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag +{ +}; + template > { + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag; static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -167,7 +175,7 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -222,7 +250,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -240,30 +268,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 782fd158c5..645d75258e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -79,6 +79,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; } // namespace ck::tensor_operation::device +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle device kernel +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag +{ +}; + namespace ck_tile::reflect { // Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle @@ -176,6 +181,8 @@ struct InstanceTraits> { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag; // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 42235df2fe..3221113565 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -9,8 +9,17 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include +#include + namespace { using ck_tile::builder::ConvDirection; @@ -26,6 +35,1099 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + 3, // NDimSpatial + ck::tensor_layout::convolution::GNDHWC, // InLayout + ck::tensor_layout::convolution::GKZYXC, // WeiLayout + ck::tensor_layout::convolution::GNDHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + 1, // NummGemmKPrefetchStage + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + false>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::Tuple<>, // DsLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 32, // MPerWMMA + 32, // NPerXDL + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleXDLTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::Tuple<>, // DsLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NummGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + ck::Sequence< + 1, + 32, + 1, + 8>, // CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEShuffleBlockTransferScalarPerVector_NPerBlock + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1>; // BlkGemmPipelineVer + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, + ck_tile::reflect::conv::convert_pipeline_scheduler()); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) {