Skip to content

Commit 59f0c32

Browse files
committed
Rewrite sequence_map_inverse using O(1) depth pack expansion
Replace O(N) recursive template sequence_map_inverse_impl with constexpr function and pack expansion for O(1) template depth. Results: - sequence_map_inverse: 45 instances, 187ms → 7 instances, 10ms (95% reduction)
1 parent 5190578 commit 59f0c32

1 file changed

Lines changed: 25 additions & 21 deletions

File tree

include/ck/utility/sequence.hpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -576,31 +576,35 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
576576
{
577577
};
578578

579-
template <typename SeqMap>
580-
struct sequence_map_inverse
579+
// O(1) template depth helper to find source index in permutation inversion
580+
// For a permutation X2Y, finds i such that X2Y[i] == Target
581+
namespace detail {
582+
template <index_t Target, index_t... Is>
583+
__host__ __device__ constexpr index_t find_source_index(Sequence<Is...>)
581584
{
582-
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
583-
struct sequence_map_inverse_impl
585+
constexpr index_t values[] = {Is...};
586+
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
584587
{
585-
static constexpr auto new_y2x =
586-
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
587-
588-
using type =
589-
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
590-
type;
591-
};
588+
if(values[i] == Target)
589+
return i;
590+
}
591+
return 0; // should not reach for valid permutation
592+
}
592593

593-
template <typename X2Y, typename WorkingY2X, index_t XBegin>
594-
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
595-
{
596-
using type = WorkingY2X;
597-
};
594+
template <typename SeqMap, index_t... Positions>
595+
__host__ __device__ constexpr auto invert_permutation_impl(Sequence<Positions...>)
596+
{
597+
return Sequence<find_source_index<Positions>(SeqMap{})...>{};
598+
}
599+
} // namespace detail
598600

599-
using type =
600-
typename sequence_map_inverse_impl<SeqMap,
601-
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
602-
0,
603-
SeqMap::Size()>::type;
601+
// Invert a permutation sequence using O(1) template depth pack expansion
602+
// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i
603+
template <typename SeqMap>
604+
struct sequence_map_inverse
605+
{
606+
using type = decltype(detail::invert_permutation_impl<SeqMap>(
607+
typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type{}));
604608
};
605609

606610
template <index_t... Xs, index_t... Ys>

0 commit comments

Comments
 (0)