Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ynnpack/base/simd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ cc_library(
"x86_sse41_base.h",
"x86_avx_partial_load_store.h",
"x86_sse2_partial_load_store.h",
"x86_sse2_saturating_convert.h",
"generic.inc",
# For the most part, only one of these headers should be included. Multiple of these headers
# may define the same operation and type, using a different implementation, depending on the
Expand Down
58 changes: 58 additions & 0 deletions ynnpack/base/simd/arm_neon_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,31 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(
}};
}

namespace internal {

YNN_ALWAYS_INLINE int32x4_t cast_f32_to_int32(float32x4_t f) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
return vcvtq_s32_f32(round(f32x4{f}).v);
#else
return vcvtnq_s32_f32(f);
#endif
}

YNN_ALWAYS_INLINE uint32x4_t cast_f32_to_uint32(float32x4_t f) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
return vcvtq_u32_f32(round(f32x4{f}).v);
#else
return vcvtnq_u32_f32(f);
#endif
}

} // namespace internal

using f32x8 = vec<float, 8>;
using s32x8 = vec<int32_t, 8>;
using s16x16 = vec<int16_t, 16>;
using s32x16 = vec<int32_t, 16>;
using f32x16 = vec<float, 16>;

YNN_ALWAYS_INLINE f32x8 convert(bf16x8 a, float) {
uint16x8x2_t a_u32 = vzipq_u16(vdupq_n_u16(0), a.v);
Expand Down Expand Up @@ -873,6 +894,43 @@ YNN_ALWAYS_INLINE s32x4 convert(f32x4 x, int32_t) {
return s32x4{vcvtq_s32_f32(x.v)};
}

YNN_ALWAYS_INLINE s16x8 saturating_convert(s32x8 a, int16_t) {
return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))};
}

YNN_ALWAYS_INLINE s8x16 saturating_convert(s16x16 a, int8_t) {
return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))};
}

YNN_ALWAYS_INLINE u8x16 saturating_convert(s16x16 a, uint8_t) {
return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))};
}

YNN_ALWAYS_INLINE s16x8 saturating_rounding_convert(f32x8 f, int16_t) {
return s16x8{vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().v)),
vqmovn_s32(internal::cast_f32_to_int32(f.hi().v)))};
}

YNN_ALWAYS_INLINE s8x16 saturating_rounding_convert(f32x16 f, int8_t) {
const int16x8_t i01 =
vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().lo().v)),
vqmovn_s32(internal::cast_f32_to_int32(f.lo().hi().v)));
const int16x8_t i23 =
vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.hi().lo().v)),
vqmovn_s32(internal::cast_f32_to_int32(f.hi().hi().v)));
return s8x16{vcombine_s8(vqmovn_s16(i01), vqmovn_s16(i23))};
}

YNN_ALWAYS_INLINE u8x16 saturating_rounding_convert(f32x16 f, uint8_t) {
const uint16x8_t i01 =
vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.lo().lo().v)),
vqmovn_u32(internal::cast_f32_to_uint32(f.lo().hi().v)));
const uint16x8_t i23 =
vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.hi().lo().v)),
vqmovn_u32(internal::cast_f32_to_uint32(f.hi().hi().v)));
return u8x16{vcombine_u8(vqmovn_u16(i01), vqmovn_u16(i23))};
}

} // namespace simd

} // namespace ynn
Expand Down
12 changes: 12 additions & 0 deletions ynnpack/base/simd/generic.inc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ YNN_ALWAYS_INLINE vec<To, N> convert(vec<From, N> from, To) {
return {convert(from.lo(), To()), convert(from.hi(), To())};
}

template <typename To, typename From, size_t N>
YNN_ALWAYS_INLINE vec<To, N> saturating_convert(vec<From, N> from, To) {
return {saturating_convert(from.lo(), To()),
saturating_convert(from.hi(), To())};
}

template <typename To, typename From, size_t N>
YNN_ALWAYS_INLINE vec<To, N> saturating_rounding_convert(vec<From, N> from, To) {
return {saturating_rounding_convert(from.lo(), To()),
saturating_rounding_convert(from.hi(), To())};
}

template <typename T, size_t N>
YNN_ALWAYS_INLINE T horizontal_sum(vec<T, N> x) {
return horizontal_sum(x.lo() + x.hi());
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ TEST_CONVERT(arm_neon, f32, s32x4);
TEST_CONVERT(arm_neon, s32, f32x4);
TEST_CONVERT(arm_neon, f32, bf16x8);

TEST_SATURATING_CONVERT(arm_neon, s16, s32x8);
TEST_SATURATING_CONVERT(arm_neon, u8, s16x16);
TEST_SATURATING_CONVERT(arm_neon, s8, s16x16);
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, u8, f32x16);
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s8, f32x16);
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s16, f32x8);

TEST_EXTRACT(arm_neon, u8x16, 8);

TEST_CONCAT(arm_neon, u8x8);
Expand Down
65 changes: 65 additions & 0 deletions ynnpack/base/simd/test/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "ynnpack/base/arithmetic.h"
#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/vec.h"
Expand Down Expand Up @@ -622,6 +623,70 @@ void test_convert() {
#define TEST_CONVERT(test_class, to, from) \
TEST_F(test_class, convert_##to##_##from) { test_convert<to, from>(); }

template <typename To, typename From>
To saturating_convert_reference(From from, bool round = false) {
if constexpr (std::is_floating_point_v<From> && std::is_integral_v<To>) {
if (round) {
return round_float_to_int<To>(from);
}
}
return ynn::saturate_cast<To>(from);
}

template <typename To, typename From>
void test_saturating_convert() {
using FromScalar = typename From::value_type;
static constexpr size_t N = From::N;
using vector = vec<FromScalar, N>;

ReplicableRandomDevice rng;
for (auto _ : FuzzTest(std::chrono::milliseconds(100))) {
FromScalar src[N];
fill_random(src, N, rng);
From from_v = load(src, vector::N);
auto to_v = saturating_convert(from_v, To{});

To dst[N];
store(dst, to_v);
for (size_t i = 0; i < N; ++i) {
ASSERT_EQ(dst[i],
saturating_convert_reference<To>(src[i], /*round=*/false));
}
}
}

#define TEST_SATURATING_CONVERT(test_class, to, from) \
TEST_F(test_class, saturating_convert_##to##_##from) { \
test_saturating_convert<to, from>(); \
}

template <typename To, typename From>
void test_saturating_rounding_convert() {
using FromScalar = typename From::value_type;
static constexpr size_t N = From::N;
using vector = vec<FromScalar, N>;

ReplicableRandomDevice rng;
for (auto _ : FuzzTest(std::chrono::milliseconds(100))) {
FromScalar src[N];
fill_random(src, N, rng);
From from_v = load(src, vector::N);
auto to_v = saturating_rounding_convert(from_v, To{});

To dst[N];
store(dst, to_v);
for (size_t i = 0; i < N; ++i) {
ASSERT_EQ(dst[i],
saturating_convert_reference<To>(src[i], /*round=*/true));
}
}
}

#define TEST_SATURATING_ROUNDING_CONVERT(test_class, to, from) \
TEST_F(test_class, saturating_rounding_convert_##to##_##from) { \
test_saturating_rounding_convert<to, from>(); \
}

template <typename scalar, size_t N>
void test_horizontal_sum() {
using vector = vec<scalar, N>;
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/x86_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,12 @@ TEST_CONVERT(x86_avx2, s32, s8x16);
TEST_CONVERT(x86_avx2, f32, s32x8);
TEST_CONVERT(x86_avx2, s32, f32x8);

TEST_SATURATING_CONVERT(x86_avx2, s16, s32x16);
TEST_SATURATING_CONVERT(x86_avx2, u8, s16x32);
TEST_SATURATING_CONVERT(x86_avx2, s8, s16x32);
TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, u8, f32x32);
TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, s8, f32x32);
TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, s16, f32x16);

} // namespace simd
} // namespace ynn
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/x86_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,12 @@ TEST_CONVERT(x86_avx512, s32, u8x32);
TEST_CONVERT(x86_avx512, f32, s32x16);
TEST_CONVERT(x86_avx512, s32, f32x16);

TEST_SATURATING_CONVERT(x86_avx512, s16, s32x32);
TEST_SATURATING_CONVERT(x86_avx512, u8, s16x64);
TEST_SATURATING_CONVERT(x86_avx512, s8, s16x64);
TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, u8, f32x64);
TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, s8, f32x64);
TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, s16, f32x32);

} // namespace simd
} // namespace ynn
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/x86_sse2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ TEST_CONVERT(x86_sse2, s32, f32x4);
TEST_CONVERT(x86_sse2, f64, f32x4);
TEST_CONVERT(x86_sse2, f32, f64x4);

TEST_SATURATING_CONVERT(x86_sse2, s16, s32x8);
TEST_SATURATING_CONVERT(x86_sse2, u8, s16x16);
TEST_SATURATING_CONVERT(x86_sse2, s8, s16x16);
TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, u8, f32x16);
TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, s8, f32x16);
TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, s16, f32x8);

TEST_FMA(x86_sse2, f32, 4);

} // namespace simd
Expand Down
22 changes: 22 additions & 0 deletions ynnpack/base/simd/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <limits>
#include <type_traits>

#include "ynnpack/base/arithmetic.h"
#include "ynnpack/base/base.h"

namespace ynn {
Expand Down Expand Up @@ -148,6 +149,12 @@ auto extract(vec<T, N>, SliceN);
template <typename To, typename From, size_t N>
vec<To, N> convert(vec<From, N> from, To);

template <typename To, typename From, size_t N>
vec<To, N> saturating_convert(vec<From, N> from, To);

template <typename To, typename From, size_t N>
vec<To, N> saturating_rounding_convert(vec<From, N> from, To);

namespace internal {

template <typename T, size_t N>
Expand Down Expand Up @@ -307,6 +314,21 @@ YNN_ALWAYS_INLINE vec<To, 1> convert(vec<From, 1> from, To) {
return vec<To, 1>{static_cast<To>(from.v)};
}

template <typename To, typename From>
YNN_ALWAYS_INLINE vec<To, 1> saturating_convert(vec<From, 1> from, To) {
return vec<To, 1>{saturate_cast<To>(from.v)};
}

template <typename To, typename From>
YNN_ALWAYS_INLINE vec<To, 1> saturating_rounding_convert(vec<From, 1> from,
To) {
if constexpr (std::is_same_v<From, float>) {
return vec<To, 1>{round_float_to_int<To>(from.v)};
} else {
return vec<To, 1>{saturate_cast<To>(from.v)};
}
}

template <typename T>
YNN_ALWAYS_INLINE vec<T, 1> operator&(vec<T, 1> a, vec<T, 1> b) {
return vec<T, 1>{static_cast<T>(a.v & b.v)};
Expand Down
49 changes: 49 additions & 0 deletions ynnpack/base/simd/x86_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace simd {
using f32x16 = vec<float, 16>;
using s32x16 = vec<int32_t, 16>;
using s32x32 = vec<int32_t, 32>;
using f32x32 = vec<float, 32>;
using s16x32 = vec<int16_t, 32>;

YNN_ALWAYS_INLINE s32x16 convert(s8x16 a, int32_t) {
return {
Expand Down Expand Up @@ -58,6 +60,53 @@ YNN_ALWAYS_INLINE bf16x16 convert(f32x16 a, bfloat16) {
return bf16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
}

YNN_ALWAYS_INLINE s16x16 saturating_convert(s32x16 a, int16_t) {
const __m256i r = _mm256_packs_epi32(a.lo().v, a.hi().v);
return s16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
}

YNN_ALWAYS_INLINE s8x32 saturating_convert(s16x32 a, int8_t) {
const __m256i r = _mm256_packs_epi16(a.lo().v, a.hi().v);
return s8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
}

YNN_ALWAYS_INLINE u8x32 saturating_convert(s16x32 a, uint8_t) {
const __m256i r = _mm256_packus_epi16(a.lo().v, a.hi().v);
return u8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
}

YNN_ALWAYS_INLINE s16x16 saturating_rounding_convert(f32x16 f, int16_t) {
const __m256 max_int16 = _mm256_set1_ps((float)((1 << 15) - 1));
const __m256i i0 = _mm256_cvtps_epi32(_mm256_min_ps(f.lo().v, max_int16));
const __m256i i1 = _mm256_cvtps_epi32(_mm256_min_ps(f.hi().v, max_int16));
return saturating_convert(s32x16(s32x8(i0), s32x8(i1)), int16_t());
}

YNN_ALWAYS_INLINE s8x32 saturating_rounding_convert(f32x32 f, int8_t) {
const s16x16 i01 =
saturating_rounding_convert(f32x16(f.lo().lo(), f.lo().hi()), int16_t());
const s16x16 i23 =
saturating_rounding_convert(f32x16(f.hi().lo(), f.hi().hi()), int16_t());
return saturating_convert(s16x32(i01, i23), int8_t());
}

YNN_ALWAYS_INLINE u8x32 saturating_rounding_convert(f32x32 f, uint8_t) {
const __m256 max_uint16 = _mm256_set1_ps((float)((1 << 16) - 1));
const __m256i i0 =
_mm256_cvtps_epi32(_mm256_min_ps(f.lo().lo().v, max_uint16));
const __m256i i1 =
_mm256_cvtps_epi32(_mm256_min_ps(f.lo().hi().v, max_uint16));
const __m256i i2 =
_mm256_cvtps_epi32(_mm256_min_ps(f.hi().lo().v, max_uint16));
const __m256i i3 =
_mm256_cvtps_epi32(_mm256_min_ps(f.hi().hi().v, max_uint16));
const __m256i i01_16 = _mm256_packs_epi32(i0, i1);
const __m256i i23_16 = _mm256_packs_epi32(i2, i3);
const __m256i r = _mm256_packus_epi16(i01_16, i23_16);
return u8x32{_mm256_permutevar8x32_epi32(
r, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7))};
}

} // namespace simd

} // namespace ynn
Expand Down
Loading
Loading