diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d114..c153c1f894a 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -36,11 +36,9 @@ struct TensorDescriptor __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -311,6 +309,45 @@ struct lambda_get_up_dim_num } }; +// Functor to convert a single visible dimension id to hidden id +// Replaces inner lambda in transform_tensor_descriptor +// Note: transform_sequences passes index_t values, not Number<> types +template +struct convert_visible_to_hidden_id +{ + __host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const + { + return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id); + } +}; + +// Functor to convert a sequence of visible dimension ids to hidden ids +// Replaces outer lambda in transform_tensor_descriptor +template +struct convert_visible_ids_to_hidden_ids +{ + template + __host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const + { + return transform_sequences(convert_visible_to_hidden_id{}, + low_dim_visible_ids); + } +}; + +// Functor to generate arithmetic sequences from scan results +// Replaces lambda in transform_tensor_descriptor that generates up_dim_hidden_idss +template +struct generate_arithmetic_sequence_from_scan +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{}); + constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{}); + return typename arithmetic_sequence_gen::type{}; + } +}; + template ::value && is_valid_sequence_map::value, @@ -341,17 +378,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, // lower dimension's hidden idss // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of // sequences) - constexpr auto low_dim_hidden_idss = transform_tuples( - // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) - [](auto low_dim_visible_ids) constexpr { - return transform_sequences( - // convert lower dimension visible id to hidden id - [](auto low_dim_visible_id) constexpr { - return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; - }, - low_dim_visible_ids); - }, - NewLowerDimensionOldVisibleIdss{}); + constexpr auto low_dim_hidden_idss = + transform_tuples(convert_visible_ids_to_hidden_ids{}, + NewLowerDimensionOldVisibleIdss{}); constexpr index_t num_new_transform = NewTransforms::Size(); @@ -364,22 +393,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, constexpr auto up_dim_numbers_scan = merge_sequences( Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + using UpDimNumbersScanType = remove_cvref_t; constexpr auto up_dim_hidden_idss = generate_tuple( - [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { - return - typename arithmetic_sequence_gen::type{}; - }, + generate_arithmetic_sequence_from_scan{}, Number{}); // new visible dimension's hidden ids constexpr auto unordered_new_visible_dim_hidden_ids = - unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + unpack_and_merge_sequences(up_dim_hidden_idss); constexpr auto new_visible_dim_unordered2ordered = - unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, - NewUpperDimensionNewVisibleIdss{}); + unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{}); constexpr auto new_visible_dim_hidden_ids = unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 95e7bd367ae..2a71e2e4efe 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP }, Number{}); - // lower dimension Id - const auto lower_dimss = - generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); - - // upper dimension Id + // lower/upper dimension Ids + const auto lower_dimss = generate_identity_sequences(); const auto upper_dimss = lower_dimss; return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss); diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 35a6a486324..1482c936ad7 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -34,4 +34,22 @@ __host__ __device__ constexpr auto to_sequence(Tuple...>) return Sequence{}; } +// Functor for merge_sequences to avoid lambda instantiation overhead +struct merge_sequences_functor +{ + template + __host__ __device__ constexpr auto operator()(Seqs... seqs) const + { + return merge_sequences(seqs...); + } +}; + +// Helper to unpack a tuple of sequences and merge them +// Replaces: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences) +template +__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences) +{ + return unpack(merge_sequences_functor{}, TupleOfSequences{}); +} + } // namespace ck