Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 52 additions & 28 deletions include/ck/tensor_description/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ struct TensorDescriptor

__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});

constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});

constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);

Expand Down Expand Up @@ -311,6 +309,45 @@ struct lambda_get_up_dim_num
}
};

// Functor to convert a single visible dimension id to hidden id
// Replaces inner lambda in transform_tensor_descriptor
// Note: transform_sequences passes index_t values, not Number<> types
template <typename OldTensorDescriptor>
struct convert_visible_to_hidden_id
{
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
{
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
}
};

// Functor to convert a sequence of visible dimension ids to hidden ids
// Replaces outer lambda in transform_tensor_descriptor
template <typename OldTensorDescriptor>
struct convert_visible_ids_to_hidden_ids
{
template <typename LowDimVisibleIds>
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
{
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
low_dim_visible_ids);
}
};

// Functor to generate arithmetic sequences from scan results
// Replaces lambda in transform_tensor_descriptor that generates up_dim_hidden_idss
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
struct generate_arithmetic_sequence_from_scan
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
return typename arithmetic_sequence_gen<start, end, 1>::type{};
}
};

template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
Expand All @@ -327,11 +364,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");

constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_old_top_ids =
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});

constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto all_new_top_ids =
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});

static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
Expand All @@ -341,17 +378,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr auto low_dim_hidden_idss =
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
NewLowerDimensionOldVisibleIdss{});

constexpr index_t num_new_transform = NewTransforms::Size();

Expand All @@ -364,22 +393,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));

using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
constexpr auto up_dim_hidden_idss = generate_tuple(
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
Number<num_new_transform>{});

// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
unpack_and_merge_sequences(up_dim_hidden_idss);

constexpr auto new_visible_dim_unordered2ordered =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});

constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
Expand Down
7 changes: 2 additions & 5 deletions include/ck/tensor_operation/gpu/device/matrix_padder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
},
Number<num_dim>{});

// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});

// upper dimension Id
// lower/upper dimension Ids
const auto lower_dimss = generate_identity_sequences<num_dim>();
const auto upper_dimss = lower_dimss;

return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
Expand Down
18 changes: 18 additions & 0 deletions include/ck/utility/sequence_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,22 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}

// Functor for merge_sequences to avoid lambda instantiation overhead
struct merge_sequences_functor
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};

// Helper to unpack a tuple of sequences and merge them
// Replaces: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences)
template <typename TupleOfSequences>
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
{
return unpack(merge_sequences_functor{}, TupleOfSequences{});
}

} // namespace ck