Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m);
const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n);

GemmArgs new_args{};
new_args.a_ptrs_ = p_as_grid;
new_args.b_ptrs_ = p_bs_grid;
new_args.ds_ptrs_ = p_ds_grid;
new_args.e_ptr_ = p_e_grid;

new_args.a_element_op_ = a_element_op_;
new_args.b_element_op_ = b_element_op_;
new_args.cde_element_op_ = cde_element_op_;

new_args.M_ = gemm_m;
new_args.N_ = gemm_n;

new_args.a_grid_desc_ = a_grid_desc;
new_args.b_grid_desc_ = b_grid_desc;
new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
const auto ds_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, m_block, n_block);
new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
const auto e_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, m_block, n_block);

new_args.BlockStart_ = BlockStart;
new_args.BlockEnd_ = BlockEnd;

gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args;
gemm_desc_kernel_args_.Emplace(
valid_gemms_count_,
GemmArgs{.a_ptrs_ = p_as_grid,
.b_ptrs_ = p_bs_grid,
.ds_ptrs_ = p_ds_grid,
.e_ptr_ = p_e_grid,
.a_element_op_ = a_element_op_,
.b_element_op_ = b_element_op_,
.cde_element_op_ = cde_element_op_,
.M_ = gemm_m,
.N_ = gemm_n,
.a_grid_desc_ = a_grid_desc,
.b_grid_desc_ = b_grid_desc,
.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
ds_desc_mblock_mperblock_nblock_nperblock,
.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
e_desc_mblock_mperblock_nblock_nperblock,
.BlockStart_ = BlockStart,
.BlockEnd_ = BlockEnd});

valid_gemms_count_++;
}
Expand Down Expand Up @@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;

static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
}
}

void Print() const
Expand All @@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
<< ", is_split_valid=" << std::boolalpha << is_split_valid_
<< std::noboolalpha << ", grid_size=" << grid_size_ << std::endl;

static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << " Ds[" << i.value
<< "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i)
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i)
<< std::endl;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << " Ds[" << i.value << "] group stride="
<< compute_ptr_offset_of_groups_.BatchStrideDs_.At(i)
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i)
<< std::endl;
});
}

std::cout << "===== GEMM splits =====" << std::endl;
for(index_t i = 0; i < valid_gemms_count_; ++i)
Expand All @@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<< gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl;

static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
std::cout << " D" << d_idx.value << " descriptor: "
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx)
<< std::endl;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
std::cout << " D" << d_idx.value << " descriptor: "
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx)
<< std::endl;
});
}
}
}

Expand Down
11 changes: 11 additions & 0 deletions include/ck/utility/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "functional2.hpp"
#include "sequence.hpp"
#include <type_traits>
#include <cassert>

namespace ck {

Expand All @@ -27,6 +29,15 @@ struct Array

__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }

template <typename... Args>
__host__ constexpr auto Emplace(index_t i, Args&&... args)
-> std::enable_if_t<std::is_nothrow_constructible_v<TData, Args&&...>>
{
assert(i >= 0 && i < NSize);
mData[i].~TData();
new(mData + i) TData(ck::forward<Args>(args)...);
}

template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);

add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
Expand All @@ -306,6 +309,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);

add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
}
Expand All @@ -322,6 +328,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);

add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
Expand All @@ -331,6 +340,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);

add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw
PassThrough,
BiasNormalizeInInferClamp>>>& instances);

void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);

void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
Expand All @@ -38,6 +53,21 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_n
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);

void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif

#ifdef CK_ENABLE_FP16
Expand All @@ -56,6 +86,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw
PassThrough,
BiasNormalizeInInferClamp>>>& instances);

void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);

void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
Expand All @@ -70,6 +115,21 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_n
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);

void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif

} // namespace instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ generate_sharded_instantiations(
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;

void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;

void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_large_tensor_f16_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ generate_sharded_instantiations(
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
)

Loading