diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 2b249884b66..3e8a189e249 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -4,8 +4,7 @@ #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP -#include "functional2.hpp" -#include "sequence.hpp" +#include "type.hpp" namespace ck { @@ -32,7 +31,11 @@ struct Array { static_assert(T::Size() == Size(), "wrong! size not the same"); - static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); +#pragma unroll + for(int i = 0; i < NSize; i++) + { + mData[i] = a[i]; + } return *this; } diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6e68690048f..9180346a0c6 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -7,7 +7,7 @@ #include #endif -#include "ck/utility/integral_constant.hpp" +#include "ck/utility/array.hpp" #include "ck/utility/type.hpp" #include "ck/utility/functional.hpp" #include "ck/utility/math.hpp" @@ -296,29 +296,46 @@ struct uniform_sequence_gen }; // reverse inclusive scan (with init) sequence -template -struct sequence_reverse_inclusive_scan; +namespace impl { +template +struct sequence_reverse_inclusive_scan_impl; -template -struct sequence_reverse_inclusive_scan, Reduce, Init> +template +struct sequence_reverse_inclusive_scan_impl, Reduce, Init> { - using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; - - static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); + template + static constexpr auto compute(Sequence) + { + constexpr index_t size = sizeof...(Is); - using type = typename sequence_merge, old_scan>::type; -}; + if constexpr(size == 0) + { + return Sequence<>{}; + } + else + { + constexpr Array arr = []() { + Array values = {Is...}; + Array result = {0}; + result.At(size - 1) = Reduce{}(values[size - 1], Init); + for(index_t i = size - 1; i > 0; --i) + { + result.At(i - 1) = Reduce{}(values[i - 1], result[i]); + } + return result; + }(); + return Sequence{}; + } + } -template -struct sequence_reverse_inclusive_scan, Reduce, Init> -{ - using type = Sequence; + using type = decltype(compute(make_index_sequence{})); }; +} // namespace impl -template -struct sequence_reverse_inclusive_scan, Reduce, Init> +template +struct sequence_reverse_inclusive_scan { - using type = Sequence<>; + using type = typename impl::sequence_reverse_inclusive_scan_impl::type; }; // split sequence diff --git a/test/util/CMakeLists.txt b/test/util/CMakeLists.txt index bf0a444f18b..5be55e1f4dc 100644 --- a/test/util/CMakeLists.txt +++ b/test/util/CMakeLists.txt @@ -2,6 +2,8 @@ # SPDX-License-Identifier: MIT add_gtest_executable(unit_sequence unit_sequence.cpp) +add_gtest_executable(unit_array unit_array.cpp) if(result EQUAL 0) target_link_libraries(unit_sequence PRIVATE utility) + target_link_libraries(unit_array PRIVATE utility) endif() diff --git a/test/util/unit_array.cpp b/test/util/unit_array.cpp new file mode 100644 index 00000000000..436ac9bf0ca --- /dev/null +++ b/test/util/unit_array.cpp @@ -0,0 +1,198 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck/utility/array.hpp" + +using namespace ck; + +// Test basic Sequence construction and properties +TEST(Array, BasicConstruction) +{ + using Arr = Array; + EXPECT_EQ(Arr::Size(), 5); +} + +TEST(Array, InitListConstruction) +{ + using Arr = Array; + Arr value{1, 2, 3, 4, 5}; + EXPECT_EQ(value[0], 1); + EXPECT_EQ(value[4], 5); +} + +// Test At() method +TEST(Array, AtMethod) +{ + Array arr{10, 20, 30}; + EXPECT_EQ(arr.At(0), 10); + EXPECT_EQ(arr.At(1), 20); + EXPECT_EQ(arr.At(2), 30); + + // Test non-const At() for modification + arr.At(1) = 25; + EXPECT_EQ(arr.At(1), 25); +} + +// Test const At() method +TEST(Array, ConstAtMethod) +{ + const Array arr{10, 20, 30}; + EXPECT_EQ(arr.At(0), 10); + EXPECT_EQ(arr.At(1), 20); + EXPECT_EQ(arr.At(2), 30); +} + +// Test operator[] +TEST(Array, OperatorBracket) +{ + const Array arr{5, 10, 15, 20}; + EXPECT_EQ(arr[0], 5); + EXPECT_EQ(arr[1], 10); + EXPECT_EQ(arr[2], 15); + EXPECT_EQ(arr[3], 20); +} + +// Test operator() +TEST(Array, OperatorParenthesis) +{ + Array arr{1, 2, 3}; + EXPECT_EQ(arr(0), 1); + EXPECT_EQ(arr(1), 2); + EXPECT_EQ(arr(2), 3); + + // Test modification through operator() + arr(1) = 99; + EXPECT_EQ(arr(1), 99); +} + +// Test operator= assignment +TEST(Array, Assignment) +{ + Array arr1{1, 2, 3}; + Array arr2{0, 0, 0}; + + arr2 = arr1; + + EXPECT_EQ(arr2[0], 1); + EXPECT_EQ(arr2[1], 2); + EXPECT_EQ(arr2[2], 3); +} + +// Test iterators +TEST(Array, Iterators) +{ + Array arr{1, 2, 3, 4, 5}; + + // Test begin() and end() + int sum = 0; + for(auto it = arr.begin(); it != arr.end(); ++it) + { + sum += *it; + } + EXPECT_EQ(sum, 15); + + // Test range-based for loop + sum = 0; + for(auto val : arr) + { + sum += val; + } + EXPECT_EQ(sum, 15); +} + +// Test const iterators +TEST(Array, ConstIterators) +{ + const Array arr{10, 20, 30, 40}; + + int sum = 0; + for(auto it = arr.begin(); it != arr.end(); ++it) + { + sum += *it; + } + EXPECT_EQ(sum, 100); + + // Test const range-based for loop + sum = 0; + for(auto val : arr) + { + sum += val; + } + EXPECT_EQ(sum, 100); +} + +// Test make_array() helper function +TEST(Array, MakeArray) +{ + auto arr = make_array(1, 2, 3, 4, 5); + + EXPECT_EQ(arr.Size(), 5); + EXPECT_EQ(arr[0], 1); + EXPECT_EQ(arr[1], 2); + EXPECT_EQ(arr[2], 3); + EXPECT_EQ(arr[3], 4); + EXPECT_EQ(arr[4], 5); +} + +// Test make_array() with different types +TEST(Array, MakeArrayFloats) +{ + auto arr = make_array(1.5f, 2.5f, 3.5f); + + EXPECT_EQ(arr.Size(), 3); + EXPECT_FLOAT_EQ(arr[0], 1.5f); + EXPECT_FLOAT_EQ(arr[1], 2.5f); + EXPECT_FLOAT_EQ(arr[2], 3.5f); +} + +// Test empty Array +TEST(Array, EmptyArray) +{ + using EmptyArr = Array; + EXPECT_EQ(EmptyArr::Size(), 0); + + // Test make_array() for empty array + auto empty = make_array(); + EXPECT_EQ(empty.Size(), 0); +} + +// Test Array with different data types +TEST(Array, DifferentTypes) +{ + Array float_arr{1.1f, 2.2f, 3.3f}; + EXPECT_FLOAT_EQ(float_arr[0], 1.1f); + EXPECT_FLOAT_EQ(float_arr[1], 2.2f); + EXPECT_FLOAT_EQ(float_arr[2], 3.3f); + + Array double_arr{1.23, 4.56}; + EXPECT_DOUBLE_EQ(double_arr[0], 1.23); + EXPECT_DOUBLE_EQ(double_arr[1], 4.56); +} + +// Test Array modification through iterators +TEST(Array, ModifyThroughIterators) +{ + Array arr{1, 2, 3}; + + for(auto it = arr.begin(); it != arr.end(); ++it) + { + *it *= 2; + } + + EXPECT_EQ(arr[0], 2); + EXPECT_EQ(arr[1], 4); + EXPECT_EQ(arr[2], 6); +} + +// Test single element Array +TEST(Array, SingleElement) +{ + Array arr{42}; + EXPECT_EQ(arr.Size(), 1); + EXPECT_EQ(arr[0], 42); + + auto single = make_array(100); + EXPECT_EQ(single.Size(), 1); + EXPECT_EQ(single[0], 100); +}