Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ynnpack/base/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ constexpr T saturate_cast(U x) noexcept {
return static_cast<T>(x);
}

template <typename T>
T add_sat(T a, T b) {
return saturate_cast<T>(static_cast<int64_t>(a) + static_cast<int64_t>(b));
}

template <typename T>
T sub_sat(T a, T b) {
return saturate_cast<T>(static_cast<int64_t>(a) - static_cast<int64_t>(b));
}

} // namespace ynn

#endif // XNNPACK_YNNPACK_BASE_ARITHMETIC_H_
2 changes: 1 addition & 1 deletion ynnpack/base/simd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cc_library(
"x86_sse41_base.h",
"x86_avx_partial_load_store.h",
"x86_sse2_partial_load_store.h",
"x86_sse2_saturating_convert.h",
"x86_sse2_saturate_cast.h",
"generic.inc",
# For the most part, only one of these headers should be included. Multiple of these headers
# may define the same operation and type, using a different implementation, depending on the
Expand Down
60 changes: 30 additions & 30 deletions ynnpack/base/simd/arm_neon_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,47 +423,47 @@ YNN_ALWAYS_INLINE s8x16 operator*(s8x16 a, s8x16 b) {
return s8x16{vmulq_s8(a.v, b.v)};
}

YNN_ALWAYS_INLINE s32x4 saturating_add(s32x4 a, s32x4 b) {
YNN_ALWAYS_INLINE s32x4 add_sat(s32x4 a, s32x4 b) {
return s32x4{vqaddq_s32(a.v, b.v)};
}
YNN_ALWAYS_INLINE u32x4 saturating_add(u32x4 a, u32x4 b) {
YNN_ALWAYS_INLINE u32x4 add_sat(u32x4 a, u32x4 b) {
return u32x4{vqaddq_u32(a.v, b.v)};
}
YNN_ALWAYS_INLINE s16x8 saturating_add(s16x8 a, s16x8 b) {
YNN_ALWAYS_INLINE s16x8 add_sat(s16x8 a, s16x8 b) {
return s16x8{vqaddq_s16(a.v, b.v)};
}
YNN_ALWAYS_INLINE u16x8 saturating_add(u16x8 a, u16x8 b) {
YNN_ALWAYS_INLINE u16x8 add_sat(u16x8 a, u16x8 b) {
return u16x8{vqaddq_u16(a.v, b.v)};
}
YNN_ALWAYS_INLINE s8x16 saturating_add(s8x16 a, s8x16 b) {
YNN_ALWAYS_INLINE s8x16 add_sat(s8x16 a, s8x16 b) {
return s8x16{vqaddq_s8(a.v, b.v)};
}
YNN_ALWAYS_INLINE u8x16 saturating_add(u8x16 a, u8x16 b) {
YNN_ALWAYS_INLINE u8x16 add_sat(u8x16 a, u8x16 b) {
return u8x16{vqaddq_u8(a.v, b.v)};
}
YNN_ALWAYS_INLINE u8x8 saturating_add(u8x8 a, u8x8 b) {
YNN_ALWAYS_INLINE u8x8 add_sat(u8x8 a, u8x8 b) {
return u8x8{vqadd_u8(a.v, b.v)};
}

YNN_ALWAYS_INLINE s32x4 saturating_sub(s32x4 a, s32x4 b) {
YNN_ALWAYS_INLINE s32x4 sub_sat(s32x4 a, s32x4 b) {
return s32x4{vqsubq_s32(a.v, b.v)};
}
YNN_ALWAYS_INLINE u32x4 saturating_sub(u32x4 a, u32x4 b) {
YNN_ALWAYS_INLINE u32x4 sub_sat(u32x4 a, u32x4 b) {
return u32x4{vqsubq_u32(a.v, b.v)};
}
YNN_ALWAYS_INLINE s16x8 saturating_sub(s16x8 a, s16x8 b) {
YNN_ALWAYS_INLINE s16x8 sub_sat(s16x8 a, s16x8 b) {
return s16x8{vqsubq_s16(a.v, b.v)};
}
YNN_ALWAYS_INLINE u16x8 saturating_sub(u16x8 a, u16x8 b) {
YNN_ALWAYS_INLINE u16x8 sub_sat(u16x8 a, u16x8 b) {
return u16x8{vqsubq_u16(a.v, b.v)};
}
YNN_ALWAYS_INLINE s8x16 saturating_sub(s8x16 a, s8x16 b) {
YNN_ALWAYS_INLINE s8x16 sub_sat(s8x16 a, s8x16 b) {
return s8x16{vqsubq_s8(a.v, b.v)};
}
YNN_ALWAYS_INLINE u8x16 saturating_sub(u8x16 a, u8x16 b) {
YNN_ALWAYS_INLINE u8x16 sub_sat(u8x16 a, u8x16 b) {
return u8x16{vqsubq_u8(a.v, b.v)};
}
YNN_ALWAYS_INLINE u8x8 saturating_sub(u8x8 a, u8x8 b) {
YNN_ALWAYS_INLINE u8x8 sub_sat(u8x8 a, u8x8 b) {
return u8x8{vqsub_u8(a.v, b.v)};
}

Expand Down Expand Up @@ -849,69 +849,69 @@ using s16x16 = vec<int16_t, 16>;
using s32x16 = vec<int32_t, 16>;
using f32x16 = vec<float, 16>;

YNN_ALWAYS_INLINE f32x8 convert(bf16x8 a, float) {
YNN_ALWAYS_INLINE f32x8 cast(bf16x8 a, float) {
uint16x8x2_t a_u32 = vzipq_u16(vdupq_n_u16(0), a.v);
return {
f32x4{vreinterpretq_f32_u32(vreinterpretq_u32_u16(a_u32.val[0]))},
f32x4{vreinterpretq_f32_u32(vreinterpretq_u32_u16(a_u32.val[1]))},
};
}

YNN_ALWAYS_INLINE s16x16 convert(s8x16 b, int16_t) {
YNN_ALWAYS_INLINE s16x16 cast(s8x16 b, int16_t) {
return {
s16x8{vmovl_s8(vget_low_s8(b.v))},
s16x8{vmovl_s8(vget_high_s8(b.v))},
};
}

YNN_ALWAYS_INLINE s16x16 convert(u8x16 b, int16_t) {
YNN_ALWAYS_INLINE s16x16 cast(u8x16 b, int16_t) {
return {
s16x8{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(b.v)))},
s16x8{vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(b.v)))},
};
}

YNN_ALWAYS_INLINE s32x8 convert(s16x8 b, int32_t) {
YNN_ALWAYS_INLINE s32x8 cast(s16x8 b, int32_t) {
return {
s32x4{vmovl_s16(vget_low_s16(b.v))},
s32x4{vmovl_s16(vget_high_s16(b.v))},
};
}

YNN_ALWAYS_INLINE s32x16 convert(s8x16 b, int32_t) {
return convert(convert(b, int16_t{}), int32_t{});
YNN_ALWAYS_INLINE s32x16 cast(s8x16 b, int32_t) {
return cast(cast(b, int16_t{}), int32_t{});
}

YNN_ALWAYS_INLINE s32x16 convert(u8x16 b, int32_t) {
return convert(convert(b, int16_t{}), int32_t{});
YNN_ALWAYS_INLINE s32x16 cast(u8x16 b, int32_t) {
return cast(cast(b, int16_t{}), int32_t{});
}

YNN_ALWAYS_INLINE f32x4 convert(s32x4 x, float) {
YNN_ALWAYS_INLINE f32x4 cast(s32x4 x, float) {
return f32x4{vcvtq_f32_s32(x.v)};
}

YNN_ALWAYS_INLINE s32x4 convert(f32x4 x, int32_t) {
YNN_ALWAYS_INLINE s32x4 cast(f32x4 x, int32_t) {
return s32x4{vcvtq_s32_f32(x.v)};
}

YNN_ALWAYS_INLINE s16x8 saturating_convert(s32x8 a, int16_t) {
YNN_ALWAYS_INLINE s16x8 saturate_cast(s32x8 a, int16_t) {
return s16x8{vcombine_s16(vqmovn_s32(a.lo().v), vqmovn_s32(a.hi().v))};
}

YNN_ALWAYS_INLINE s8x16 saturating_convert(s16x16 a, int8_t) {
YNN_ALWAYS_INLINE s8x16 saturate_cast(s16x16 a, int8_t) {
return s8x16{vcombine_s8(vqmovn_s16(a.lo().v), vqmovn_s16(a.hi().v))};
}

YNN_ALWAYS_INLINE u8x16 saturating_convert(s16x16 a, uint8_t) {
YNN_ALWAYS_INLINE u8x16 saturate_cast(s16x16 a, uint8_t) {
return u8x16{vcombine_u8(vqmovun_s16(a.lo().v), vqmovun_s16(a.hi().v))};
}

YNN_ALWAYS_INLINE s16x8 saturating_rounding_convert(f32x8 f, int16_t) {
YNN_ALWAYS_INLINE s16x8 round_float_to_int(f32x8 f, int16_t) {
return s16x8{vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().v)),
vqmovn_s32(internal::cast_f32_to_int32(f.hi().v)))};
}

YNN_ALWAYS_INLINE s8x16 saturating_rounding_convert(f32x16 f, int8_t) {
YNN_ALWAYS_INLINE s8x16 round_float_to_int(f32x16 f, int8_t) {
const int16x8_t i01 =
vcombine_s16(vqmovn_s32(internal::cast_f32_to_int32(f.lo().lo().v)),
vqmovn_s32(internal::cast_f32_to_int32(f.lo().hi().v)));
Expand All @@ -921,7 +921,7 @@ YNN_ALWAYS_INLINE s8x16 saturating_rounding_convert(f32x16 f, int8_t) {
return s8x16{vcombine_s8(vqmovn_s16(i01), vqmovn_s16(i23))};
}

YNN_ALWAYS_INLINE u8x16 saturating_rounding_convert(f32x16 f, uint8_t) {
YNN_ALWAYS_INLINE u8x16 round_float_to_int(f32x16 f, uint8_t) {
const uint16x8_t i01 =
vcombine_u16(vqmovn_u32(internal::cast_f32_to_uint32(f.lo().lo().v)),
vqmovn_u32(internal::cast_f32_to_uint32(f.lo().hi().v)));
Expand Down
8 changes: 4 additions & 4 deletions ynnpack/base/simd/arm_neonfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ namespace simd {

using f32x8 = vec<float, 8>;

YNN_ALWAYS_INLINE f32x4 convert(f16x4 a, float) {
YNN_ALWAYS_INLINE f32x4 cast(f16x4 a, float) {
return f32x4{vcvt_f32_f16(vreinterpret_f16_u16(a.v))};
}

YNN_ALWAYS_INLINE f32x8 convert(f16x8 a, float) {
YNN_ALWAYS_INLINE f32x8 cast(f16x8 a, float) {
return {
f32x4{vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(a.v)))},
f32x4{vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(a.v)))},
};
}

YNN_ALWAYS_INLINE f16x4 convert(f32x4 a, half) {
YNN_ALWAYS_INLINE f16x4 cast(f32x4 a, half) {
return f16x4{vreinterpret_u16_f16(vcvt_f16_f32(a.v))};
}

YNN_ALWAYS_INLINE f16x8 convert(f32x8 a, half) {
YNN_ALWAYS_INLINE f16x8 cast(f32x8 a, half) {
return f16x8{vreinterpretq_u16_f16(
vcombine_f16(vcvt_f16_f32(a.lo().v), vcvt_f16_f32(a.hi().v)))};
}
Expand Down
26 changes: 13 additions & 13 deletions ynnpack/base/simd/generic.inc
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ YNN_ALWAYS_INLINE vec<T, N> abs(vec<T, N> a) {
return {abs(a.lo()), abs(a.hi())};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> saturating_add(vec<T, N> a, vec<T, N> b) {
return {saturating_add(a.lo(), b.lo()), saturating_add(a.hi(), b.hi())};
YNN_ALWAYS_INLINE vec<T, N> add_sat(vec<T, N> a, vec<T, N> b) {
return {add_sat(a.lo(), b.lo()), add_sat(a.hi(), b.hi())};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> saturating_sub(vec<T, N> a, vec<T, N> b) {
return {saturating_sub(a.lo(), b.lo()), saturating_sub(a.hi(), b.hi())};
YNN_ALWAYS_INLINE vec<T, N> sub_sat(vec<T, N> a, vec<T, N> b) {
return {sub_sat(a.lo(), b.lo()), sub_sat(a.hi(), b.hi())};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> floor(vec<T, N> a) {
Expand Down Expand Up @@ -245,24 +245,24 @@ YNN_ALWAYS_INLINE vec<T, N*2> concat(vec<T, N> a, vec<T, N> b) {
}

template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> convert(vec<T, N> from, T) {
YNN_ALWAYS_INLINE vec<T, N> cast(vec<T, N> from, T) {
return from;
}
template <typename To, typename From, size_t N>
YNN_ALWAYS_INLINE vec<To, N> convert(vec<From, N> from, To) {
return {convert(from.lo(), To()), convert(from.hi(), To())};
YNN_ALWAYS_INLINE vec<To, N> cast(vec<From, N> from, To) {
return {cast(from.lo(), To()), cast(from.hi(), To())};
}

template <typename To, typename From, size_t N>
YNN_ALWAYS_INLINE vec<To, N> saturating_convert(vec<From, N> from, To) {
return {saturating_convert(from.lo(), To()),
saturating_convert(from.hi(), To())};
YNN_ALWAYS_INLINE vec<To, N> saturate_cast(vec<From, N> from, To) {
return {saturate_cast(from.lo(), To()),
saturate_cast(from.hi(), To())};
}

template <typename To, typename From, size_t N>
YNN_ALWAYS_INLINE vec<To, N> saturating_rounding_convert(vec<From, N> from, To) {
return {saturating_rounding_convert(from.lo(), To()),
saturating_rounding_convert(from.hi(), To())};
YNN_ALWAYS_INLINE vec<To, N> round_float_to_int(vec<From, N> from, To) {
return {round_float_to_int(from.lo(), To()),
round_float_to_int(from.hi(), To())};
}

template <typename T, size_t N>
Expand Down
18 changes: 9 additions & 9 deletions ynnpack/base/simd/hexagon_hvx.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,29 +604,29 @@ using s32x64 = simd::vec<int32_t, 64>;
using s32x128 = simd::vec<int32_t, 128>;
using s16x128 = simd::vec<int16_t, 128>;

YNN_ALWAYS_INLINE s16x128 convert(s8x128 x, int16_t) {
YNN_ALWAYS_INLINE s16x128 cast(s8x128 x, int16_t) {
HVX_VectorPair result = Q6_Wh_vunpack_Vb(x.v);
return {s16x64{Q6_V_lo_W(result)}, s16x64{Q6_V_hi_W(result)}};
}
YNN_ALWAYS_INLINE s16x128 convert(u8x128 x, int16_t) {
YNN_ALWAYS_INLINE s16x128 cast(u8x128 x, int16_t) {
HVX_VectorPair result = Q6_Wuh_vunpack_Vub(x.v);
return {s16x64{Q6_V_lo_W(result)}, s16x64{Q6_V_hi_W(result)}};
}

YNN_ALWAYS_INLINE s32x64 convert(s16x64 x, int32_t) {
YNN_ALWAYS_INLINE s32x64 cast(s16x64 x, int32_t) {
HVX_VectorPair result = Q6_Ww_vunpack_Vh(x.v);
return {s32x32{Q6_V_lo_W(result)}, s32x32{Q6_V_hi_W(result)}};
}

YNN_ALWAYS_INLINE s32x128 convert(s8x128 x, int32_t) {
YNN_ALWAYS_INLINE s32x128 cast(s8x128 x, int32_t) {
HVX_VectorPair s16 = Q6_Wh_vunpack_Vb(x.v);
return {convert(s16x64{Q6_V_lo_W(s16)}, int32_t{}),
convert(s16x64{Q6_V_hi_W(s16)}, int32_t{})};
return {cast(s16x64{Q6_V_lo_W(s16)}, int32_t{}),
cast(s16x64{Q6_V_hi_W(s16)}, int32_t{})};
}
YNN_ALWAYS_INLINE s32x128 convert(u8x128 x, int32_t) {
YNN_ALWAYS_INLINE s32x128 cast(u8x128 x, int32_t) {
HVX_VectorPair s16 = Q6_Wuh_vunpack_Vub(x.v);
return {convert(s16x64{Q6_V_lo_W(s16)}, int32_t{}),
convert(s16x64{Q6_V_hi_W(s16)}, int32_t{})};
return {cast(s16x64{Q6_V_lo_W(s16)}, int32_t{}),
cast(s16x64{Q6_V_hi_W(s16)}, int32_t{})};
}

template <typename ElemSizeBits>
Expand Down
52 changes: 26 additions & 26 deletions ynnpack/base/simd/test/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,19 @@ TEST_SUBTRACT(arm_neon, s8, 16);
TEST_SUBTRACT(arm_neon, f32, 4);
TEST_SUBTRACT(arm_neon, s32, 4);

TEST_SATURATING_ADD(arm_neon, u8, 16);
TEST_SATURATING_ADD(arm_neon, s8, 16);
TEST_SATURATING_ADD(arm_neon, u16, 8);
TEST_SATURATING_ADD(arm_neon, s16, 8);
TEST_SATURATING_ADD(arm_neon, u32, 4);
TEST_SATURATING_ADD(arm_neon, s32, 4);

TEST_SATURATING_SUB(arm_neon, u8, 16);
TEST_SATURATING_SUB(arm_neon, s8, 16);
TEST_SATURATING_SUB(arm_neon, u16, 8);
TEST_SATURATING_SUB(arm_neon, s16, 8);
TEST_SATURATING_SUB(arm_neon, u32, 4);
TEST_SATURATING_SUB(arm_neon, s32, 4);
TEST_ADD_SAT(arm_neon, u8, 16);
TEST_ADD_SAT(arm_neon, s8, 16);
TEST_ADD_SAT(arm_neon, u16, 8);
TEST_ADD_SAT(arm_neon, s16, 8);
TEST_ADD_SAT(arm_neon, u32, 4);
TEST_ADD_SAT(arm_neon, s32, 4);

TEST_SUB_SAT(arm_neon, u8, 16);
TEST_SUB_SAT(arm_neon, s8, 16);
TEST_SUB_SAT(arm_neon, u16, 8);
TEST_SUB_SAT(arm_neon, s16, 8);
TEST_SUB_SAT(arm_neon, u32, 4);
TEST_SUB_SAT(arm_neon, s32, 4);

TEST_MULTIPLY(arm_neon, u8, 16);
TEST_MULTIPLY(arm_neon, s8, 16);
Expand Down Expand Up @@ -151,19 +151,19 @@ TEST_HORIZONTAL_MAX(arm_neon, s16, 8);
TEST_HORIZONTAL_MAX(arm_neon, f32, 4);
TEST_HORIZONTAL_MAX(arm_neon, s32, 4);

TEST_CONVERT(arm_neon, s32, s8x16);
TEST_CONVERT(arm_neon, s32, u8x16);
TEST_CONVERT(arm_neon, s32, s16x8);
TEST_CONVERT(arm_neon, f32, s32x4);
TEST_CONVERT(arm_neon, s32, f32x4);
TEST_CONVERT(arm_neon, f32, bf16x8);

TEST_SATURATING_CONVERT(arm_neon, s16, s32x8);
TEST_SATURATING_CONVERT(arm_neon, u8, s16x16);
TEST_SATURATING_CONVERT(arm_neon, s8, s16x16);
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, u8, f32x16);
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s8, f32x16);
TEST_SATURATING_ROUNDING_CONVERT(arm_neon, s16, f32x8);
TEST_CAST(arm_neon, s32, s8x16);
TEST_CAST(arm_neon, s32, u8x16);
TEST_CAST(arm_neon, s32, s16x8);
TEST_CAST(arm_neon, f32, s32x4);
TEST_CAST(arm_neon, s32, f32x4);
TEST_CAST(arm_neon, f32, bf16x8);

TEST_SATURATE_CAST(arm_neon, s16, s32x8);
TEST_SATURATE_CAST(arm_neon, u8, s16x16);
TEST_SATURATE_CAST(arm_neon, s8, s16x16);
TEST_ROUND_FLOAT_TO_INT(arm_neon, u8, f32x16);
TEST_ROUND_FLOAT_TO_INT(arm_neon, s8, f32x16);
TEST_ROUND_FLOAT_TO_INT(arm_neon, s16, f32x8);

TEST_EXTRACT(arm_neon, u8x16, 8);

Expand Down
8 changes: 4 additions & 4 deletions ynnpack/base/simd/test/arm_neonfp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class arm_neonfp16 : public ::testing::Test {
}
};

TEST_CONVERT(arm_neonfp16, f32, f16x4);
TEST_CONVERT(arm_neonfp16, f32, f16x8);
TEST_CONVERT(arm_neonfp16, f16, f32x4);
TEST_CONVERT(arm_neonfp16, f16, f32x8);
TEST_CAST(arm_neonfp16, f32, f16x4);
TEST_CAST(arm_neonfp16, f32, f16x8);
TEST_CAST(arm_neonfp16, f16, f32x4);
TEST_CAST(arm_neonfp16, f16, f32x8);

} // namespace simd
} // namespace ynn
Loading
Loading