From a8c9be9378f537ba9cf1eebbe92e97e147d2e66c Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 11:19:35 -0600 Subject: [PATCH 1/3] Rewrite sequence_map_inverse using O(1) depth pack expansion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace O(N) recursive template sequence_map_inverse_impl with constexpr function and pack expansion for O(1) template depth. Results: - sequence_map_inverse: 45 instances, 187ms → 7 instances, 10ms (95% reduction) --- include/ck/utility/sequence.hpp | 46 ++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 18bb36d112..1f392af070 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -576,31 +576,35 @@ struct is_valid_sequence_map : is_same -struct sequence_map_inverse +// O(1) template depth helper to find source index in permutation inversion +// For a permutation X2Y, finds i such that X2Y[i] == Target +namespace detail { +template +__host__ __device__ constexpr index_t find_source_index(Sequence) { - template - struct sequence_map_inverse_impl + constexpr index_t values[] = {Is...}; + for(index_t i = 0; i < static_cast(sizeof...(Is)); ++i) { - static constexpr auto new_y2x = - WorkingY2X::Modify(X2Y::At(Number{}), Number{}); - - using type = - typename sequence_map_inverse_impl:: - type; - }; + if(values[i] == Target) + return i; + } + return 0; // should not reach for valid permutation +} - template - struct sequence_map_inverse_impl - { - using type = WorkingY2X; - }; +template +__host__ __device__ constexpr auto invert_permutation_impl(Sequence) +{ + return Sequence(SeqMap{})...>{}; +} +} // namespace detail - using type = - typename sequence_map_inverse_impl::type, - 0, - SeqMap::Size()>::type; +// Invert a permutation sequence using O(1) template depth pack expansion +// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i +template +struct sequence_map_inverse +{ + using type = decltype(detail::invert_permutation_impl( + typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type{})); }; template From e74b611c14c342a82f1bd2565fd767c46ccaedbc Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 11:26:46 -0600 Subject: [PATCH 2/3] Replace O(N) recursive element space size with O(1) fold expression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use pack expansion with fold expression to compute element space size instead of recursive template or recursive lambda. Results: - calculate_element_space_size: 24 instances, 35ms → 10 instances, 9ms - Max template depth: 24 → 23 --- .../tensor_descriptor_helper.hpp | 50 +++++-------------- 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index 44ab0d90c3..e03eb4cf21 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -17,25 +17,18 @@ namespace ck { * functions on GPU without worrying about scratch memory usage. */ -#if CK_WORKAROUND_SWDEV_275126 -template -__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, - const Strides& strides, - Number i, - AccOld acc_old) +// O(1) template depth helper for element space size calculation using fold expression +// Computes: 1 + sum((length[i] - 1) * stride[i]) for all i +namespace detail { +template +__host__ __device__ constexpr auto compute_element_space_size(const Tuple& lengths, + const Tuple& strides, + Sequence) { - auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; - - if constexpr(i.value < Lengths::Size() - 1) - { - return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new); - } - else - { - return acc_new; - } + return (LongNumber<1>{} + ... + + ((lengths[Number{}] - Number<1>{}) * strides[Number{}])); } -#endif +} // namespace detail // Lengths..., Strides... could be: // 1) index_t, which is known at run-time, or @@ -60,27 +53,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple::type{}; -#if !CK_WORKAROUND_SWDEV_275126 - // rocm-4.1 compiler would crash for recursive labmda - // recursive function for reduction - auto f = [&](auto fs, auto i, auto acc_old) { - auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; - - if constexpr(i.value < N - 1) - { - return fs(fs, i + Number<1>{}, acc_new); - } - else - { - return acc_new; - } - }; - - const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{}); -#else - const auto element_space_size = - calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); -#endif + const auto element_space_size = detail::compute_element_space_size( + lengths, strides, typename arithmetic_sequence_gen<0, N, 1>::type{}); return TensorDescriptor, remove_cv_t, From 9942fd6ab9902cdb50b11e699cb72a7758aec913 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 14:01:19 -0600 Subject: [PATCH 3/3] Replace sequence_merge O(log N) recursion with O(1) fold expression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use operator| with fold expression (Seqs{} | ...) to merge sequences in O(1) template depth instead of O(log N) binary tree recursion. - Reduces sequence_merge instantiations from 449 to 167 (63% reduction) - Total template instantiations: 47,186 → 46,974 (-212) - ADL finds operator| since Sequence is in ck namespace --- include/ck/utility/sequence.hpp | 55 ++++++--------------------------- 1 file changed, 9 insertions(+), 46 deletions(-) diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 1f392af070..f8d8fc6f00 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -199,57 +199,20 @@ template using make_index_sequence = typename __make_integer_seq::seq_type; -// merge sequence - optimized to avoid recursive instantiation -namespace detail { - -// Helper to concatenate multiple sequences in one step using fold expression -template -struct sequence_merge_impl; - -// Base case: single sequence -template -struct sequence_merge_impl> -{ - using type = Sequence; -}; - -// Two sequences: direct concatenation -template -struct sequence_merge_impl, Sequence> -{ - using type = Sequence; -}; - -// Three sequences: direct concatenation (avoids one level of recursion) -template -struct sequence_merge_impl, Sequence, Sequence> -{ - using type = Sequence; -}; - -// Four sequences: direct concatenation -template -struct sequence_merge_impl, Sequence, Sequence, Sequence> +// merge sequence - O(1) template depth using fold expression +// Binary merge operator for fold expression - enables O(1) depth via (S1 | S2 | S3 | ...) +// Must be in ck namespace for ADL to find it when used with Sequence types +template +constexpr Sequence operator|(Sequence, Sequence) { - using type = Sequence; -}; - -// General case: binary tree reduction (O(log N) depth instead of O(N)) -template -struct sequence_merge_impl -{ - // Merge pairs first, then recurse - using left = typename sequence_merge_impl::type; - using right = typename sequence_merge_impl::type; - using type = typename sequence_merge_impl::type; -}; - -} // namespace detail + return {}; +} template struct sequence_merge { - using type = typename detail::sequence_merge_impl::type; + // Left fold: ((S1 | S2) | S3) | ... - O(1) template depth + using type = decltype((Seqs{} | ...)); }; template <>