diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index 08d0f296f0..ed0ead42d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m); const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n); - GemmArgs new_args{}; - new_args.a_ptrs_ = p_as_grid; - new_args.b_ptrs_ = p_bs_grid; - new_args.ds_ptrs_ = p_ds_grid; - new_args.e_ptr_ = p_e_grid; - - new_args.a_element_op_ = a_element_op_; - new_args.b_element_op_ = b_element_op_; - new_args.cde_element_op_ = cde_element_op_; - - new_args.M_ = gemm_m; - new_args.N_ = gemm_n; - - new_args.a_grid_desc_ = a_grid_desc; - new_args.b_grid_desc_ = b_grid_desc; - new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + const auto ds_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, m_block, n_block); - new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ = + const auto e_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, m_block, n_block); - new_args.BlockStart_ = BlockStart; - new_args.BlockEnd_ = BlockEnd; - - gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args; + gemm_desc_kernel_args_.Emplace( + valid_gemms_count_, + GemmArgs{.a_ptrs_ = p_as_grid, + .b_ptrs_ = p_bs_grid, + .ds_ptrs_ = p_ds_grid, + .e_ptr_ = p_e_grid, + .a_element_op_ = a_element_op_, + .b_element_op_ = b_element_op_, + .cde_element_op_ = cde_element_op_, + .M_ = gemm_m, + .N_ = gemm_n, + .a_grid_desc_ = a_grid_desc, + .b_grid_desc_ = b_grid_desc, + .ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + ds_desc_mblock_mperblock_nblock_nperblock, + .e_grid_desc_mblock_mperblock_nblock_nperblock_ = + e_desc_mblock_mperblock_nblock_nperblock, + .BlockStart_ = BlockStart, + .BlockEnd_ = BlockEnd}); valid_gemms_count_++; } @@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - static_for<0, NumDTensor, 1>{}([&](auto i) { - compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; - compute_ptr_offset_of_n_.BatchStrideDs_(i) = - ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + }); + } } void Print() const @@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor << ", is_split_valid=" << std::boolalpha << is_split_valid_ << std::noboolalpha << ", grid_size=" << grid_size_ << std::endl; - static_for<0, NumDTensor, 1>{}([&](auto i) { - std::cout << " Ds[" << i.value - << "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i) - << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i) - << std::endl; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << " Ds[" << i.value << "] group stride=" + << compute_ptr_offset_of_groups_.BatchStrideDs_.At(i) + << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i) + << std::endl; + }); + } std::cout << "===== GEMM splits =====" << std::endl; for(index_t i = 0; i < valid_gemms_count_; ++i) @@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: " << gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl; - static_for<0, NumDTensor, 1>{}([&](auto d_idx) { - std::cout << " D" << d_idx.value << " descriptor: " - << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx) - << std::endl; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto d_idx) { + std::cout << " D" << d_idx.value << " descriptor: " + << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx) + << std::endl; + }); + } } } diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 2b249884b6..73eb18fe16 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -6,6 +6,8 @@ #include "functional2.hpp" #include "sequence.hpp" +#include +#include namespace ck { @@ -27,6 +29,15 @@ struct Array __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + template + __host__ constexpr auto Emplace(index_t i, Args&&... args) + -> std::enable_if_t> + { + assert(i >= 0 && i < NSize); + mData[i].~TData(); + new(mData + i) TData(ck::forward(args)...); + } + template __host__ __device__ constexpr auto operator=(const T& a) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 295b2c21b5..e42a3f2045 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -297,6 +297,9 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -56,6 +86,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw PassThrough, BiasNormalizeInInferClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index d089663f37..1f381f5f7d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -328,6 +328,8 @@ generate_sharded_instantiations( add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP} ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..6bd58617aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..5eebe7f386 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index dc759cbb54..f54588991f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -309,6 +309,8 @@ generate_sharded_instantiations( add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP} ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..6d7ede939a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..0a6dcf2e75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck