@@ -188,23 +188,25 @@ fn vpdpbusd<'tcx>(
188188 let ( b, b_len) = ecx. project_to_simd ( b) ?;
189189 let ( dest, dest_len) = ecx. project_to_simd ( dest) ?;
190190
191- // fn vpdpbusd(src: i32x16, a: i32x16 , b: i32x16 ) -> i32x16;
192- // fn vpdpbusd256(src: i32x8, a: i32x8 , b: i32x8 ) -> i32x8;
193- // fn vpdpbusd128(src: i32x4, a: i32x4 , b: i32x4 ) -> i32x4;
191+ // fn vpdpbusd(src: i32x16, a: u8x64 , b: i8x64 ) -> i32x16;
192+ // fn vpdpbusd256(src: i32x8, a: u8x32 , b: i8x32 ) -> i32x8;
193+ // fn vpdpbusd128(src: i32x4, a: u8x16 , b: i8x16 ) -> i32x4;
194194 assert_eq ! ( dest_len, src_len) ;
195- assert_eq ! ( dest_len , a_len ) ;
196- assert_eq ! ( dest_len , b_len) ;
195+ assert_eq ! ( a_len , dest_len * 4 ) ;
196+ assert_eq ! ( a_len , b_len) ;
197197
198198 for i in 0 ..dest_len {
199199 let src = ecx. read_scalar ( & ecx. project_index ( & src, i) ?) ?. to_i32 ( ) ?;
200- let a = ecx. read_scalar ( & ecx. project_index ( & a, i) ?) ?. to_u32 ( ) ?;
201- let b = ecx. read_scalar ( & ecx. project_index ( & b, i) ?) ?. to_u32 ( ) ?;
202200 let dest = ecx. project_index ( & dest, i) ?;
203201
204- let zipped = a. to_le_bytes ( ) . into_iter ( ) . zip ( b. to_le_bytes ( ) ) ;
205- let intermediate_sum: i32 = zipped
206- . map ( |( a, b) | i32:: from ( a) . strict_mul ( i32:: from ( b. cast_signed ( ) ) ) )
207- . fold ( 0 , |x, y| x. strict_add ( y) ) ;
202+ let mut intermediate_sum: i32 = 0 ;
203+ for j in 0 ..4 {
204+ let a = ecx. read_scalar ( & ecx. project_index ( & a, i * 4 + j) ?) ?. to_u8 ( ) ?;
205+ let b = ecx. read_scalar ( & ecx. project_index ( & b, i * 4 + j) ?) ?. to_i8 ( ) ?;
206+
207+ let product = i32:: from ( a) . strict_mul ( i32:: from ( b) ) ;
208+ intermediate_sum = intermediate_sum. strict_add ( product) ;
209+ }
208210
209211 // Use `wrapping_add` because `src` is an arbitrary i32 and the addition can overflow.
210212 let res = Scalar :: from_i32 ( intermediate_sum. wrapping_add ( src) ) ;
0 commit comments