diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 55a44198b2..031082e1a0 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -45,27 +45,29 @@ struct TensorAdaptor return BottomDimensionHiddenIds{}; } - __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + // Helper to get length of a top dimension from transforms + template + __host__ __device__ static constexpr auto + GetTopDimLengthFromTransforms(const Transforms& transforms) { - const auto lengths = generate_tuple( - [&](auto idim_top) { - constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); - - constexpr index_t itran = tmp[Number<0>{}]; - constexpr index_t idim_up = tmp[Number<1>{}]; - constexpr bool found = tmp[Number<2>{}]; - - static_assert(found == true, - "wrong! not found matching transformation and upper-dimension"); - - const auto length = - transforms[Number{}].GetUpperLengths()[Number{}]; + constexpr auto result = find_in_tuple_of_sequences{})>( + UpperDimensionHiddenIdss{}); + static_assert(result.found, "wrong! not found matching transformation and upper-dimension"); + return transforms[Number{}].GetUpperLengths()[Number{}]; + } - return length; - }, - Number{}); + // Compute element size using pack expansion instead of generate_tuple with lambda + template + __host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms, + Sequence) + { + return (GetTopDimLengthFromTransforms(transforms) * ...); + } - return container_product(lengths); + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + return ComputeElementSizeImpl(transforms, + typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{}); } template @@ -75,24 +77,10 @@ struct TensorAdaptor constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; - - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); + // Use compile-time search helper instead of nested static_for with lambdas + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionHiddenIdss{}); - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } __host__ __device__ static constexpr index_t GetNumOfBottomDimension() diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 4f827e51ea..b066dc0d00 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -49,27 +49,29 @@ struct TensorDescriptor return unique_sort_all_dim_ids::Size(); } - __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + // Helper to get length of a visible dimension from transforms + template + __host__ __device__ static constexpr auto + GetVisibleDimLengthFromTransforms(const Transforms& transforms) { - const auto lengths = generate_tuple( - [&](auto idim_visible) { - constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); - - constexpr index_t itran = tmp[Number<0>{}]; - constexpr index_t idim_up = tmp[Number<1>{}]; - constexpr bool found = tmp[Number<2>{}]; - - static_assert(found == true, - "wrong! not found matching transformation and upper-dimension"); - - const auto length = - transforms[Number{}].GetUpperLengths()[Number{}]; + constexpr auto result = + find_in_tuple_of_sequences{})>(UpperDimensionIdss{}); + static_assert(result.found, "wrong! not found matching transformation and upper-dimension"); + return transforms[Number{}].GetUpperLengths()[Number{}]; + } - return length; - }, - Number{}); + // Compute element size using pack expansion instead of generate_tuple with lambda + template + __host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms, + Sequence) + { + return (GetVisibleDimLengthFromTransforms(transforms) * ...); + } - return container_product(lengths); + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + return ComputeElementSizeImpl( + transforms, typename arithmetic_sequence_gen<0, ndim_visible_, 1>::type{}); } template @@ -79,24 +81,11 @@ struct TensorDescriptor constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; - - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); + // Use compile-time search helper instead of nested static_for with lambdas + // This eliminates ~918 applier::operator() instantiations + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionIdss{}); - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } constexpr static index_t ntransform_ = GetNumOfTransform(); diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 1482c936ad..0e70b2466b 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -52,4 +52,91 @@ __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences) return unpack(merge_sequences_functor{}, TupleOfSequences{}); } +// Find index of Target in Sequence, returns -1 if not found +// Uses constexpr array lookup for O(1) template depth +template +__host__ __device__ constexpr index_t sequence_find_value(Sequence) +{ + if constexpr(sizeof...(Is) == 0) + { + return -1; + } + else + { + constexpr bool matches[] = {(Is == Target)...}; + for(index_t i = 0; i < static_cast(sizeof...(Is)); ++i) + { + if(matches[i]) + return i; + } + return -1; + } +} + +// Result type for find_in_tuple_of_sequences +template +struct FindTransformResult +{ + static constexpr index_t itran = ITran; + static constexpr index_t idim_up = IDimUp; + static constexpr bool found = Found; +}; + +// O(1) template depth implementation using pack expansion +// Avoids O(N) recursive template instantiations +template +struct FindInTupleOfSequencesCompute +{ + private: + // Result struct for constexpr computation + struct ResultData + { + index_t itran; + index_t idim_up; + bool found; + }; + + // Compute result using constexpr function with array lookup + static constexpr ResultData compute() + { + if constexpr(sizeof...(Seqs) == 0) + { + return {0, 0, false}; + } + else + { + // Pack expansion creates array - O(1) template depth + constexpr index_t indices[] = {sequence_find_value(Seqs{})...}; + + // Find first matching sequence + for(index_t i = 0; i < static_cast(sizeof...(Seqs)); ++i) + { + if(indices[i] >= 0) + { + return {i, indices[i], true}; + } + } + return {0, 0, false}; + } + } + + static constexpr ResultData result_ = compute(); + + public: + static constexpr index_t itran = result_.itran; + static constexpr index_t idim_up = result_.idim_up; + static constexpr bool found = result_.found; + + using type = FindTransformResult; +}; + +// Find target value in a tuple of sequences +// Returns FindTransformResult +// Uses O(1) template depth via pack expansion (no recursion) +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple) +{ + return typename FindInTupleOfSequencesCompute::type{}; +} + } // namespace ck