Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions include/ck/utility/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP

#include "functional2.hpp"
#include "sequence.hpp"
#include "type.hpp"

namespace ck {

Expand All @@ -32,7 +31,11 @@ struct Array
{
static_assert(T::Size() == Size(), "wrong! size not the same");

static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
#pragma unroll
for(int i = 0; i < NSize; i++)
{
mData[i] = a[i];
}
Comment on lines +34 to +38
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the dependency on static_for. Hence, only need to include type.hpp.

And the code is easier to read.


return *this;
}
Expand Down
51 changes: 34 additions & 17 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <ostream>
#endif

#include "ck/utility/integral_constant.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
Expand Down Expand Up @@ -296,29 +296,46 @@ struct uniform_sequence_gen
};

// reverse inclusive scan (with init) sequence
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
namespace impl {
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan_impl;

template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
template <index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan_impl<Sequence<Is...>, Reduce, Init>
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;

static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
template <index_t... Indices>
static constexpr auto compute(Sequence<Indices...>)
{
constexpr index_t size = sizeof...(Is);

using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
};
if constexpr(size == 0)
{
return Sequence<>{};
}
else
{
constexpr Array<index_t, size> arr = []() {
Array<index_t, size> values = {Is...};
Array<index_t, size> result = {0};
result.At(size - 1) = Reduce{}(values[size - 1], Init);
for(index_t i = size - 1; i > 0; --i)
{
result.At(i - 1) = Reduce{}(values[i - 1], result[i]);
}
return result;
}();
return Sequence<arr[Indices]...>{};
}
}

template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{
using type = Sequence<Reduce{}(I, Init)>;
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
} // namespace impl

template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan
{
using type = Sequence<>;
using type = typename impl::sequence_reverse_inclusive_scan_impl<Seq, Reduce, Init>::type;
};

// split sequence
Expand Down
2 changes: 2 additions & 0 deletions test/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT

add_gtest_executable(unit_sequence unit_sequence.cpp)
add_gtest_executable(unit_array unit_array.cpp)
if(result EQUAL 0)
target_link_libraries(unit_sequence PRIVATE utility)
target_link_libraries(unit_array PRIVATE utility)
endif()
198 changes: 198 additions & 0 deletions test/util/unit_array.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include <gtest/gtest.h>
#include "ck/utility/array.hpp"

using namespace ck;

// Test basic Sequence construction and properties
TEST(Array, BasicConstruction)
{
using Arr = Array<index_t, 5>;
EXPECT_EQ(Arr::Size(), 5);
}

TEST(Array, InitListConstruction)
{
using Arr = Array<index_t, 5>;
Arr value{1, 2, 3, 4, 5};
EXPECT_EQ(value[0], 1);
EXPECT_EQ(value[4], 5);
}

// Test At() method
TEST(Array, AtMethod)
{
Array<int, 3> arr{10, 20, 30};
EXPECT_EQ(arr.At(0), 10);
EXPECT_EQ(arr.At(1), 20);
EXPECT_EQ(arr.At(2), 30);

// Test non-const At() for modification
arr.At(1) = 25;
EXPECT_EQ(arr.At(1), 25);
}

// Test const At() method
TEST(Array, ConstAtMethod)
{
const Array<int, 3> arr{10, 20, 30};
EXPECT_EQ(arr.At(0), 10);
EXPECT_EQ(arr.At(1), 20);
EXPECT_EQ(arr.At(2), 30);
}

// Test operator[]
TEST(Array, OperatorBracket)
{
const Array<int, 4> arr{5, 10, 15, 20};
EXPECT_EQ(arr[0], 5);
EXPECT_EQ(arr[1], 10);
EXPECT_EQ(arr[2], 15);
EXPECT_EQ(arr[3], 20);
}

// Test operator()
TEST(Array, OperatorParenthesis)
{
Array<int, 3> arr{1, 2, 3};
EXPECT_EQ(arr(0), 1);
EXPECT_EQ(arr(1), 2);
EXPECT_EQ(arr(2), 3);

// Test modification through operator()
arr(1) = 99;
EXPECT_EQ(arr(1), 99);
}

// Test operator= assignment
TEST(Array, Assignment)
{
Array<int, 3> arr1{1, 2, 3};
Array<int, 3> arr2{0, 0, 0};

arr2 = arr1;

EXPECT_EQ(arr2[0], 1);
EXPECT_EQ(arr2[1], 2);
EXPECT_EQ(arr2[2], 3);
}

// Test iterators
TEST(Array, Iterators)
{
Array<int, 5> arr{1, 2, 3, 4, 5};

// Test begin() and end()
int sum = 0;
for(auto it = arr.begin(); it != arr.end(); ++it)
{
sum += *it;
}
EXPECT_EQ(sum, 15);

// Test range-based for loop
sum = 0;
for(auto val : arr)
{
sum += val;
}
EXPECT_EQ(sum, 15);
}

// Test const iterators
TEST(Array, ConstIterators)
{
const Array<int, 4> arr{10, 20, 30, 40};

int sum = 0;
for(auto it = arr.begin(); it != arr.end(); ++it)
{
sum += *it;
}
EXPECT_EQ(sum, 100);

// Test const range-based for loop
sum = 0;
for(auto val : arr)
{
sum += val;
}
EXPECT_EQ(sum, 100);
}

// Test make_array() helper function
TEST(Array, MakeArray)
{
auto arr = make_array(1, 2, 3, 4, 5);

EXPECT_EQ(arr.Size(), 5);
EXPECT_EQ(arr[0], 1);
EXPECT_EQ(arr[1], 2);
EXPECT_EQ(arr[2], 3);
EXPECT_EQ(arr[3], 4);
EXPECT_EQ(arr[4], 5);
}

// Test make_array() with different types
TEST(Array, MakeArrayFloats)
{
auto arr = make_array(1.5f, 2.5f, 3.5f);

EXPECT_EQ(arr.Size(), 3);
EXPECT_FLOAT_EQ(arr[0], 1.5f);
EXPECT_FLOAT_EQ(arr[1], 2.5f);
EXPECT_FLOAT_EQ(arr[2], 3.5f);
}

// Test empty Array<T, 0>
TEST(Array, EmptyArray)
{
using EmptyArr = Array<int, 0>;
EXPECT_EQ(EmptyArr::Size(), 0);

// Test make_array() for empty array
auto empty = make_array<int>();
EXPECT_EQ(empty.Size(), 0);
}

// Test Array with different data types
TEST(Array, DifferentTypes)
{
Array<float, 3> float_arr{1.1f, 2.2f, 3.3f};
EXPECT_FLOAT_EQ(float_arr[0], 1.1f);
EXPECT_FLOAT_EQ(float_arr[1], 2.2f);
EXPECT_FLOAT_EQ(float_arr[2], 3.3f);

Array<double, 2> double_arr{1.23, 4.56};
EXPECT_DOUBLE_EQ(double_arr[0], 1.23);
EXPECT_DOUBLE_EQ(double_arr[1], 4.56);
}

// Test Array modification through iterators
TEST(Array, ModifyThroughIterators)
{
Array<int, 3> arr{1, 2, 3};

for(auto it = arr.begin(); it != arr.end(); ++it)
{
*it *= 2;
}

EXPECT_EQ(arr[0], 2);
EXPECT_EQ(arr[1], 4);
EXPECT_EQ(arr[2], 6);
}

// Test single element Array
TEST(Array, SingleElement)
{
Array<int, 1> arr{42};
EXPECT_EQ(arr.Size(), 1);
EXPECT_EQ(arr[0], 42);

auto single = make_array(100);
EXPECT_EQ(single.Size(), 1);
EXPECT_EQ(single[0], 100);
}