Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightMultiDWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightMultiDXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightTwoStageWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightTwoStageXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdLargeTensorFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvTileFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE>;
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using Ops = internal::TileElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,60 +172,63 @@ struct LayoutToCK<TensorLayout::GNDHWK>
using type = ck::tensor_layout::convolution::GNDHWK;
};

template <TensorLayout Layout>
template <TensorLayout LAYOUT>
consteval auto TensorLayoutToCK()
{
return typename LayoutToCK<Layout>::type{};
return typename LayoutToCK<LAYOUT>::type{};
}

struct EmptyAuxiliaryTensorLayout
{
using type = ck::Tuple<>;
};

template <auto AuxiliaryTensorConfigsArray, size_t... Indices>
template <auto AUXILIARY_TENSOR_CONFIGS_ARRAY, size_t... Indices>
consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck::Tuple<
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
decltype(TensorLayoutToCK<AUXILIARY_TENSOR_CONFIGS_ARRAY[Indices].layout>())...>{};
}

template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
template <auto AUXILIARY_TENSOR_CONFIGS_VALUE, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTensorLayouts
{
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTensorLayoutTuple<AuxiliaryTensorConfigsValue>(
static constexpr auto Size = AUXILIARY_TENSOR_CONFIGS_VALUE.size();
using type = decltype(GetAuxiliaryTensorLayoutTuple<AUXILIARY_TENSOR_CONFIGS_VALUE>(
std::make_index_sequence<Size>{}));
};

// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>
consteval auto GetAuxiliaryTensorLayouts()
{
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
return AuxiliaryTensorLayouts<SIGNATURE.output.operation.auxiliary_operand_configs,
SIGNATURE.spatial_dim>{};
}

template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>)
consteval auto GetAuxiliaryTensorLayouts()
{
return EmptyAuxiliaryTensorLayout{};
}

template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
template <auto SIGNATURE>
requires ConvSpatialDim<SIGNATURE.spatial_dim> &&
ValidConvInputLayoutForSpatialDim<SIGNATURE.input.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvWeightLayoutForSpatialDim<SIGNATURE.weight.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvOutputLayoutForSpatialDim<SIGNATURE.output.config.layout,
SIGNATURE.spatial_dim>
struct ConvTensorLayouts
{
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
using InLayout = decltype(TensorLayoutToCK<SIGNATURE.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<SIGNATURE.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<SIGNATURE.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<SIGNATURE>())::type;
};

} // namespace ck_tile::builder::factory::internal
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

namespace ck_tile::builder::factory::internal {
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
template <TensorLayout Layout>
template <TensorLayout LAYOUT>
struct LayoutToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
static_assert(sizeof(UnsupportedEnumValue<LAYOUT>) == 0,
"Unsupported layout conversion to CK.");
};

Expand Down Expand Up @@ -152,49 +152,52 @@ struct EmptyAuxiliaryTileTensorLayout
using type = ck_tile::tuple<>;
};

template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
template <auto AUXILIARY_TILE_TENSOR_CONFIGS_ARRAY, size_t... Indices>
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck_tile::tuple<
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
decltype(TensorLayoutToCKTile<AUXILIARY_TILE_TENSOR_CONFIGS_ARRAY[Indices].layout>())...>{};
}

template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
template <auto AUXILIARY_TILE_TENSOR_CONFIGS_VALUE, size_t SPATIAL_DIM>
requires ConvSpatialDim<SPATIAL_DIM>
struct AuxiliaryTileTensorLayouts
{
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
static constexpr auto Size = AUXILIARY_TILE_TENSOR_CONFIGS_VALUE.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AUXILIARY_TILE_TENSOR_CONFIGS_VALUE>(
std::make_index_sequence<Size>{}));
};

// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>
consteval auto GetAuxiliaryTileTensorLayouts()
{
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
return AuxiliaryTileTensorLayouts<SIGNATURE.output.operation.auxiliary_operand_configs,
SIGNATURE.spatial_dim>{};
}

template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return EmptyAuxiliaryTileTensorLayout{};
}

template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
template <auto SIGNATURE>
requires ConvSpatialDim<SIGNATURE.spatial_dim> &&
ValidConvInputLayoutForSpatialDim<SIGNATURE.input.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvWeightLayoutForSpatialDim<SIGNATURE.weight.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvOutputLayoutForSpatialDim<SIGNATURE.output.config.layout,
SIGNATURE.spatial_dim>
struct TileConvTensorLayouts
{
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
using ALayout = decltype(TensorLayoutToCKTile<SIGNATURE.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<SIGNATURE.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<SIGNATURE.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<SIGNATURE>())::type;
};

} // namespace ck_tile::builder::factory::internal
Loading