Skip to content

Commit 9e0f4a5

Browse files
vksnkxnnpack-bot
authored andcommitted
Add simd wrappers for saturating_convert and saturating_rounding_convert.
PiperOrigin-RevId: 885269784
1 parent 0e7a25e commit 9e0f4a5

15 files changed

Lines changed: 372 additions & 11 deletions

File tree

ynnpack/base/simd/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ cc_library(
3737
"x86_sse41_base.h",
3838
"x86_avx_partial_load_store.h",
3939
"x86_sse2_partial_load_store.h",
40+
"x86_sse2_saturating_convert.h",
4041
"generic.inc",
4142
# For the most part, only one of these headers should be included. Multiple of these headers
4243
# may define the same operation and type, using a different implementation, depending on the

ynnpack/base/simd/arm_neon_base.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,10 +823,31 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(
823823
}};
824824
}
825825

826+
namespace internal {
827+
828+
YNN_ALWAYS_INLINE int32x4_t cast_f32_to_int32(float32x4_t f) {
829+
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
830+
return vcvtq_s32_f32(round(f32x4{f}).v);
831+
#else
832+
return vcvtnq_s32_f32(f);
833+
#endif
834+
}
835+
836+
YNN_ALWAYS_INLINE uint32x4_t cast_f32_to_uint32(float32x4_t f) {
837+
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
838+
return vcvtq_u32_f32(round(f32x4{f}).v);
839+
#else
840+
return vcvtnq_u32_f32(f);
841+
#endif
842+
}
843+
844+
} // namespace internal
845+
826846
using f32x8 = vec<float, 8>;
827847
using s32x8 = vec<int32_t, 8>;
828848
using s16x16 = vec<int16_t, 16>;
829849
using s32x16 = vec<int32_t, 16>;
850+
using f32x16 = vec<float, 16>;
830851

831852
YNN_ALWAYS_INLINE f32x8 convert(bf16x8 a, float) {
832853
uint16x8x2_t a_u32 = vzipq_u16(vdupq_n_u16(0), a.v);
@@ -873,6 +894,43 @@ YNN_ALWAYS_INLINE s32x4 convert(f32x4 x, int32_t) {
873894
return s32x4{vcvtq_s32_f32(x.v)};
874895
}
875896

897+
YNN_ALWAYS_INLINE s16x8 saturating_convert(s32x8 a, int16_t) {
898+
return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))};
899+
}
900+
901+
YNN_ALWAYS_INLINE s8x16 saturating_convert(s16x16 a, int8_t) {
902+
return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))};
903+
}
904+
905+
YNN_ALWAYS_INLINE u8x16 saturating_convert(s16x16 a, uint8_t) {
906+
return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))};
907+
}
908+
909+
YNN_ALWAYS_INLINE s16x8 saturating_rounding_convert(f32x8 f, int16_t) {
910+
return s16x8{vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().v)),
911+
vqmovn_s32(internal::cast_f32_to_int32(f.hi().v)))};
912+
}
913+
914+
YNN_ALWAYS_INLINE s8x16 saturating_rounding_convert(f32x16 f, int8_t) {
915+
const int16x8_t i01 =
916+
vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().lo().v)),
917+
vqmovn_s32(internal::cast_f32_to_int32(f.lo().hi().v)));
918+
const int16x8_t i23 =
919+
vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.hi().lo().v)),
920+
vqmovn_s32(internal::cast_f32_to_int32(f.hi().hi().v)));
921+
return s8x16{vcombine_s8(vqmovn_s16(i01), vqmovn_s16(i23))};
922+
}
923+
924+
YNN_ALWAYS_INLINE u8x16 saturating_rounding_convert(f32x16 f, uint8_t) {
925+
const uint16x8_t i01 =
926+
vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.lo().lo().v)),
927+
vqmovn_u32(internal::cast_f32_to_uint32(f.lo().hi().v)));
928+
const uint16x8_t i23 =
929+
vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.hi().lo().v)),
930+
vqmovn_u32(internal::cast_f32_to_uint32(f.hi().hi().v)));
931+
return u8x16{vcombine_u8(vqmovn_u16(i01), vqmovn_u16(i23))};
932+
}
933+
876934
} // namespace simd
877935

878936
} // namespace ynn

ynnpack/base/simd/generic.inc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,18 @@ YNN_ALWAYS_INLINE vec<To, N> convert(vec<From, N> from, To) {
253253
return {convert(from.lo(), To()), convert(from.hi(), To())};
254254
}
255255

256+
template <typename To, typename From, size_t N>
257+
YNN_ALWAYS_INLINE vec<To, N> saturating_convert(vec<From, N> from, To) {
258+
return {saturating_convert(from.lo(), To()),
259+
saturating_convert(from.hi(), To())};
260+
}
261+
262+
template <typename To, typename From, size_t N>
263+
YNN_ALWAYS_INLINE vec<To, N> saturating_rounding_convert(vec<From, N> from, To) {
264+
return {saturating_rounding_convert(from.lo(), To()),
265+
saturating_rounding_convert(from.hi(), To())};
266+
}
267+
256268
template <typename T, size_t N>
257269
YNN_ALWAYS_INLINE T horizontal_sum(vec<T, N> x) {
258270
return horizontal_sum(x.lo() + x.hi());

ynnpack/base/simd/test/arm_neon.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ TEST_CONVERT(arm_neon, f32, s32x4);
158158
TEST_CONVERT(arm_neon, s32, f32x4);
159159
TEST_CONVERT(arm_neon, f32, bf16x8);
160160

161+
TEST_SATURATING_CONVERT(arm_neon, s16, s32x8);
162+
TEST_SATURATING_CONVERT(arm_neon, u8, s16x16);
163+
TEST_SATURATING_CONVERT(arm_neon, s8, s16x16);
164+
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, u8, f32x16);
165+
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s8, f32x16);
166+
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s16, f32x8);
167+
161168
TEST_EXTRACT(arm_neon, u8x16, 8);
162169

163170
TEST_CONCAT(arm_neon, u8x8);

ynnpack/base/simd/test/generic.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <gmock/gmock.h>
1818
#include <gtest/gtest.h>
19+
#include "ynnpack/base/arithmetic.h"
1920
#include "ynnpack/base/bfloat16.h"
2021
#include "ynnpack/base/half.h"
2122
#include "ynnpack/base/simd/vec.h"
@@ -622,6 +623,70 @@ void test_convert() {
622623
#define TEST_CONVERT(test_class, to, from) \
623624
TEST_F(test_class, convert_##to##_##from) { test_convert<to, from>(); }
624625

626+
template <typename To, typename From>
627+
To saturating_convert_reference(From from, bool round = false) {
628+
if constexpr (std::is_floating_point_v<From> && std::is_integral_v<To>) {
629+
if (round) {
630+
return round_float_to_int<To>(from);
631+
}
632+
}
633+
return ynn::saturate_cast<To>(from);
634+
}
635+
636+
template <typename To, typename From>
637+
void test_saturating_convert() {
638+
using FromScalar = typename From::value_type;
639+
static constexpr size_t N = From::N;
640+
using vector = vec<FromScalar, N>;
641+
642+
ReplicableRandomDevice rng;
643+
for (auto _ : FuzzTest(std::chrono::milliseconds(100))) {
644+
FromScalar src[N];
645+
fill_random(src, N, rng);
646+
From from_v = load(src, vector::N);
647+
auto to_v = saturating_convert(from_v, To{});
648+
649+
To dst[N];
650+
store(dst, to_v);
651+
for (size_t i = 0; i < N; ++i) {
652+
ASSERT_EQ(dst[i],
653+
saturating_convert_reference<To>(src[i], /*round=*/false));
654+
}
655+
}
656+
}
657+
658+
#define TEST_SATURATING_CONVERT(test_class, to, from) \
659+
TEST_F(test_class, saturating_convert_##to##_##from) { \
660+
test_saturating_convert<to, from>(); \
661+
}
662+
663+
template <typename To, typename From>
664+
void test_saturating_rounding_convert() {
665+
using FromScalar = typename From::value_type;
666+
static constexpr size_t N = From::N;
667+
using vector = vec<FromScalar, N>;
668+
669+
ReplicableRandomDevice rng;
670+
for (auto _ : FuzzTest(std::chrono::milliseconds(100))) {
671+
FromScalar src[N];
672+
fill_random(src, N, rng);
673+
From from_v = load(src, vector::N);
674+
auto to_v = saturating_rounding_convert(from_v, To{});
675+
676+
To dst[N];
677+
store(dst, to_v);
678+
for (size_t i = 0; i < N; ++i) {
679+
ASSERT_EQ(dst[i],
680+
saturating_convert_reference<To>(src[i], /*round=*/true));
681+
}
682+
}
683+
}
684+
685+
#define TEST_SATURATING_ROUNDING_CONVERT(test_class, to, from) \
686+
TEST_F(test_class, saturating_rounding_convert_##to##_##from) { \
687+
test_saturating_rounding_convert<to, from>(); \
688+
}
689+
625690
template <typename scalar, size_t N>
626691
void test_horizontal_sum() {
627692
using vector = vec<scalar, N>;

ynnpack/base/simd/test/x86_avx2.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,12 @@ TEST_CONVERT(x86_avx2, s32, s8x16);
7171
TEST_CONVERT(x86_avx2, f32, s32x8);
7272
TEST_CONVERT(x86_avx2, s32, f32x8);
7373

74+
TEST_SATURATING_CONVERT(x86_avx2, s16, s32x16);
75+
TEST_SATURATING_CONVERT(x86_avx2, u8, s16x32);
76+
TEST_SATURATING_CONVERT(x86_avx2, s8, s16x32);
77+
TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, u8, f32x32);
78+
TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, s8, f32x32);
79+
TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, s16, f32x16);
80+
7481
} // namespace simd
7582
} // namespace ynn

ynnpack/base/simd/test/x86_avx512.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,12 @@ TEST_CONVERT(x86_avx512, s32, u8x32);
174174
TEST_CONVERT(x86_avx512, f32, s32x16);
175175
TEST_CONVERT(x86_avx512, s32, f32x16);
176176

177+
TEST_SATURATING_CONVERT(x86_avx512, s16, s32x32);
178+
TEST_SATURATING_CONVERT(x86_avx512, u8, s16x64);
179+
TEST_SATURATING_CONVERT(x86_avx512, s8, s16x64);
180+
TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, u8, f32x64);
181+
TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, s8, f32x64);
182+
TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, s16, f32x32);
183+
177184
} // namespace simd
178185
} // namespace ynn

ynnpack/base/simd/test/x86_sse2.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ TEST_CONVERT(x86_sse2, s32, f32x4);
127127
TEST_CONVERT(x86_sse2, f64, f32x4);
128128
TEST_CONVERT(x86_sse2, f32, f64x4);
129129

130+
TEST_SATURATING_CONVERT(x86_sse2, s16, s32x8);
131+
TEST_SATURATING_CONVERT(x86_sse2, u8, s16x16);
132+
TEST_SATURATING_CONVERT(x86_sse2, s8, s16x16);
133+
TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, u8, f32x16);
134+
TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, s8, f32x16);
135+
TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, s16, f32x8);
136+
130137
TEST_FMA(x86_sse2, f32, 4);
131138

132139
} // namespace simd

ynnpack/base/simd/vec.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <limits>
1717
#include <type_traits>
1818

19+
#include "ynnpack/base/arithmetic.h"
1920
#include "ynnpack/base/base.h"
2021

2122
namespace ynn {
@@ -148,6 +149,12 @@ auto extract(vec<T, N>, SliceN);
148149
template <typename To, typename From, size_t N>
149150
vec<To, N> convert(vec<From, N> from, To);
150151

152+
template <typename To, typename From, size_t N>
153+
vec<To, N> saturating_convert(vec<From, N> from, To);
154+
155+
template <typename To, typename From, size_t N>
156+
vec<To, N> saturating_rounding_convert(vec<From, N> from, To);
157+
151158
namespace internal {
152159

153160
template <typename T, size_t N>
@@ -307,6 +314,21 @@ YNN_ALWAYS_INLINE vec<To, 1> convert(vec<From, 1> from, To) {
307314
return vec<To, 1>{static_cast<To>(from.v)};
308315
}
309316

317+
template <typename To, typename From>
318+
YNN_ALWAYS_INLINE vec<To, 1> saturating_convert(vec<From, 1> from, To) {
319+
return vec<To, 1>{saturate_cast<To>(from.v)};
320+
}
321+
322+
template <typename To, typename From>
323+
YNN_ALWAYS_INLINE vec<To, 1> saturating_rounding_convert(vec<From, 1> from,
324+
To) {
325+
if constexpr (std::is_same_v<From, float>) {
326+
return vec<To, 1>{round_float_to_int<To>(from.v)};
327+
} else {
328+
return vec<To, 1>{saturate_cast<To>(from.v)};
329+
}
330+
}
331+
310332
template <typename T>
311333
YNN_ALWAYS_INLINE vec<T, 1> operator&(vec<T, 1> a, vec<T, 1> b) {
312334
return vec<T, 1>{static_cast<T>(a.v & b.v)};

ynnpack/base/simd/x86_avx2.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace simd {
2323
using f32x16 = vec<float, 16>;
2424
using s32x16 = vec<int32_t, 16>;
2525
using s32x32 = vec<int32_t, 32>;
26+
using f32x32 = vec<float, 32>;
27+
using s16x32 = vec<int16_t, 32>;
2628

2729
YNN_ALWAYS_INLINE s32x16 convert(s8x16 a, int32_t) {
2830
return {
@@ -58,6 +60,53 @@ YNN_ALWAYS_INLINE bf16x16 convert(f32x16 a, bfloat16) {
5860
return bf16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
5961
}
6062

63+
YNN_ALWAYS_INLINE s16x16 saturating_convert(s32x16 a, int16_t) {
64+
const __m256i r = _mm256_packs_epi32(a.lo().v, a.hi().v);
65+
return s16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
66+
}
67+
68+
YNN_ALWAYS_INLINE s8x32 saturating_convert(s16x32 a, int8_t) {
69+
const __m256i r = _mm256_packs_epi16(a.lo().v, a.hi().v);
70+
return s8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
71+
}
72+
73+
YNN_ALWAYS_INLINE u8x32 saturating_convert(s16x32 a, uint8_t) {
74+
const __m256i r = _mm256_packus_epi16(a.lo().v, a.hi().v);
75+
return u8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))};
76+
}
77+
78+
YNN_ALWAYS_INLINE s16x16 saturating_rounding_convert(f32x16 f, int16_t) {
79+
const __m256 max_int16 = _mm256_set1_ps((float)((1 << 15) - 1));
80+
const __m256i i0 = _mm256_cvtps_epi32(_mm256_min_ps(f.lo().v, max_int16));
81+
const __m256i i1 = _mm256_cvtps_epi32(_mm256_min_ps(f.hi().v, max_int16));
82+
return saturating_convert(s32x16(s32x8(i0), s32x8(i1)), int16_t());
83+
}
84+
85+
YNN_ALWAYS_INLINE s8x32 saturating_rounding_convert(f32x32 f, int8_t) {
86+
const s16x16 i01 =
87+
saturating_rounding_convert(f32x16(f.lo().lo(), f.lo().hi()), int16_t());
88+
const s16x16 i23 =
89+
saturating_rounding_convert(f32x16(f.hi().lo(), f.hi().hi()), int16_t());
90+
return saturating_convert(s16x32(i01, i23), int8_t());
91+
}
92+
93+
YNN_ALWAYS_INLINE u8x32 saturating_rounding_convert(f32x32 f, uint8_t) {
94+
const __m256 max_uint16 = _mm256_set1_ps((float)((1 << 16) - 1));
95+
const __m256i i0 =
96+
_mm256_cvtps_epi32(_mm256_min_ps(f.lo().lo().v, max_uint16));
97+
const __m256i i1 =
98+
_mm256_cvtps_epi32(_mm256_min_ps(f.lo().hi().v, max_uint16));
99+
const __m256i i2 =
100+
_mm256_cvtps_epi32(_mm256_min_ps(f.hi().lo().v, max_uint16));
101+
const __m256i i3 =
102+
_mm256_cvtps_epi32(_mm256_min_ps(f.hi().hi().v, max_uint16));
103+
const __m256i i01_16 = _mm256_packs_epi32(i0, i1);
104+
const __m256i i23_16 = _mm256_packs_epi32(i2, i3);
105+
const __m256i r = _mm256_packus_epi16(i01_16, i23_16);
106+
return u8x32{_mm256_permutevar8x32_epi32(
107+
r, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7))};
108+
}
109+
61110
} // namespace simd
62111

63112
} // namespace ynn

0 commit comments

Comments
 (0)