From 1d7c221c956c42906ec653c9c854f16c75497024 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 15:49:59 -0600 Subject: [PATCH 1/3] Replace nested static_for lambdas with compile-time search helper The GetTransformAndItsUpperDimension function used nested static_for loops with lambdas to search for a hidden dimension in UpperDimensionIdss. This caused 918 applier::operator() instantiations (81% of all applier instantiations). Replace with find_in_tuple_of_sequences helper that uses constexpr array lookup and if-constexpr recursion, eliminating the lambda instantiation overhead. Results on example_grouped_conv_fwd_xdl_fp16: - applier instantiations: 1132 -> 127 (89% reduction) - TensorDescriptor instantiations: 2503 -> 664 (73% reduction) - Template instantiation time: 23.4s -> 19.4s (17% reduction) --- .../tensor_description/tensor_descriptor.hpp | 21 ++---- include/ck/utility/sequence_helper.hpp | 70 +++++++++++++++++++ 2 files changed, 74 insertions(+), 17 deletions(-) diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 4f827e51ea..52d876e814 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -79,24 +79,11 @@ struct TensorDescriptor constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; - - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); + // Use compile-time search helper instead of nested static_for with lambdas + // This eliminates ~918 applier::operator() instantiations + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionIdss{}); - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } constexpr static index_t ntransform_ = GetNumOfTransform(); diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 1482c936ad..7b26b3e9e4 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -52,4 +52,74 @@ __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences) return unpack(merge_sequences_functor{}, TupleOfSequences{}); } +// Find index of Target in Sequence, returns -1 if not found +// Uses constexpr array lookup for O(1) template depth +template +__host__ __device__ constexpr index_t sequence_find_value(Sequence) +{ + if constexpr(sizeof...(Is) == 0) + { + return -1; + } + else + { + constexpr bool matches[] = {(Is == Target)...}; + for(index_t i = 0; i < static_cast(sizeof...(Is)); ++i) + { + if(matches[i]) + return i; + } + return -1; + } +} + +// Result type for find_in_tuple_of_sequences +template +struct FindTransformResult +{ + static constexpr index_t itran = ITran; + static constexpr index_t idim_up = IDimUp; + static constexpr bool found = Found; +}; + +namespace detail { + +// Helper to search through a tuple of sequences for a target value +// Returns FindTransformResult with (transform_index, index_within_sequence, found) +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl() +{ + constexpr index_t idx = sequence_find_value(FirstSeq{}); + if constexpr(idx >= 0) + { + return FindTransformResult{}; + } + else if constexpr(sizeof...(RestSeqs) > 0) + { + return find_in_tuple_of_sequences_impl(); + } + else + { + return FindTransformResult<0, 0, false>{}; + } +} + +} // namespace detail + +// Find target value in a tuple of sequences +// Returns FindTransformResult +// This replaces nested static_for loops with O(1) template depth +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple) +{ + if constexpr(sizeof...(Seqs) == 0) + { + return FindTransformResult<0, 0, false>{}; + } + else + { + return detail::find_in_tuple_of_sequences_impl(); + } +} + } // namespace ck From bbf5c5e9265b27346e633b512938c260fdb3cf73 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:06:37 -0600 Subject: [PATCH 2/3] Replace generate_tuple lambda with pack expansion in InitializeElementSize The InitializeElementSize function used generate_tuple with a lambda to compute visible dimension lengths. Each TensorDescriptor type created a unique lambda type, causing 78 instantiations (385ms). Replace with direct pack expansion using helper functions, eliminating the lambda instantiation overhead entirely. Results on example_grouped_conv_fwd_xdl_fp16: - generate_tuple lambdas: 178 -> 100 (44% reduction) - Template instantiation time: 19.5s -> 19.0s --- .../tensor_description/tensor_descriptor.hpp | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 52d876e814..b066dc0d00 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -49,27 +49,29 @@ struct TensorDescriptor return unique_sort_all_dim_ids::Size(); } - __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + // Helper to get length of a visible dimension from transforms + template + __host__ __device__ static constexpr auto + GetVisibleDimLengthFromTransforms(const Transforms& transforms) { - const auto lengths = generate_tuple( - [&](auto idim_visible) { - constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); - - constexpr index_t itran = tmp[Number<0>{}]; - constexpr index_t idim_up = tmp[Number<1>{}]; - constexpr bool found = tmp[Number<2>{}]; - - static_assert(found == true, - "wrong! not found matching transformation and upper-dimension"); - - const auto length = - transforms[Number{}].GetUpperLengths()[Number{}]; + constexpr auto result = + find_in_tuple_of_sequences{})>(UpperDimensionIdss{}); + static_assert(result.found, "wrong! not found matching transformation and upper-dimension"); + return transforms[Number{}].GetUpperLengths()[Number{}]; + } - return length; - }, - Number{}); + // Compute element size using pack expansion instead of generate_tuple with lambda + template + __host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms, + Sequence) + { + return (GetVisibleDimLengthFromTransforms(transforms) * ...); + } - return container_product(lengths); + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + return ComputeElementSizeImpl( + transforms, typename arithmetic_sequence_gen<0, ndim_visible_, 1>::type{}); } template From a565d87e08b0f57120bd0b733d79e0102f3d543f Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:37:56 -0600 Subject: [PATCH 3/3] Apply same optimization pattern to TensorAdaptor TensorAdaptor has identical InitializeElementSize and GetTransformAndItsUpperDimension patterns as TensorDescriptor. Apply the same optimization: - Replace nested static_for lambdas with find_in_tuple_of_sequences - Replace generate_tuple lambda with pack expansion Results: generate_tuple lambdas 100 -> 96 (4 events, 17ms eliminated) --- .../ck/tensor_description/tensor_adaptor.hpp | 58 ++++++--------- include/ck/utility/sequence_helper.hpp | 71 ++++++++++++------- 2 files changed, 67 insertions(+), 62 deletions(-) diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 55a44198b2..031082e1a0 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -45,27 +45,29 @@ struct TensorAdaptor return BottomDimensionHiddenIds{}; } - __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + // Helper to get length of a top dimension from transforms + template + __host__ __device__ static constexpr auto + GetTopDimLengthFromTransforms(const Transforms& transforms) { - const auto lengths = generate_tuple( - [&](auto idim_top) { - constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); - - constexpr index_t itran = tmp[Number<0>{}]; - constexpr index_t idim_up = tmp[Number<1>{}]; - constexpr bool found = tmp[Number<2>{}]; - - static_assert(found == true, - "wrong! not found matching transformation and upper-dimension"); - - const auto length = - transforms[Number{}].GetUpperLengths()[Number{}]; + constexpr auto result = find_in_tuple_of_sequences{})>( + UpperDimensionHiddenIdss{}); + static_assert(result.found, "wrong! not found matching transformation and upper-dimension"); + return transforms[Number{}].GetUpperLengths()[Number{}]; + } - return length; - }, - Number{}); + // Compute element size using pack expansion instead of generate_tuple with lambda + template + __host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms, + Sequence) + { + return (GetTopDimLengthFromTransforms(transforms) * ...); + } - return container_product(lengths); + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + return ComputeElementSizeImpl(transforms, + typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{}); } template @@ -75,24 +77,10 @@ struct TensorAdaptor constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; - - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); + // Use compile-time search helper instead of nested static_for with lambdas + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionHiddenIdss{}); - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } __host__ __device__ static constexpr index_t GetNumOfBottomDimension() diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 7b26b3e9e4..0e70b2466b 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -82,44 +82,61 @@ struct FindTransformResult static constexpr bool found = Found; }; -namespace detail { - -// Helper to search through a tuple of sequences for a target value -// Returns FindTransformResult with (transform_index, index_within_sequence, found) -template -__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl() +// O(1) template depth implementation using pack expansion +// Avoids O(N) recursive template instantiations +template +struct FindInTupleOfSequencesCompute { - constexpr index_t idx = sequence_find_value(FirstSeq{}); - if constexpr(idx >= 0) + private: + // Result struct for constexpr computation + struct ResultData { - return FindTransformResult{}; - } - else if constexpr(sizeof...(RestSeqs) > 0) - { - return find_in_tuple_of_sequences_impl(); - } - else + index_t itran; + index_t idim_up; + bool found; + }; + + // Compute result using constexpr function with array lookup + static constexpr ResultData compute() { - return FindTransformResult<0, 0, false>{}; + if constexpr(sizeof...(Seqs) == 0) + { + return {0, 0, false}; + } + else + { + // Pack expansion creates array - O(1) template depth + constexpr index_t indices[] = {sequence_find_value(Seqs{})...}; + + // Find first matching sequence + for(index_t i = 0; i < static_cast(sizeof...(Seqs)); ++i) + { + if(indices[i] >= 0) + { + return {i, indices[i], true}; + } + } + return {0, 0, false}; + } } -} -} // namespace detail + static constexpr ResultData result_ = compute(); + + public: + static constexpr index_t itran = result_.itran; + static constexpr index_t idim_up = result_.idim_up; + static constexpr bool found = result_.found; + + using type = FindTransformResult; +}; // Find target value in a tuple of sequences // Returns FindTransformResult -// This replaces nested static_for loops with O(1) template depth +// Uses O(1) template depth via pack expansion (no recursion) template __host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple) { - if constexpr(sizeof...(Seqs) == 0) - { - return FindTransformResult<0, 0, false>{}; - } - else - { - return detail::find_in_tuple_of_sequences_impl(); - } + return typename FindInTupleOfSequencesCompute::type{}; } } // namespace ck