Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 23 additions & 35 deletions include/ck/tensor_description/tensor_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,29 @@ struct TensorAdaptor
return BottomDimensionHiddenIds{};
}

__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
// Helper to get length of a top dimension from transforms
template <index_t I>
__host__ __device__ static constexpr auto
GetTopDimLengthFromTransforms(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);

constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];

static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");

const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
constexpr auto result = find_in_tuple_of_sequences<TopDimensionHiddenIds::At(Number<I>{})>(
UpperDimensionHiddenIdss{});
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
}

return length;
},
Number<ndim_top_>{});
// Compute element size using pack expansion instead of generate_tuple with lambda
template <index_t... Is>
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
Sequence<Is...>)
{
return (GetTopDimLengthFromTransforms<Is>(transforms) * ...);
}

return container_product(lengths);
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
return ComputeElementSizeImpl(transforms,
typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{});
}

template <index_t IDim>
Expand All @@ -75,24 +77,10 @@ struct TensorAdaptor

constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);

index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;

static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];

static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
// Use compile-time search helper instead of nested static_for with lambdas
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionHiddenIdss{});

return make_tuple(itran_found, idim_up_found, found);
return make_tuple(result.itran, result.idim_up, result.found);
}

__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
Expand Down
59 changes: 24 additions & 35 deletions include/ck/tensor_description/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,29 @@ struct TensorDescriptor
return unique_sort_all_dim_ids::Size();
}

__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
// Helper to get length of a visible dimension from transforms
template <index_t I>
__host__ __device__ static constexpr auto
GetVisibleDimLengthFromTransforms(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_visible) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);

constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];

static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");

const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
constexpr auto result =
find_in_tuple_of_sequences<VisibleDimensionIds::At(Number<I>{})>(UpperDimensionIdss{});
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
}

return length;
},
Number<ndim_visible_>{});
// Compute element size using pack expansion instead of generate_tuple with lambda
template <index_t... Is>
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
Sequence<Is...>)
{
return (GetVisibleDimLengthFromTransforms<Is>(transforms) * ...);
}

return container_product(lengths);
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
return ComputeElementSizeImpl(
transforms, typename arithmetic_sequence_gen<0, ndim_visible_, 1>::type{});
}

template <index_t IDim>
Expand All @@ -79,24 +81,11 @@ struct TensorDescriptor

constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);

index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;

static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];

static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
// Use compile-time search helper instead of nested static_for with lambdas
// This eliminates ~918 applier::operator() instantiations
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionIdss{});

return make_tuple(itran_found, idim_up_found, found);
return make_tuple(result.itran, result.idim_up, result.found);
}

constexpr static index_t ntransform_ = GetNumOfTransform();
Expand Down
87 changes: 87 additions & 0 deletions include/ck/utility/sequence_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,91 @@ __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
return unpack(merge_sequences_functor{}, TupleOfSequences{});
}

// Find index of Target in Sequence, returns -1 if not found
// Uses constexpr array lookup for O(1) template depth
template <index_t Target, index_t... Is>
__host__ __device__ constexpr index_t sequence_find_value(Sequence<Is...>)
{
if constexpr(sizeof...(Is) == 0)
{
return -1;
}
else
{
constexpr bool matches[] = {(Is == Target)...};
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
{
if(matches[i])
return i;
}
return -1;
}
}

// Result type for find_in_tuple_of_sequences
template <index_t ITran, index_t IDimUp, bool Found>
struct FindTransformResult
{
static constexpr index_t itran = ITran;
static constexpr index_t idim_up = IDimUp;
static constexpr bool found = Found;
};

// O(1) template depth implementation using pack expansion
// Avoids O(N) recursive template instantiations
template <index_t Target, typename... Seqs>
struct FindInTupleOfSequencesCompute
{
private:
// Result struct for constexpr computation
struct ResultData
{
index_t itran;
index_t idim_up;
bool found;
};

// Compute result using constexpr function with array lookup
static constexpr ResultData compute()
{
if constexpr(sizeof...(Seqs) == 0)
{
return {0, 0, false};
}
else
{
// Pack expansion creates array - O(1) template depth
constexpr index_t indices[] = {sequence_find_value<Target>(Seqs{})...};

// Find first matching sequence
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Seqs)); ++i)
{
if(indices[i] >= 0)
{
return {i, indices[i], true};
}
}
return {0, 0, false};
}
}

static constexpr ResultData result_ = compute();

public:
static constexpr index_t itran = result_.itran;
static constexpr index_t idim_up = result_.idim_up;
static constexpr bool found = result_.found;

using type = FindTransformResult<itran, idim_up, found>;
};

// Find target value in a tuple of sequences
// Returns FindTransformResult<itran, idim_up, found>
// Uses O(1) template depth via pack expansion (no recursion)
template <index_t Target, typename... Seqs>
__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple<Seqs...>)
{
return typename FindInTupleOfSequencesCompute<Target, Seqs...>::type{};
}

} // namespace ck