From aef254ca0dc283b470fbe44cc9716694b6c3e73a Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 15 Jan 2026 21:54:58 -0600 Subject: [PATCH] Rewrite StaticallyIndexedArray to use C-array instead of Tuple Replace the recursive template metaprogramming implementation of StaticallyIndexedArray with a simple C-array based struct. This avoids deep template instantiation while maintaining the same interface. Key changes: - StaticallyIndexedArray now stores `T data_[N]` instead of inheriting from Tuple - Added constexpr conversion constructor to convert from any indexed container (Tuple, etc.) - Added arithmetic operators (+, -, *, +=, -=) using C++20 concepts - Added overloads for container_reorder_given_new2old/old2new - Added overloads for get_container_subset and set_container_subset - Specialization for empty array (N=0) Co-Authored-By: Claude --- include/ck/utility/container_helper.hpp | 51 +++++ .../ck/utility/statically_indexed_array.hpp | 200 +++++++++++++++--- 2 files changed, 224 insertions(+), 27 deletions(-) diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index 8f2fe45796e..02462612213 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -76,6 +76,25 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple::type{}); } +template +__host__ __device__ constexpr auto +container_reorder_given_new2old(const StaticallyIndexedArray& old_arr, + Sequence /*new2old*/) +{ + static_assert(N == sizeof...(IRs), "wrong! size not consistent"); + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + return make_statically_indexed_array(old_arr[Number{}]...); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_old2new(const StaticallyIndexedArray& old_arr, + Sequence old2new) +{ + return container_reorder_given_new2old( + old_arr, typename sequence_map_inverse::type{}); +} + template __host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, Sequence /*new2old*/) @@ -358,6 +377,15 @@ __host__ __device__ constexpr auto get_container_subset(const Tuple& tup, return make_tuple(tup[Number{}]...); } +template +__host__ __device__ constexpr auto get_container_subset(const StaticallyIndexedArray& arr, + Sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + return StaticallyIndexedArray{arr[Number{}]...}; +} + template __host__ __device__ constexpr void set_container_subset(Array& y, Sequence picks, const Array& x) @@ -376,6 +404,29 @@ set_container_subset(Tuple& y, Sequence picks, const Tuple& static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); } +template +__host__ __device__ constexpr void +set_container_subset(StaticallyIndexedArray& y, + Sequence picks, + const StaticallyIndexedArray& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +// Generic set_container_subset for StaticallyIndexedArray destination with any indexed source +template + requires requires { Src::Size(); } +__host__ __device__ constexpr void +set_container_subset(StaticallyIndexedArray& y, Sequence picks, const Src& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + static_assert(Src::Size() == sizeof...(Is), "wrong! size mismatch"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + template __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) { diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index d0735a32f6d..3dbb1d946ec 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -10,51 +10,124 @@ namespace ck { -namespace detail { -template -struct tuple_concat; - -template -struct tuple_concat, Tuple> -{ - using type = Tuple; -}; - +// StaticallyIndexedArray using simple C-array instead of template metaprogramming +// This avoids deep template instantiation while maintaining the same interface template -struct StaticallyIndexedArrayImpl +struct StaticallyIndexedArray { - using type = - typename tuple_concat::type, - typename StaticallyIndexedArrayImpl::type>::type; -}; + __host__ __device__ constexpr StaticallyIndexedArray() : data_{} {} + + // Single-element constructor - exclude containers with matching size (to prefer conversion + // constructor) + template + requires(N == 1 && + // Allow if X is same type as T or doesn't have Size() method + (is_same, T>::value || !requires { remove_cvref_t::Size(); })) + __host__ __device__ constexpr StaticallyIndexedArray(X&& x) + : data_{static_cast(ck::forward(x))} + { + } -template -struct StaticallyIndexedArrayImpl -{ - using type = Tuple<>; + // Multi-element constructor + template + requires(sizeof...(Xs) == N && N > 1) + __host__ __device__ constexpr StaticallyIndexedArray(Xs&&... xs) + : data_{static_cast(ck::forward(xs))...} + { + } + + // Conversion constructor from any indexed container (Tuple, etc.) + template + requires(!is_same, StaticallyIndexedArray>::value && + requires { Container::Size(); } && Container::Size() == N) + __host__ __device__ constexpr StaticallyIndexedArray(const Container& src) + : StaticallyIndexedArray( + make_from_container(src, typename arithmetic_sequence_gen<0, N, 1>::type{})) + { + } + + private: + template + __host__ __device__ static constexpr StaticallyIndexedArray + make_from_container(const Container& src, Sequence) + { + return StaticallyIndexedArray{static_cast(src[Number{}])...}; + } + + public: + __host__ __device__ static constexpr index_t Size() { return N; } + + // read access + template + __host__ __device__ constexpr const T& At(Number) const + { + static_assert(I < N, "wrong! out of range"); + return data_[I]; + } + + // write access + template + __host__ __device__ constexpr T& At(Number) + { + static_assert(I < N, "wrong! out of range"); + return data_[I]; + } + + // read access + template + __host__ __device__ constexpr const T& operator[](Number i) const + { + return At(i); + } + + // write access + template + __host__ __device__ constexpr T& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const U& a) + { + static_assert(U::Size() == Size(), "wrong! size not the same"); + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + T data_[N]; }; +// Specialization for empty array template -struct StaticallyIndexedArrayImpl +struct StaticallyIndexedArray { - using type = Tuple; -}; -} // namespace detail + __host__ __device__ constexpr StaticallyIndexedArray() = default; -template -using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl::type; + __host__ __device__ static constexpr index_t Size() { return 0; } + + template + __host__ __device__ constexpr auto operator=(const U&) + { + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; template __host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) { - return StaticallyIndexedArray(x, static_cast(xs)...); + return StaticallyIndexedArray{x, static_cast(xs)...}; } // make empty StaticallyIndexedArray template __host__ __device__ constexpr auto make_statically_indexed_array() { - return StaticallyIndexedArray(); + return StaticallyIndexedArray{}; } template @@ -101,5 +174,78 @@ struct StaticallyIndexedArray_v2 T data_[N]; }; +// Concepts for StaticallyIndexedArray arithmetic operators +template +concept Scalar = ck::is_integral::value || ck::is_floating_point::value; + +template +concept IndexedContainer = !Scalar && requires { T::Size(); }; + +// Arithmetic operators for StaticallyIndexedArray (to match Tuple operators) + +// StaticallyIndexedArray += X +template +__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray& y, const X& x) +{ + static_assert(X::Size() == N, "wrong! size not the same"); + static_for<0, N, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +// StaticallyIndexedArray -= X +template +__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray& y, const X& x) +{ + static_assert(X::Size() == N, "wrong! size not the same"); + static_for<0, N, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +// StaticallyIndexedArray + Y +template +__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray& x, const Y& y) +{ + static_assert(Y::Size() == N, "wrong! size not the same"); + StaticallyIndexedArray r; + static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] + y[i]; }); + return r; +} + +// StaticallyIndexedArray - Y +template +__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray& x, const Y& y) +{ + static_assert(Y::Size() == N, "wrong! size not the same"); + StaticallyIndexedArray r; + static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] - y[i]; }); + return r; +} + +// StaticallyIndexedArray * Y (element-wise) +template +__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray& x, const Y& y) +{ + static_assert(Y::Size() == N, "wrong! size not the same"); + StaticallyIndexedArray r; + static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] * y[i]; }); + return r; +} + +// scalar * StaticallyIndexedArray +template +__host__ __device__ constexpr auto operator*(S a, const StaticallyIndexedArray& x) +{ + StaticallyIndexedArray r; + static_for<0, N, 1>{}([&](auto i) { r(i) = a * x[i]; }); + return r; +} + +// StaticallyIndexedArray * scalar +template +__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray& x, S a) +{ + return a * x; +} + } // namespace ck #endif