Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 12 additions & 38 deletions include/ck/tensor_description/tensor_descriptor_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,18 @@ namespace ck {
* functions on GPU without worrying about scratch memory usage.
*/

#if CK_WORKAROUND_SWDEV_275126
template <typename Lengths, typename Strides, index_t I, typename AccOld>
__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
Number<I> i,
AccOld acc_old)
// O(1) template depth helper for element space size calculation using fold expression
// Computes: 1 + sum((length[i] - 1) * stride[i]) for all i
namespace detail {
template <typename... Lengths, typename... Strides, index_t... Is>
__host__ __device__ constexpr auto compute_element_space_size(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides,
Sequence<Is...>)
{
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];

if constexpr(i.value < Lengths::Size() - 1)
{
return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
return (LongNumber<1>{} + ... +
((lengths[Number<Is>{}] - Number<1>{}) * strides[Number<Is>{}]));
}
#endif
} // namespace detail

// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
Expand All @@ -60,27 +53,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng

constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};

#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto f = [&](auto fs, auto i, auto acc_old) {
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];

if constexpr(i.value < N - 1)
{
return fs(fs, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
};

const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
#else
const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
#endif
const auto element_space_size = detail::compute_element_space_size(
lengths, strides, typename arithmetic_sequence_gen<0, N, 1>::type{});

return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
Expand Down
101 changes: 34 additions & 67 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,57 +199,20 @@ template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;

// merge sequence - optimized to avoid recursive instantiation
namespace detail {

// Helper to concatenate multiple sequences in one step using fold expression
template <typename... Seqs>
struct sequence_merge_impl;

// Base case: single sequence
template <index_t... Is>
struct sequence_merge_impl<Sequence<Is...>>
// merge sequence - O(1) template depth using fold expression
// Binary merge operator for fold expression - enables O(1) depth via (S1 | S2 | S3 | ...)
// Must be in ck namespace for ADL to find it when used with Sequence types
template <index_t... As, index_t... Bs>
constexpr Sequence<As..., Bs...> operator|(Sequence<As...>, Sequence<Bs...>)
{
using type = Sequence<Is...>;
};

// Two sequences: direct concatenation
template <index_t... Xs, index_t... Ys>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Xs..., Ys...>;
};

// Three sequences: direct concatenation (avoids one level of recursion)
template <index_t... Xs, index_t... Ys, index_t... Zs>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
{
using type = Sequence<Xs..., Ys..., Zs...>;
};

// Four sequences: direct concatenation
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
{
using type = Sequence<As..., Bs..., Cs..., Ds...>;
};

// General case: binary tree reduction (O(log N) depth instead of O(N))
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
{
// Merge pairs first, then recurse
using left = typename sequence_merge_impl<S1, S2>::type;
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
using type = typename sequence_merge_impl<left, right>::type;
};

} // namespace detail
return {};
}

template <typename... Seqs>
struct sequence_merge
{
using type = typename detail::sequence_merge_impl<Seqs...>::type;
// Left fold: ((S1 | S2) | S3) | ... - O(1) template depth
using type = decltype((Seqs{} | ...));
};

template <>
Expand Down Expand Up @@ -576,31 +539,35 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
{
};

template <typename SeqMap>
struct sequence_map_inverse
// O(1) template depth helper to find source index in permutation inversion
// For a permutation X2Y, finds i such that X2Y[i] == Target
namespace detail {
template <index_t Target, index_t... Is>
__host__ __device__ constexpr index_t find_source_index(Sequence<Is...>)
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
constexpr index_t values[] = {Is...};
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});

using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
if(values[i] == Target)
return i;
}
return 0; // should not reach for valid permutation
}

template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <typename SeqMap, index_t... Positions>
__host__ __device__ constexpr auto invert_permutation_impl(Sequence<Positions...>)
{
return Sequence<find_source_index<Positions>(SeqMap{})...>{};
}
} // namespace detail

using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
// Invert a permutation sequence using O(1) template depth pack expansion
// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i
template <typename SeqMap>
struct sequence_map_inverse
{
using type = decltype(detail::invert_permutation_impl<SeqMap>(
typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type{}));
};

template <index_t... Xs, index_t... Ys>
Expand Down