Skip to content

Commit 28dcd17

Browse files
committed
correct vpdpbusd asserts in miri
in a24022a we changed the argument types to be more accurate, and now the miri asserts on the simd type/size need to reflect that
1 parent 8155209 commit 28dcd17

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/tools/miri/src/shims/x86/avx512.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,23 +188,25 @@ fn vpdpbusd<'tcx>(
188188
let (b, b_len) = ecx.project_to_simd(b)?;
189189
let (dest, dest_len) = ecx.project_to_simd(dest)?;
190190

191-
// fn vpdpbusd(src: i32x16, a: i32x16, b: i32x16) -> i32x16;
192-
// fn vpdpbusd256(src: i32x8, a: i32x8, b: i32x8) -> i32x8;
193-
// fn vpdpbusd128(src: i32x4, a: i32x4, b: i32x4) -> i32x4;
191+
// fn vpdpbusd(src: i32x16, a: u8x64, b: i8x64) -> i32x16;
192+
// fn vpdpbusd256(src: i32x8, a: u8x32, b: i8x32) -> i32x8;
193+
// fn vpdpbusd128(src: i32x4, a: u8x16, b: i8x16) -> i32x4;
194194
assert_eq!(dest_len, src_len);
195-
assert_eq!(dest_len, a_len);
196-
assert_eq!(dest_len, b_len);
195+
assert_eq!(a_len, dest_len * 4);
196+
assert_eq!(a_len, b_len);
197197

198198
for i in 0..dest_len {
199199
let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?;
200-
let a = ecx.read_scalar(&ecx.project_index(&a, i)?)?.to_u32()?;
201-
let b = ecx.read_scalar(&ecx.project_index(&b, i)?)?.to_u32()?;
202200
let dest = ecx.project_index(&dest, i)?;
203201

204-
let zipped = a.to_le_bytes().into_iter().zip(b.to_le_bytes());
205-
let intermediate_sum: i32 = zipped
206-
.map(|(a, b)| i32::from(a).strict_mul(i32::from(b.cast_signed())))
207-
.fold(0, |x, y| x.strict_add(y));
202+
let mut intermediate_sum: i32 = 0;
203+
for j in 0..4 {
204+
let a = ecx.read_scalar(&ecx.project_index(&a, i * 4 + j)?)?.to_u8()?;
205+
let b = ecx.read_scalar(&ecx.project_index(&b, i * 4 + j)?)?.to_i8()?;
206+
207+
let product = i32::from(a).strict_mul(i32::from(b));
208+
intermediate_sum = intermediate_sum.strict_add(product);
209+
}
208210

209211
// Use `wrapping_add` because `src` is an arbitrary i32 and the addition can overflow.
210212
let res = Scalar::from_i32(intermediate_sum.wrapping_add(src));

0 commit comments

Comments
 (0)