From 811d0bd388ba8f4853c5560b6e16a5af3e8b895a Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Wed, 18 Mar 2026 18:25:58 -0700 Subject: [PATCH] Add simd wrappers for saturating_convert and saturating_rounding_convert. PiperOrigin-RevId: 885892423 --- ynnpack/base/simd/BUILD | 1 + ynnpack/base/simd/arm_neon_base.h | 58 +++++++++++++++ ynnpack/base/simd/generic.inc | 12 +++ ynnpack/base/simd/test/arm_neon.cc | 7 ++ ynnpack/base/simd/test/generic.h | 65 +++++++++++++++++ ynnpack/base/simd/test/x86_avx2.cc | 7 ++ ynnpack/base/simd/test/x86_avx512.cc | 7 ++ ynnpack/base/simd/test/x86_sse2.cc | 7 ++ ynnpack/base/simd/vec.h | 22 ++++++ ynnpack/base/simd/x86_avx2.h | 49 +++++++++++++ ynnpack/base/simd/x86_avx512.h | 62 ++++++++++++++++ ynnpack/base/simd/x86_sse2.h | 1 + .../base/simd/x86_sse2_saturating_convert.h | 73 +++++++++++++++++++ ynnpack/base/simd/x86_sse41.h | 1 + ynnpack/kernels/elementwise/compiler.py | 11 --- 15 files changed, 372 insertions(+), 11 deletions(-) create mode 100644 ynnpack/base/simd/x86_sse2_saturating_convert.h diff --git a/ynnpack/base/simd/BUILD b/ynnpack/base/simd/BUILD index 0cfea8b8327..74ceb5f0f13 100644 --- a/ynnpack/base/simd/BUILD +++ b/ynnpack/base/simd/BUILD @@ -37,6 +37,7 @@ cc_library( "x86_sse41_base.h", "x86_avx_partial_load_store.h", "x86_sse2_partial_load_store.h", + "x86_sse2_saturating_convert.h", "generic.inc", # For the most part, only one of these headers should be included. Multiple of these headers # may define the same operation and type, using a different implementation, depending on the diff --git a/ynnpack/base/simd/arm_neon_base.h b/ynnpack/base/simd/arm_neon_base.h index a21d7cb16ac..1d1f6f78d40 100644 --- a/ynnpack/base/simd/arm_neon_base.h +++ b/ynnpack/base/simd/arm_neon_base.h @@ -823,10 +823,31 @@ YNN_ALWAYS_INLINE std::array, 4> transpose( }}; } +namespace internal { + +YNN_ALWAYS_INLINE int32x4_t cast_f32_to_int32(float32x4_t f) { +#if defined(__ARM_ARCH) && __ARM_ARCH < 8 + return vcvtq_s32_f32(round(f32x4{f}).v); +#else + return vcvtnq_s32_f32(f); +#endif +} + +YNN_ALWAYS_INLINE uint32x4_t cast_f32_to_uint32(float32x4_t f) { +#if defined(__ARM_ARCH) && __ARM_ARCH < 8 + return vcvtq_u32_f32(round(f32x4{f}).v); +#else + return vcvtnq_u32_f32(f); +#endif +} + +} // namespace internal + using f32x8 = vec; using s32x8 = vec; using s16x16 = vec; using s32x16 = vec; +using f32x16 = vec; YNN_ALWAYS_INLINE f32x8 convert(bf16x8 a, float) { uint16x8x2_t a_u32 = vzipq_u16(vdupq_n_u16(0), a.v); @@ -873,6 +894,43 @@ YNN_ALWAYS_INLINE s32x4 convert(f32x4 x, int32_t) { return s32x4{vcvtq_s32_f32(x.v)}; } +YNN_ALWAYS_INLINE s16x8 saturating_convert(s32x8 a, int16_t) { + return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))}; +} + +YNN_ALWAYS_INLINE s8x16 saturating_convert(s16x16 a, int8_t) { + return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))}; +} + +YNN_ALWAYS_INLINE u8x16 saturating_convert(s16x16 a, uint8_t) { + return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))}; +} + +YNN_ALWAYS_INLINE s16x8 saturating_rounding_convert(f32x8 f, int16_t) { + return s16x8{vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().v)), + vqmovn_s32(internal::cast_f32_to_int32(f.hi().v)))}; +} + +YNN_ALWAYS_INLINE s8x16 saturating_rounding_convert(f32x16 f, int8_t) { + const int16x8_t i01 = + vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().lo().v)), + vqmovn_s32(internal::cast_f32_to_int32(f.lo().hi().v))); + const int16x8_t i23 = + vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.hi().lo().v)), + vqmovn_s32(internal::cast_f32_to_int32(f.hi().hi().v))); + return s8x16{vcombine_s8(vqmovn_s16(i01), vqmovn_s16(i23))}; +} + +YNN_ALWAYS_INLINE u8x16 saturating_rounding_convert(f32x16 f, uint8_t) { + const uint16x8_t i01 = + vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.lo().lo().v)), + vqmovn_u32(internal::cast_f32_to_uint32(f.lo().hi().v))); + const uint16x8_t i23 = + vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.hi().lo().v)), + vqmovn_u32(internal::cast_f32_to_uint32(f.hi().hi().v))); + return u8x16{vcombine_u8(vqmovn_u16(i01), vqmovn_u16(i23))}; +} + } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/generic.inc b/ynnpack/base/simd/generic.inc index 0ef6f22d17b..cf23ae2e2be 100644 --- a/ynnpack/base/simd/generic.inc +++ b/ynnpack/base/simd/generic.inc @@ -253,6 +253,18 @@ YNN_ALWAYS_INLINE vec convert(vec from, To) { return {convert(from.lo(), To()), convert(from.hi(), To())}; } +template +YNN_ALWAYS_INLINE vec saturating_convert(vec from, To) { + return {saturating_convert(from.lo(), To()), + saturating_convert(from.hi(), To())}; +} + +template +YNN_ALWAYS_INLINE vec saturating_rounding_convert(vec from, To) { + return {saturating_rounding_convert(from.lo(), To()), + saturating_rounding_convert(from.hi(), To())}; +} + template YNN_ALWAYS_INLINE T horizontal_sum(vec x) { return horizontal_sum(x.lo() + x.hi()); diff --git a/ynnpack/base/simd/test/arm_neon.cc b/ynnpack/base/simd/test/arm_neon.cc index ac23f5bec69..b03c0cc31cb 100644 --- a/ynnpack/base/simd/test/arm_neon.cc +++ b/ynnpack/base/simd/test/arm_neon.cc @@ -158,6 +158,13 @@ TEST_CONVERT(arm_neon, f32, s32x4); TEST_CONVERT(arm_neon, s32, f32x4); TEST_CONVERT(arm_neon, f32, bf16x8); +TEST_SATURATING_CONVERT(arm_neon, s16, s32x8); +TEST_SATURATING_CONVERT(arm_neon, u8, s16x16); +TEST_SATURATING_CONVERT(arm_neon, s8, s16x16); +TEST_SATURATING_ROUNDING_CONVERT(arm_neon, u8, f32x16); +TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s8, f32x16); +TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s16, f32x8); + TEST_EXTRACT(arm_neon, u8x16, 8); TEST_CONCAT(arm_neon, u8x8); diff --git a/ynnpack/base/simd/test/generic.h b/ynnpack/base/simd/test/generic.h index a4aa9a0901e..6bd5cb9b35a 100644 --- a/ynnpack/base/simd/test/generic.h +++ b/ynnpack/base/simd/test/generic.h @@ -16,6 +16,7 @@ #include #include +#include "ynnpack/base/arithmetic.h" #include "ynnpack/base/bfloat16.h" #include "ynnpack/base/half.h" #include "ynnpack/base/simd/vec.h" @@ -622,6 +623,70 @@ void test_convert() { #define TEST_CONVERT(test_class, to, from) \ TEST_F(test_class, convert_##to##_##from) { test_convert(); } +template +To saturating_convert_reference(From from, bool round = false) { + if constexpr (std::is_floating_point_v && std::is_integral_v) { + if (round) { + return round_float_to_int(from); + } + } + return ynn::saturate_cast(from); +} + +template +void test_saturating_convert() { + using FromScalar = typename From::value_type; + static constexpr size_t N = From::N; + using vector = vec; + + ReplicableRandomDevice rng; + for (auto _ : FuzzTest(std::chrono::milliseconds(100))) { + FromScalar src[N]; + fill_random(src, N, rng); + From from_v = load(src, vector::N); + auto to_v = saturating_convert(from_v, To{}); + + To dst[N]; + store(dst, to_v); + for (size_t i = 0; i < N; ++i) { + ASSERT_EQ(dst[i], + saturating_convert_reference(src[i], /*round=*/false)); + } + } +} + +#define TEST_SATURATING_CONVERT(test_class, to, from) \ + TEST_F(test_class, saturating_convert_##to##_##from) { \ + test_saturating_convert(); \ + } + +template +void test_saturating_rounding_convert() { + using FromScalar = typename From::value_type; + static constexpr size_t N = From::N; + using vector = vec; + + ReplicableRandomDevice rng; + for (auto _ : FuzzTest(std::chrono::milliseconds(100))) { + FromScalar src[N]; + fill_random(src, N, rng); + From from_v = load(src, vector::N); + auto to_v = saturating_rounding_convert(from_v, To{}); + + To dst[N]; + store(dst, to_v); + for (size_t i = 0; i < N; ++i) { + ASSERT_EQ(dst[i], + saturating_convert_reference(src[i], /*round=*/true)); + } + } +} + +#define TEST_SATURATING_ROUNDING_CONVERT(test_class, to, from) \ + TEST_F(test_class, saturating_rounding_convert_##to##_##from) { \ + test_saturating_rounding_convert(); \ + } + template void test_horizontal_sum() { using vector = vec; diff --git a/ynnpack/base/simd/test/x86_avx2.cc b/ynnpack/base/simd/test/x86_avx2.cc index 6d8b29ac421..b4c1f9312bb 100644 --- a/ynnpack/base/simd/test/x86_avx2.cc +++ b/ynnpack/base/simd/test/x86_avx2.cc @@ -71,5 +71,12 @@ TEST_CONVERT(x86_avx2, s32, s8x16); TEST_CONVERT(x86_avx2, f32, s32x8); TEST_CONVERT(x86_avx2, s32, f32x8); +TEST_SATURATING_CONVERT(x86_avx2, s16, s32x16); +TEST_SATURATING_CONVERT(x86_avx2, u8, s16x32); +TEST_SATURATING_CONVERT(x86_avx2, s8, s16x32); +TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, u8, f32x32); +TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, s8, f32x32); +TEST_SATURATING_ROUNDING_CONVERT(x86_avx2, s16, f32x16); + } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/test/x86_avx512.cc b/ynnpack/base/simd/test/x86_avx512.cc index 205cf890fa1..3a34f0fdf27 100644 --- a/ynnpack/base/simd/test/x86_avx512.cc +++ b/ynnpack/base/simd/test/x86_avx512.cc @@ -174,5 +174,12 @@ TEST_CONVERT(x86_avx512, s32, u8x32); TEST_CONVERT(x86_avx512, f32, s32x16); TEST_CONVERT(x86_avx512, s32, f32x16); +TEST_SATURATING_CONVERT(x86_avx512, s16, s32x32); +TEST_SATURATING_CONVERT(x86_avx512, u8, s16x64); +TEST_SATURATING_CONVERT(x86_avx512, s8, s16x64); +TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, u8, f32x64); +TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, s8, f32x64); +TEST_SATURATING_ROUNDING_CONVERT(x86_avx512, s16, f32x32); + } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/test/x86_sse2.cc b/ynnpack/base/simd/test/x86_sse2.cc index 97cfebbe0ee..c3854852f48 100644 --- a/ynnpack/base/simd/test/x86_sse2.cc +++ b/ynnpack/base/simd/test/x86_sse2.cc @@ -127,6 +127,13 @@ TEST_CONVERT(x86_sse2, s32, f32x4); TEST_CONVERT(x86_sse2, f64, f32x4); TEST_CONVERT(x86_sse2, f32, f64x4); +TEST_SATURATING_CONVERT(x86_sse2, s16, s32x8); +TEST_SATURATING_CONVERT(x86_sse2, u8, s16x16); +TEST_SATURATING_CONVERT(x86_sse2, s8, s16x16); +TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, u8, f32x16); +TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, s8, f32x16); +TEST_SATURATING_ROUNDING_CONVERT(x86_sse2, s16, f32x8); + TEST_FMA(x86_sse2, f32, 4); } // namespace simd diff --git a/ynnpack/base/simd/vec.h b/ynnpack/base/simd/vec.h index 174ad565833..94531f1ffa0 100644 --- a/ynnpack/base/simd/vec.h +++ b/ynnpack/base/simd/vec.h @@ -16,6 +16,7 @@ #include #include +#include "ynnpack/base/arithmetic.h" #include "ynnpack/base/base.h" namespace ynn { @@ -148,6 +149,12 @@ auto extract(vec, SliceN); template vec convert(vec from, To); +template +vec saturating_convert(vec from, To); + +template +vec saturating_rounding_convert(vec from, To); + namespace internal { template @@ -307,6 +314,21 @@ YNN_ALWAYS_INLINE vec convert(vec from, To) { return vec{static_cast(from.v)}; } +template +YNN_ALWAYS_INLINE vec saturating_convert(vec from, To) { + return vec{saturate_cast(from.v)}; +} + +template +YNN_ALWAYS_INLINE vec saturating_rounding_convert(vec from, + To) { + if constexpr (std::is_same_v) { + return vec{round_float_to_int(from.v)}; + } else { + return vec{saturate_cast(from.v)}; + } +} + template YNN_ALWAYS_INLINE vec operator&(vec a, vec b) { return vec{static_cast(a.v & b.v)}; diff --git a/ynnpack/base/simd/x86_avx2.h b/ynnpack/base/simd/x86_avx2.h index f225082ce04..e033d15601a 100644 --- a/ynnpack/base/simd/x86_avx2.h +++ b/ynnpack/base/simd/x86_avx2.h @@ -23,6 +23,8 @@ namespace simd { using f32x16 = vec; using s32x16 = vec; using s32x32 = vec; +using f32x32 = vec; +using s16x32 = vec; YNN_ALWAYS_INLINE s32x16 convert(s8x16 a, int32_t) { return { @@ -58,6 +60,53 @@ YNN_ALWAYS_INLINE bf16x16 convert(f32x16 a, bfloat16) { return bf16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; } +YNN_ALWAYS_INLINE s16x16 saturating_convert(s32x16 a, int16_t) { + const __m256i r = _mm256_packs_epi32(a.lo().v, a.hi().v); + return s16x16{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; +} + +YNN_ALWAYS_INLINE s8x32 saturating_convert(s16x32 a, int8_t) { + const __m256i r = _mm256_packs_epi16(a.lo().v, a.hi().v); + return s8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; +} + +YNN_ALWAYS_INLINE u8x32 saturating_convert(s16x32 a, uint8_t) { + const __m256i r = _mm256_packus_epi16(a.lo().v, a.hi().v); + return u8x32{_mm256_permute4x64_epi64(r, _MM_SHUFFLE(3, 1, 2, 0))}; +} + +YNN_ALWAYS_INLINE s16x16 saturating_rounding_convert(f32x16 f, int16_t) { + const __m256 max_int16 = _mm256_set1_ps((float)((1 << 15) - 1)); + const __m256i i0 = _mm256_cvtps_epi32(_mm256_min_ps(f.lo().v, max_int16)); + const __m256i i1 = _mm256_cvtps_epi32(_mm256_min_ps(f.hi().v, max_int16)); + return saturating_convert(s32x16(s32x8(i0), s32x8(i1)), int16_t()); +} + +YNN_ALWAYS_INLINE s8x32 saturating_rounding_convert(f32x32 f, int8_t) { + const s16x16 i01 = + saturating_rounding_convert(f32x16(f.lo().lo(), f.lo().hi()), int16_t()); + const s16x16 i23 = + saturating_rounding_convert(f32x16(f.hi().lo(), f.hi().hi()), int16_t()); + return saturating_convert(s16x32(i01, i23), int8_t()); +} + +YNN_ALWAYS_INLINE u8x32 saturating_rounding_convert(f32x32 f, uint8_t) { + const __m256 max_uint16 = _mm256_set1_ps((float)((1 << 16) - 1)); + const __m256i i0 = + _mm256_cvtps_epi32(_mm256_min_ps(f.lo().lo().v, max_uint16)); + const __m256i i1 = + _mm256_cvtps_epi32(_mm256_min_ps(f.lo().hi().v, max_uint16)); + const __m256i i2 = + _mm256_cvtps_epi32(_mm256_min_ps(f.hi().lo().v, max_uint16)); + const __m256i i3 = + _mm256_cvtps_epi32(_mm256_min_ps(f.hi().hi().v, max_uint16)); + const __m256i i01_16 = _mm256_packs_epi32(i0, i1); + const __m256i i23_16 = _mm256_packs_epi32(i2, i3); + const __m256i r = _mm256_packus_epi16(i01_16, i23_16); + return u8x32{_mm256_permutevar8x32_epi32( + r, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7))}; +} + } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/x86_avx512.h b/ynnpack/base/simd/x86_avx512.h index 563fbe734dd..f89db2b9f35 100644 --- a/ynnpack/base/simd/x86_avx512.h +++ b/ynnpack/base/simd/x86_avx512.h @@ -203,6 +203,7 @@ using u16x32 = vec; using s16x32 = vec; using u8x64 = vec; using s8x64 = vec; +using f32x64 = vec; YNN_ALWAYS_INLINE f32x16 load_aligned(const float* ptr, decltype(f32x16::N), f32x16 = {}) { @@ -904,6 +905,67 @@ YNN_ALWAYS_INLINE s32x16 convert(f32x16 x, int32_t) { return s32x16{_mm512_cvttps_epi32(x.v)}; } +YNN_ALWAYS_INLINE s16x32 saturating_convert(s32x32 a, int16_t) { + const __m512i r = _mm512_packs_epi32(a.lo().v, a.hi().v); + return s16x32{ + _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), r)}; +} + +YNN_ALWAYS_INLINE s8x64 saturating_convert(s16x64 a, int8_t) { + const __m512i r = _mm512_packs_epi16(a.lo().v, a.hi().v); + return s8x64{ + _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), r)}; +} + +YNN_ALWAYS_INLINE u8x64 saturating_convert(s16x64 a, uint8_t) { + const __m512i r = _mm512_packus_epi16(a.lo().v, a.hi().v); + return u8x64{ + _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), r)}; +} + +YNN_ALWAYS_INLINE s16x32 saturating_rounding_convert(f32x32 f, int16_t) { + const __m512 max_int16 = _mm512_set1_ps((1 << 15) - 1); + const __m512i i0 = _mm512_cvtps_epi32(_mm512_min_ps(f.lo().v, max_int16)); + const __m512i i1 = _mm512_cvtps_epi32(_mm512_min_ps(f.hi().v, max_int16)); + return saturating_convert(s32x32(s32x16(i0), s32x16(i1)), int16_t()); +} + +YNN_ALWAYS_INLINE u8x64 saturating_rounding_convert(f32x64 f, uint8_t) { + const __m512 max_uint16 = _mm512_set1_ps((1 << 16) - 1); + const __m512i i0 = + _mm512_cvtps_epi32(_mm512_min_ps(f.lo().lo().v, max_uint16)); + const __m512i i1 = + _mm512_cvtps_epi32(_mm512_min_ps(f.lo().hi().v, max_uint16)); + const __m512i i2 = + _mm512_cvtps_epi32(_mm512_min_ps(f.hi().lo().v, max_uint16)); + const __m512i i3 = + _mm512_cvtps_epi32(_mm512_min_ps(f.hi().hi().v, max_uint16)); + const __m512i i01_16 = _mm512_packs_epi32(i0, i1); + const __m512i i23_16 = _mm512_packs_epi32(i2, i3); + const __m512i r = _mm512_packus_epi16(i01_16, i23_16); + const __m512i idx = + _mm512_setr_epi32(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + return u8x64{_mm512_permutexvar_epi32(idx, r)}; +} + +YNN_ALWAYS_INLINE s8x64 saturating_rounding_convert(f32x64 f, int8_t) { + const __m512 max_int16 = _mm512_set1_ps((1 << 15) - 1); + const __m512i i0 = + _mm512_cvtps_epi32(_mm512_min_ps(f.lo().lo().v, max_int16)); + const __m512i i1 = + _mm512_cvtps_epi32(_mm512_min_ps(f.lo().hi().v, max_int16)); + const __m512i i2 = + _mm512_cvtps_epi32(_mm512_min_ps(f.hi().lo().v, max_int16)); + const __m512i i3 = + _mm512_cvtps_epi32(_mm512_min_ps(f.hi().hi().v, max_int16)); + const __m512i i01_16 = _mm512_packs_epi32(i0, i1); + const __m512i i23_16 = _mm512_packs_epi32(i2, i3); + const __m512i r = _mm512_packs_epi16(i01_16, i23_16); + const __m512i idx = + _mm512_setr_epi32(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + return s8x64{_mm512_permutexvar_epi32(idx, r)}; +} + } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/x86_sse2.h b/ynnpack/base/simd/x86_sse2.h index ee7dcc11af6..6d5a323b865 100644 --- a/ynnpack/base/simd/x86_sse2.h +++ b/ynnpack/base/simd/x86_sse2.h @@ -14,6 +14,7 @@ #include "ynnpack/base/simd/vec.h" #include "ynnpack/base/simd/x86_sse2_base.h" // IWYU pragma: export #include "ynnpack/base/simd/x86_sse2_partial_load_store.h" // IWYU pragma: export +#include "ynnpack/base/simd/x86_sse2_saturating_convert.h" // IWYU pragma: export namespace ynn { diff --git a/ynnpack/base/simd/x86_sse2_saturating_convert.h b/ynnpack/base/simd/x86_sse2_saturating_convert.h new file mode 100644 index 00000000000..2d033b49c29 --- /dev/null +++ b/ynnpack/base/simd/x86_sse2_saturating_convert.h @@ -0,0 +1,73 @@ +#ifndef XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_SATURATING_CONVERT_H_ +#define XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_SATURATING_CONVERT_H_ + +#include + +#include +#include +#include +#include + +#include "ynnpack/base/base.h" +#include "ynnpack/base/bfloat16.h" +#include "ynnpack/base/half.h" +#include "ynnpack/base/simd/vec.h" + +namespace ynn { + +namespace simd { + +using f32x8 = vec; +using s32x8 = vec; +using s16x16 = vec; +using bf16x16 = vec; +using f16x16 = vec; +using s8x32 = vec; +using u8x32 = vec; +using f64x4 = vec; +using f32x16 = vec; +using s32x16 = vec; + +YNN_ALWAYS_INLINE s16x8 saturating_convert(s32x8 a, int16_t) { + return s16x8{_mm_packs_epi32(a.lo().v, a.hi().v)}; +} + +YNN_ALWAYS_INLINE s8x16 saturating_convert(s16x16 a, int8_t) { + return s8x16{_mm_packs_epi16(a.lo().v, a.hi().v)}; +} + +YNN_ALWAYS_INLINE u8x16 saturating_convert(s16x16 a, uint8_t) { + return u8x16{_mm_packus_epi16(a.lo().v, a.hi().v)}; +} + +YNN_ALWAYS_INLINE s16x8 saturating_rounding_convert(f32x8 f, int16_t) { + const __m128 max_int16 = _mm_set1_ps((float)((1 << 15) - 1)); + const __m128i i0 = _mm_cvtps_epi32(_mm_min_ps(f.lo().v, max_int16)); + const __m128i i1 = _mm_cvtps_epi32(_mm_min_ps(f.hi().v, max_int16)); + return saturating_convert(s32x8(s32x4(i0), s32x4(i1)), int16_t()); +} + +YNN_ALWAYS_INLINE s8x16 saturating_rounding_convert(f32x16 f, int8_t) { + const s16x8 i01 = + saturating_rounding_convert(f32x8(f.lo().lo(), f.lo().hi()), int16_t()); + const s16x8 i23 = + saturating_rounding_convert(f32x8(f.hi().lo(), f.hi().hi()), int16_t()); + return saturating_convert(s16x16(i01, i23), int8_t()); +} + +YNN_ALWAYS_INLINE u8x16 saturating_rounding_convert(f32x16 f, uint8_t) { + const __m128 max_int16 = _mm_set1_ps((1 << 15) - 1); + const __m128i i0 = _mm_cvtps_epi32(_mm_min_ps(f.lo().lo().v, max_int16)); + const __m128i i1 = _mm_cvtps_epi32(_mm_min_ps(f.lo().hi().v, max_int16)); + const __m128i i2 = _mm_cvtps_epi32(_mm_min_ps(f.hi().lo().v, max_int16)); + const __m128i i3 = _mm_cvtps_epi32(_mm_min_ps(f.hi().hi().v, max_int16)); + const __m128i i01_16 = _mm_packs_epi32(i0, i1); + const __m128i i23_16 = _mm_packs_epi32(i2, i3); + return u8x16{_mm_packus_epi16(i01_16, i23_16)}; +} + +} // namespace simd + +} // namespace ynn + +#endif // XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_SATURATING_CONVERT_H_ diff --git a/ynnpack/base/simd/x86_sse41.h b/ynnpack/base/simd/x86_sse41.h index 853bc1d550a..2a68a3221e2 100644 --- a/ynnpack/base/simd/x86_sse41.h +++ b/ynnpack/base/simd/x86_sse41.h @@ -12,6 +12,7 @@ #include "ynnpack/base/simd/vec.h" #include "ynnpack/base/simd/x86_sse2_base.h" // IWYU pragma: export #include "ynnpack/base/simd/x86_sse2_partial_load_store.h" // IWYU pragma: export +#include "ynnpack/base/simd/x86_sse2_saturating_convert.h" // IWYU pragma: export #include "ynnpack/base/simd/x86_sse41_base.h" // IWYU pragma: export namespace ynn { diff --git a/ynnpack/kernels/elementwise/compiler.py b/ynnpack/kernels/elementwise/compiler.py index fa65b1b38d7..762b01fad58 100644 --- a/ynnpack/kernels/elementwise/compiler.py +++ b/ynnpack/kernels/elementwise/compiler.py @@ -764,17 +764,6 @@ def __init__( namespace ynn { namespace { -template -YNN_INTRINSIC T* offset_bytes(T* ptr, std::ptrdiff_t offset) { - return reinterpret_cast(reinterpret_cast(ptr) + offset); -} - -template -YNN_INTRINSIC const T* offset_bytes(const T* ptr, std::ptrdiff_t offset) { - return reinterpret_cast(reinterpret_cast(ptr) + - offset); -} - YNN_INTRINSIC std::size_t min(std::size_t a, std::size_t b) { return a < b ? a : b; }