diff --git a/arrow-buffer/src/bigint/mod.rs b/arrow-buffer/src/bigint/mod.rs index 92f6e291abba..aecf55f64bfd 100644 --- a/arrow-buffer/src/bigint/mod.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -19,13 +19,14 @@ use crate::arith::derive_arith; use crate::bigint::div::div_rem; use num_bigint::BigInt; use num_traits::{ - Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, FromPrimitive, - Num, One, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr, - WrappingSub, Zero, cast::AsPrimitive, + Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedShl, CheckedShr, + CheckedSub, ConstOne, ConstZero, FromPrimitive, MulAdd, MulAddAssign, Num, One, SaturatingAdd, + SaturatingMul, SaturatingSub, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingNeg, + WrappingShl, WrappingShr, WrappingSub, Zero, cast::AsPrimitive, }; use std::cmp::Ordering; use std::num::ParseIntError; -use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; +use std::ops::{BitAnd, BitOr, BitXor, Neg, Not, Shl, Shr}; use std::str::FromStr; mod div; @@ -1061,6 +1062,22 @@ impl CheckedRem for i256 { } } +impl CheckedShl for i256 { + fn checked_shl(&self, rhs: u32) -> Option { + let rhs = u8::try_from(rhs).ok()?; + Some(self.shl(rhs)) + } +} + +impl CheckedShr for i256 { + fn checked_shr(&self, rhs: u32) -> Option { + let rhs = u8::try_from(rhs).ok()?; + Some(self.shr(rhs)) + } +} + +// num_traits wrapping implementations + impl WrappingAdd for i256 { fn wrapping_add(&self, v: &Self) -> Self { (*self).wrapping_add(*v) @@ -1085,6 +1102,58 @@ impl WrappingNeg for i256 { } } +// num_traits saturating implementations + +impl SaturatingAdd for i256 { + fn saturating_add(&self, v: &Self) -> Self { + self.checked_add(v).unwrap_or_else(|| { + if v.is_negative() { + i256::MIN + } else { + i256::MAX + } + }) + } +} + +impl SaturatingSub for i256 { + fn saturating_sub(&self, v: &Self) -> Self { + self.checked_sub(v).unwrap_or_else(|| { + if v.is_negative() { + i256::MAX + } else { + i256::MIN + } + }) + } +} + +impl SaturatingMul for i256 { + fn saturating_mul(&self, v: &Self) -> Self { + self.checked_mul(v).unwrap_or_else(|| { + if v.is_negative() { + i256::MIN + } else { + i256::MAX + } + }) + } +} + +impl MulAdd for i256 { + type Output = i256; + + fn mul_add(self, a: Self, b: Self) -> Self::Output { + (self * a) + b + } +} + +impl MulAddAssign for i256 { + fn mul_add_assign(&mut self, a: Self, b: Self) { + *self = self.mul_add(a, b) + } +} + impl Zero for i256 { fn zero() -> Self { i256::ZERO @@ -1095,6 +1164,10 @@ impl Zero for i256 { } } +impl ConstZero for i256 { + const ZERO: Self = i256::ZERO; +} + impl One for i256 { fn one() -> Self { i256::ONE @@ -1105,6 +1178,10 @@ impl One for i256 { } } +impl ConstOne for i256 { + const ONE: Self = i256::ONE; +} + impl Num for i256 { type FromStrRadixErr = ParseI256Error; @@ -1154,6 +1231,15 @@ impl Bounded for i256 { } } +impl Not for i256 { + type Output = i256; + + #[inline] + fn not(self) -> Self::Output { + Self::from_parts(!self.low, !self.high) + } +} + #[cfg(all(test, not(miri)))] // llvm.x86.subborrow.64 not supported by MIRI mod tests { use super::*; @@ -1362,6 +1448,8 @@ mod tests { i256::ZERO, i256::ONE, i256::MINUS_ONE, + ConstZero::ZERO, + ConstOne::ONE, i256::from_i128(2), i256::from_i128(-2), i256::from_parts(u128::MAX, 1), @@ -1686,6 +1774,36 @@ mod tests { let result = ::wrapping_add(&i256::MAX, &i256::ONE); assert_eq!(result, i256::MIN); + // Saturating operations + assert_eq!(i256::MAX.saturating_add(&i256::ONE), i256::MAX); + assert_eq!(i256::MIN.saturating_sub(&i256::ONE), i256::MIN); + assert_eq!(i256::MIN.saturating_add(&i256::MINUS_ONE), i256::MIN); + assert_eq!(i256::MAX.saturating_sub(&i256::MINUS_ONE), i256::MAX); + assert_eq!(i256::MAX.saturating_mul(&i256::MAX), i256::MAX); + assert_eq!( + i256::from(20).saturating_add(&i256::from(5)), + i256::from(25) + ); + assert_eq!( + i256::from(20).saturating_sub(&i256::from(5)), + i256::from(15) + ); + assert_eq!( + i256::from(20).saturating_mul(&i256::from(5)), + i256::from(100) + ); + + // Mul-add + assert_eq!( + i256::from(20).mul_add(i256::from(5), i256::from(10)), + i256::from(110) + ); + + let mut mul_add_value = i256::from(20); + mul_add_value.mul_add_assign(i256::from(5), i256::from(10)); + assert_eq!(mul_add_value, i256::from(110)); + + let value = i256::from(-5); assert_eq!(::abs(&value), i256::from(5)); assert_eq!(::one(), i256::from(1)); @@ -1693,6 +1811,11 @@ mod tests { assert_eq!(::min_value(), i256::MIN); assert_eq!(::max_value(), i256::MAX); + + // Bitwise not + assert_eq!(!i256::ZERO, i256::MINUS_ONE); + assert_eq!(!i256::MINUS_ONE, i256::ZERO); + assert_eq!(!i256::ONE, i256::from_parts(u128::MAX - 1, -1)); } #[should_panic]