Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions include/ck/utility/container_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<T
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}

template <typename T, index_t N, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_new2old(const StaticallyIndexedArray<T, N>& old_arr,
Sequence<IRs...> /*new2old*/)
{
static_assert(N == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_statically_indexed_array<T>(old_arr[Number<IRs>{}]...);
}

template <typename T, index_t N, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_old2new(const StaticallyIndexedArray<T, N>& old_arr,
Sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_arr, typename sequence_map_inverse<decltype(old2new)>::type{});
}

template <index_t... Is, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
Sequence<IRs...> /*new2old*/)
Expand Down Expand Up @@ -358,6 +377,15 @@ __host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup,
return make_tuple(tup[Number<Is>{}]...);
}

template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const StaticallyIndexedArray<T, N>& arr,
Sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");

return StaticallyIndexedArray<T, sizeof...(Is)>{arr[Number<Is>{}]...};
}

template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
Expand All @@ -376,6 +404,29 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}

template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(StaticallyIndexedArray<T, N>& y,
Sequence<Is...> picks,
const StaticallyIndexedArray<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");

static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}

// Generic set_container_subset for StaticallyIndexedArray destination with any indexed source
template <typename T, index_t N, index_t... Is, typename Src>
requires requires { Src::Size(); }
__host__ __device__ constexpr void
set_container_subset(StaticallyIndexedArray<T, N>& y, Sequence<Is...> picks, const Src& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
static_assert(Src::Size() == sizeof...(Is), "wrong! size mismatch");

static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}

template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{
Expand Down
200 changes: 173 additions & 27 deletions include/ck/utility/statically_indexed_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,124 @@

namespace ck {

namespace detail {
template <typename X, typename Y>
struct tuple_concat;

template <typename... Xs, typename... Ys>
struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
{
using type = Tuple<Xs..., Ys...>;
};

// StaticallyIndexedArray using simple C-array instead of template metaprogramming
// This avoids deep template instantiation while maintaining the same interface
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
struct StaticallyIndexedArray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we are doing here is essentially a vector of a vector, no? Maybe we can refactor this into the vector_type class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current major problem with this class it has to be interface-compatible with a Tuple. Need to be careful with the call sites

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can retire the StaticallyIndexedArray and replace with StaticallyIndexedArray_v2

{
using type =
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
};
__host__ __device__ constexpr StaticallyIndexedArray() : data_{} {}

// Single-element constructor - exclude containers with matching size (to prefer conversion
// constructor)
template <typename X>
requires(N == 1 &&
// Allow if X is same type as T or doesn't have Size() method
(is_same<remove_cvref_t<X>, T>::value || !requires { remove_cvref_t<X>::Size(); }))
__host__ __device__ constexpr StaticallyIndexedArray(X&& x)
: data_{static_cast<T>(ck::forward<X>(x))}
{
}

template <typename T>
struct StaticallyIndexedArrayImpl<T, 0>
{
using type = Tuple<>;
// Multi-element constructor
template <typename... Xs>
requires(sizeof...(Xs) == N && N > 1)
__host__ __device__ constexpr StaticallyIndexedArray(Xs&&... xs)
: data_{static_cast<T>(ck::forward<Xs>(xs))...}
{
}

// Conversion constructor from any indexed container (Tuple, etc.)
template <typename Container>
requires(!is_same<remove_cvref_t<Container>, StaticallyIndexedArray>::value &&
requires { Container::Size(); } && Container::Size() == N)
__host__ __device__ constexpr StaticallyIndexedArray(const Container& src)
: StaticallyIndexedArray(
make_from_container(src, typename arithmetic_sequence_gen<0, N, 1>::type{}))
{
}

private:
template <typename Container, index_t... Is>
__host__ __device__ static constexpr StaticallyIndexedArray
make_from_container(const Container& src, Sequence<Is...>)
{
return StaticallyIndexedArray{static_cast<T>(src[Number<Is>{}])...};
}

public:
__host__ __device__ static constexpr index_t Size() { return N; }

// read access
template <index_t I>
__host__ __device__ constexpr const T& At(Number<I>) const
{
static_assert(I < N, "wrong! out of range");
return data_[I];
}

// write access
template <index_t I>
__host__ __device__ constexpr T& At(Number<I>)
{
static_assert(I < N, "wrong! out of range");
return data_[I];
}

// read access
template <index_t I>
__host__ __device__ constexpr const T& operator[](Number<I> i) const
{
return At(i);
}

// write access
template <index_t I>
__host__ __device__ constexpr T& operator()(Number<I> i)
{
return At(i);
}

template <typename U>
__host__ __device__ constexpr auto operator=(const U& a)
{
static_assert(U::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}

__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

T data_[N];
};

// Specialization for empty array
template <typename T>
struct StaticallyIndexedArrayImpl<T, 1>
struct StaticallyIndexedArray<T, 0>
{
using type = Tuple<T>;
};
} // namespace detail
__host__ __device__ constexpr StaticallyIndexedArray() = default;

template <typename T, index_t N>
using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
__host__ __device__ static constexpr index_t Size() { return 0; }

template <typename U>
__host__ __device__ constexpr auto operator=(const U&)
{
return *this;
}

__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};

template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>{x, static_cast<X>(xs)...};
}

// make empty StaticallyIndexedArray
template <typename X>
__host__ __device__ constexpr auto make_statically_indexed_array()
{
return StaticallyIndexedArray<X, 0>();
return StaticallyIndexedArray<X, 0>{};
}

template <typename T, index_t N>
Expand Down Expand Up @@ -101,5 +174,78 @@ struct StaticallyIndexedArray_v2
T data_[N];
};

// Concepts for StaticallyIndexedArray arithmetic operators
template <typename T>
concept Scalar = ck::is_integral<T>::value || ck::is_floating_point<T>::value;

template <typename T>
concept IndexedContainer = !Scalar<T> && requires { T::Size(); };

// Arithmetic operators for StaticallyIndexedArray (to match Tuple operators)

// StaticallyIndexedArray += X
template <typename T, index_t N, IndexedContainer X>
__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<T, N>& y, const X& x)
{
static_assert(X::Size() == N, "wrong! size not the same");
static_for<0, N, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}

// StaticallyIndexedArray -= X
template <typename T, index_t N, IndexedContainer X>
__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<T, N>& y, const X& x)
{
static_assert(X::Size() == N, "wrong! size not the same");
static_for<0, N, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}

// StaticallyIndexedArray + Y
template <typename T, index_t N, IndexedContainer Y>
__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<T, N>& x, const Y& y)
{
static_assert(Y::Size() == N, "wrong! size not the same");
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
return r;
}

// StaticallyIndexedArray - Y
template <typename T, index_t N, IndexedContainer Y>
__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<T, N>& x, const Y& y)
{
static_assert(Y::Size() == N, "wrong! size not the same");
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
return r;
}

// StaticallyIndexedArray * Y (element-wise)
template <typename T, index_t N, IndexedContainer Y>
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, const Y& y)
{
static_assert(Y::Size() == N, "wrong! size not the same");
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
return r;
}

// scalar * StaticallyIndexedArray
template <typename T, index_t N, Scalar S>
__host__ __device__ constexpr auto operator*(S a, const StaticallyIndexedArray<T, N>& x)
{
StaticallyIndexedArray<T, N> r;
static_for<0, N, 1>{}([&](auto i) { r(i) = a * x[i]; });
return r;
}

// StaticallyIndexedArray * scalar
template <typename T, index_t N, Scalar S>
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, S a)
{
return a * x;
}

} // namespace ck
#endif