Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions include/ck/tensor_description/multi_index_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,7 @@ struct Merge_v1_carry_check
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));

using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using UpLengths = decltype(make_tuple(container_product(LowLengths{})));

LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
Expand All @@ -500,7 +499,7 @@ struct Merge_v1_carry_check
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
up_lengths_{make_tuple(container_product(low_lengths))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
Expand Down Expand Up @@ -1039,8 +1038,7 @@ struct Merge_v2_magic_division
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;

using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using UpLengths = decltype(make_tuple(container_product(LowLengths{})));

using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Expand All @@ -1065,7 +1063,7 @@ struct Merge_v2_magic_division
low_lengths_magic_divisor_shift_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
up_lengths_{make_tuple(container_product(low_lengths))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
Expand Down Expand Up @@ -1194,8 +1192,7 @@ struct Merge_v2r2_magic_division
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));

using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using UpLengths = decltype(make_tuple(container_product(LowLengths{})));

using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{},
Expand Down Expand Up @@ -1223,7 +1220,7 @@ struct Merge_v2r2_magic_division
low_lengths_scan_magic_divisor_shift_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
up_lengths_{make_tuple(container_product(low_lengths))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
Expand Down Expand Up @@ -1344,8 +1341,7 @@ struct Merge_v3_division_mod
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));

using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using UpLengths = decltype(make_tuple(container_product(LowLengths{})));

LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
Expand All @@ -1357,7 +1353,7 @@ struct Merge_v3_division_mod
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
up_lengths_{make_tuple(container_product(low_lengths))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
Expand Down
3 changes: 1 addition & 2 deletions include/ck/tensor_description/tensor_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ struct TensorAdaptor
},
Number<ndim_top_>{});

// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
return container_product(lengths);
}

template <index_t IDim>
Expand Down
3 changes: 1 addition & 2 deletions include/ck/tensor_description/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ struct TensorDescriptor
},
Number<ndim_visible_>{});

// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
return container_product(lengths);
}

template <index_t IDim>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
if constexpr(isMultiA || isMultiB)
{
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
make_uniform_tuple(a_grid_desc_m_k_, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
make_uniform_tuple(b_grid_desc_n_k_, Number<NumBTensor>{});

if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
Expand Down
34 changes: 30 additions & 4 deletions include/ck/utility/container_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ __host__ __device__ constexpr auto container_reduce(const Container& x,
}
#endif

// O(1) template depth alternative to container_reduce for computing products.
// Uses fold expression via unpack instead of O(N) linear recursion.
template <typename Container>
__host__ __device__ constexpr auto container_product(const Container& x)
{
return unpack([](auto... xs) { return (xs * ...); }, x);
}

template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr auto
container_reverse_inclusive_scan(const Array<TData, NSize>& x, Reduce f, TData init)
Expand Down Expand Up @@ -316,6 +324,26 @@ container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
return y;
}

// Named functors for container_concat to reduce template instantiations
// (lambdas create unique types per call site, functors are shared)
struct make_tuple_functor
{
template <typename... Ts>
__host__ __device__ constexpr auto operator()(Ts&&... xs) const
{
return make_tuple(ck::forward<Ts>(xs)...);
}
};

struct make_array_functor
{
template <typename T, typename... Ts>
__host__ __device__ constexpr auto operator()(T&& x, Ts&&... xs) const
{
return make_array(ck::forward<T>(x), ck::forward<Ts>(xs)...);
}
};

template <typename X, typename... Ys>
__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
{
Expand All @@ -325,15 +353,13 @@ __host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
return unpack2(make_array_functor{}, ax, ay);
}

template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
return unpack2(make_tuple_functor{}, tx, ty);
}

template <typename Container>
Expand Down
22 changes: 22 additions & 0 deletions include/ck/utility/tuple_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ __host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
return generate_identity_sequences<N>();
}

// Optimized helper for common pattern: generate_tuple([&](auto) { return value; }, Number<N>{})
// Creates Tuple<T, T, ..., T> (N copies) without lambda instantiation
namespace detail {
template <typename T, index_t... Is>
__host__ __device__ constexpr auto make_uniform_tuple_impl(T&& value, Sequence<Is...>)
{
return make_tuple(((void)Is, value)...);
}
} // namespace detail

template <index_t N, typename T>
__host__ __device__ constexpr auto make_uniform_tuple(T&& value)
{
return detail::make_uniform_tuple_impl(static_cast<T&&>(value), make_index_sequence<N>{});
}

template <typename T, index_t N>
__host__ __device__ constexpr auto make_uniform_tuple(T&& value, Number<N>)
{
return make_uniform_tuple<N>(static_cast<T&&>(value));
}

// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
Expand Down