@@ -23,6 +23,8 @@ namespace simd {
2323using f32x16 = vec<float , 16 >;
2424using s32x16 = vec<int32_t , 16 >;
2525using s32x32 = vec<int32_t , 32 >;
26+ using f32x32 = vec<float , 32 >;
27+ using s16x32 = vec<int16_t , 32 >;
2628
2729YNN_ALWAYS_INLINE s32x16 convert (s8x16 a, int32_t ) {
2830 return {
@@ -58,6 +60,53 @@ YNN_ALWAYS_INLINE bf16x16 convert(f32x16 a, bfloat16) {
5860 return bf16x16{_mm256_permute4x64_epi64 (r, _MM_SHUFFLE (3 , 1 , 2 , 0 ))};
5961}
6062
63+ YNN_ALWAYS_INLINE s16x16 saturating_convert (s32x16 a, int16_t ) {
64+ const __m256i r = _mm256_packs_epi32 (a.lo ().v , a.hi ().v );
65+ return s16x16{_mm256_permute4x64_epi64 (r, _MM_SHUFFLE (3 , 1 , 2 , 0 ))};
66+ }
67+
68+ YNN_ALWAYS_INLINE s8x32 saturating_convert (s16x32 a, int8_t ) {
69+ const __m256i r = _mm256_packs_epi16 (a.lo ().v , a.hi ().v );
70+ return s8x32{_mm256_permute4x64_epi64 (r, _MM_SHUFFLE (3 , 1 , 2 , 0 ))};
71+ }
72+
73+ YNN_ALWAYS_INLINE u8x32 saturating_convert (s16x32 a, uint8_t ) {
74+ const __m256i r = _mm256_packus_epi16 (a.lo ().v , a.hi ().v );
75+ return u8x32{_mm256_permute4x64_epi64 (r, _MM_SHUFFLE (3 , 1 , 2 , 0 ))};
76+ }
77+
78+ YNN_ALWAYS_INLINE s16x16 saturating_rounding_convert (f32x16 f, int16_t ) {
79+ const __m256 max_int16 = _mm256_set1_ps ((float )((1 << 15 ) - 1 ));
80+ const __m256i i0 = _mm256_cvtps_epi32 (_mm256_min_ps (f.lo ().v , max_int16));
81+ const __m256i i1 = _mm256_cvtps_epi32 (_mm256_min_ps (f.hi ().v , max_int16));
82+ return saturating_convert (s32x16 (s32x8 (i0), s32x8 (i1)), int16_t ());
83+ }
84+
85+ YNN_ALWAYS_INLINE s8x32 saturating_rounding_convert (f32x32 f, int8_t ) {
86+ const s16x16 i01 =
87+ saturating_rounding_convert (f32x16 (f.lo ().lo (), f.lo ().hi ()), int16_t ());
88+ const s16x16 i23 =
89+ saturating_rounding_convert (f32x16 (f.hi ().lo (), f.hi ().hi ()), int16_t ());
90+ return saturating_convert (s16x32 (i01, i23), int8_t ());
91+ }
92+
93+ YNN_ALWAYS_INLINE u8x32 saturating_rounding_convert (f32x32 f, uint8_t ) {
94+ const __m256 max_uint16 = _mm256_set1_ps ((float )((1 << 16 ) - 1 ));
95+ const __m256i i0 =
96+ _mm256_cvtps_epi32 (_mm256_min_ps (f.lo ().lo ().v , max_uint16));
97+ const __m256i i1 =
98+ _mm256_cvtps_epi32 (_mm256_min_ps (f.lo ().hi ().v , max_uint16));
99+ const __m256i i2 =
100+ _mm256_cvtps_epi32 (_mm256_min_ps (f.hi ().lo ().v , max_uint16));
101+ const __m256i i3 =
102+ _mm256_cvtps_epi32 (_mm256_min_ps (f.hi ().hi ().v , max_uint16));
103+ const __m256i i01_16 = _mm256_packs_epi32 (i0, i1);
104+ const __m256i i23_16 = _mm256_packs_epi32 (i2, i3);
105+ const __m256i r = _mm256_packus_epi16 (i01_16, i23_16);
106+ return u8x32{_mm256_permutevar8x32_epi32 (
107+ r, _mm256_setr_epi32 (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ))};
108+ }
109+
61110} // namespace simd
62111
63112} // namespace ynn
0 commit comments