diff --git a/Cargo.toml b/Cargo.toml index a692408..8480f82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,7 @@ features = ["plotters", "cargo_bench_support"] [[bench]] name = "numparse" harness = false + +[[bench]] +name = "simd_math" +harness = false diff --git a/README.md b/README.md index 9c7c2a7..b7e2973 100644 --- a/README.md +++ b/README.md @@ -24,15 +24,54 @@ Refer to the excellent [Intel Intrinsics Guide](https://software.intel.com/sites * Extract or set a single lane with the index operator: `let v1 = v[1];` * Falls all the way back to scalar code for platforms with no SIMD or unsupported SIMD -# Trig Functions via Sleef-sys +# SIMD math revival status -~~A number of trigonometric and other common math functions are provided~~ -~~in vectorized form via the Sleef-sys crate. This is an optional feature `sleef` that you can enable.~~ -~~Doing so currently requires nightly, as well as having CMake and Clang installed.~~ +SIMDeez now includes a native, pure-Rust math surface for the first restored SLEEF-style family: -⚠️ In simdeez V2.0, sleef is temporarily deprecated due to the maintenance complexity involved around it. We are open to contributions, and are undecided on whether we: -- Resume sleef support via the existing sleef-sys crate -- Re-implement sleef via simdeez primitives +- `log2_u35` +- `exp2_u35` +- `ln_u35` +- `exp_u35` + +These are exposed via extension traits in `simdeez::math` and re-exported in `simdeez::prelude`: + +```rust +use simdeez::prelude::*; + +fn apply_math(x: S::Vf32) -> S::Vf32 { + let y = x.log2_u35(); + y.exp2_u35() + x.ln_u35() + x.exp_u35() +} +``` + +The old `sleef-sys` feature remains historical/deprecated and is **not** the primary implementation path for this revived surface. + +### Kernel layering blueprint (v0.1) + +The restored `f32` path now demonstrates the intended extension architecture: + +1. **Portable SIMD kernels** (`src/math/f32/portable.rs`) implement reduction + polynomial logic with backend-agnostic simdeez primitives. +2. **Backend override dispatch** (`src/math/f32/mod.rs`) selects architecture-tuned kernels without changing the public `SimdMathF32` API. +3. **Hand-optimized backend implementation** (`src/math/f32/x86_avx2.rs`) provides a real AVX2/FMA override for `log2_u35`. +4. **Scalar fallback patching** remains centralized in the portable layer for exceptional lanes, preserving special-value semantics. + +To add the next SLEEF-style function, follow the same pattern: start portable, wire dispatch, then add optional backend overrides only where profiling justifies complexity. + +### Benchmarking restored math + +An in-repo Criterion benchmark target is available for this revived surface: + +```bash +cargo bench --bench simd_math +``` + +This benchmark reports per-function throughput for: + +- native scalar loop baseline (`f32::{log2, exp2, ln, exp}`) +- simdeez runtime-selected path +- forced backend variants (`scalar`, `sse2`, `sse41`, `avx2`, and `avx512` when available on host) + +Current expectation: `log2_u35` and `exp2_u35` should show clear speedups on SIMD-capable backends (notably AVX2 on x86 hosts), while `ln_u35`/`exp_u35` remain scalar-reference quality-first baselines. Use these benches to validate both performance and dispatch behavior as new kernels/overrides are added. # Compared to packed_simd diff --git a/benches/simd_math.rs b/benches/simd_math.rs new file mode 100644 index 0000000..fb1cbe8 --- /dev/null +++ b/benches/simd_math.rs @@ -0,0 +1,337 @@ +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] +use simdeez::scalar::Scalar; +use simdeez::{prelude::*, simd_unsafe_generate_all}; +use std::{hint::black_box, time::Duration}; + +const INPUT_LEN: usize = 1 << 20; + +fn make_positive_log_inputs(len: usize, seed: u64) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..len) + .map(|_| { + let log2x = rng.gen_range(-20.0f32..20.0f32); + let mantissa = rng.gen_range(1.0f32..2.0f32); + mantissa * log2x.exp2() + }) + .collect() +} + +fn make_exp2_inputs(len: usize, seed: u64) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..len) + .map(|_| rng.gen_range(-100.0f32..100.0f32)) + .collect() +} + +fn make_exp_inputs(len: usize, seed: u64) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..len).map(|_| rng.gen_range(-80.0f32..80.0f32)).collect() +} + +#[inline(never)] +fn scalar_log2_sum(input: &[f32]) -> f32 { + input.iter().copied().map(f32::log2).sum() +} + +#[inline(never)] +fn scalar_exp2_sum(input: &[f32]) -> f32 { + input.iter().copied().map(f32::exp2).sum() +} + +#[inline(never)] +fn scalar_ln_sum(input: &[f32]) -> f32 { + input.iter().copied().map(f32::ln).sum() +} + +#[inline(never)] +fn scalar_exp_sum(input: &[f32]) -> f32 { + input.iter().copied().map(f32::exp).sum() +} + +#[inline(always)] +fn simdeez_log2_sum_impl(input: &[f32]) -> f32 { + let mut sum = 0.0f32; + let mut i = 0; + + while i + S::Vf32::WIDTH <= input.len() { + let v = S::Vf32::load_from_slice(&input[i..]); + sum += v.log2_u35().horizontal_add(); + i += S::Vf32::WIDTH; + } + + for &x in &input[i..] { + sum += x.log2(); + } + + sum +} + +#[inline(always)] +fn simdeez_exp2_sum_impl(input: &[f32]) -> f32 { + let mut sum = 0.0f32; + let mut i = 0; + + while i + S::Vf32::WIDTH <= input.len() { + let v = S::Vf32::load_from_slice(&input[i..]); + sum += v.exp2_u35().horizontal_add(); + i += S::Vf32::WIDTH; + } + + for &x in &input[i..] { + sum += x.exp2(); + } + + sum +} + +#[inline(always)] +fn simdeez_ln_sum_impl(input: &[f32]) -> f32 { + let mut sum = 0.0f32; + let mut i = 0; + + while i + S::Vf32::WIDTH <= input.len() { + let v = S::Vf32::load_from_slice(&input[i..]); + sum += v.ln_u35().horizontal_add(); + i += S::Vf32::WIDTH; + } + + for &x in &input[i..] { + sum += x.ln(); + } + + sum +} + +#[inline(always)] +fn simdeez_exp_sum_impl(input: &[f32]) -> f32 { + let mut sum = 0.0f32; + let mut i = 0; + + while i + S::Vf32::WIDTH <= input.len() { + let v = S::Vf32::load_from_slice(&input[i..]); + sum += v.exp_u35().horizontal_add(); + i += S::Vf32::WIDTH; + } + + for &x in &input[i..] { + sum += x.exp(); + } + + sum +} + +simd_unsafe_generate_all!( + fn simdeez_log2_sum(input: &[f32]) -> f32 { + simdeez_log2_sum_impl::(input) + } +); + +simd_unsafe_generate_all!( + fn simdeez_exp2_sum(input: &[f32]) -> f32 { + simdeez_exp2_sum_impl::(input) + } +); + +simd_unsafe_generate_all!( + fn simdeez_ln_sum(input: &[f32]) -> f32 { + simdeez_ln_sum_impl::(input) + } +); + +simd_unsafe_generate_all!( + fn simdeez_exp_sum(input: &[f32]) -> f32 { + simdeez_exp_sum_impl::(input) + } +); + +#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] +#[inline(never)] +fn simdeez_log2_sum_scalar(input: &[f32]) -> f32 { + simdeez_log2_sum_impl::(input) +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] +#[inline(never)] +fn simdeez_exp2_sum_scalar(input: &[f32]) -> f32 { + simdeez_exp2_sum_impl::(input) +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] +#[inline(never)] +fn simdeez_ln_sum_scalar(input: &[f32]) -> f32 { + simdeez_ln_sum_impl::(input) +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] +#[inline(never)] +fn simdeez_exp_sum_scalar(input: &[f32]) -> f32 { + simdeez_exp_sum_impl::(input) +} + +struct BenchTargets { + scalar_native: fn(&[f32]) -> f32, + simdeez_runtime: fn(&[f32]) -> f32, + simdeez_scalar: fn(&[f32]) -> f32, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse2: unsafe fn(&[f32]) -> f32, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse41: unsafe fn(&[f32]) -> f32, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx2: unsafe fn(&[f32]) -> f32, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx512: unsafe fn(&[f32]) -> f32, +} + +fn bench_variants(c: &mut Criterion, group_name: &str, input: &[f32], targets: BenchTargets) { + let mut group = c.benchmark_group(group_name); + group.throughput(Throughput::Elements(input.len() as u64)); + + group.bench_function("scalar-native", |b| { + b.iter(|| black_box((targets.scalar_native)(black_box(input)))) + }); + + group.bench_function("simdeez-runtime", |b| { + b.iter(|| black_box((targets.simdeez_runtime)(black_box(input)))) + }); + + group.bench_function("simdeez-forced-scalar", |b| { + b.iter(|| black_box((targets.simdeez_scalar)(black_box(input)))) + }); + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + { + if std::is_x86_feature_detected!("sse2") { + group.bench_function("simdeez-forced-sse2", |b| { + b.iter(|| unsafe { black_box((targets.simdeez_sse2)(black_box(input))) }) + }); + } else { + eprintln!("[bench] skipped simdeez-forced-sse2 for {group_name}: CPU lacks sse2"); + } + + if std::is_x86_feature_detected!("sse4.1") { + group.bench_function("simdeez-forced-sse41", |b| { + b.iter(|| unsafe { black_box((targets.simdeez_sse41)(black_box(input))) }) + }); + } else { + eprintln!("[bench] skipped simdeez-forced-sse41 for {group_name}: CPU lacks sse4.1"); + } + + if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") { + group.bench_function("simdeez-forced-avx2", |b| { + b.iter(|| unsafe { black_box((targets.simdeez_avx2)(black_box(input))) }) + }); + } else { + eprintln!("[bench] skipped simdeez-forced-avx2 for {group_name}: CPU lacks avx2/fma"); + } + + if std::is_x86_feature_detected!("avx512f") + && std::is_x86_feature_detected!("avx512bw") + && std::is_x86_feature_detected!("avx512dq") + { + group.bench_function("simdeez-forced-avx512", |b| { + b.iter(|| unsafe { black_box((targets.simdeez_avx512)(black_box(input))) }) + }); + } else { + eprintln!( + "[bench] skipped simdeez-forced-avx512 for {group_name}: CPU lacks avx512f+bw+dq" + ); + } + } + + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let log_inputs = make_positive_log_inputs(INPUT_LEN, 0xA11C_E001); + let exp2_inputs = make_exp2_inputs(INPUT_LEN, 0xA11C_E002); + let exp_inputs = make_exp_inputs(INPUT_LEN, 0xA11C_E003); + + bench_variants( + c, + "simd_math/f32/log2_u35", + &log_inputs, + BenchTargets { + scalar_native: scalar_log2_sum, + simdeez_runtime: simdeez_log2_sum, + simdeez_scalar: simdeez_log2_sum_scalar, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse2: simdeez_log2_sum_sse2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse41: simdeez_log2_sum_sse41, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx2: simdeez_log2_sum_avx2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx512: simdeez_log2_sum_avx512, + }, + ); + + bench_variants( + c, + "simd_math/f32/exp2_u35", + &exp2_inputs, + BenchTargets { + scalar_native: scalar_exp2_sum, + simdeez_runtime: simdeez_exp2_sum, + simdeez_scalar: simdeez_exp2_sum_scalar, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse2: simdeez_exp2_sum_sse2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse41: simdeez_exp2_sum_sse41, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx2: simdeez_exp2_sum_avx2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx512: simdeez_exp2_sum_avx512, + }, + ); + + bench_variants( + c, + "simd_math/f32/ln_u35", + &log_inputs, + BenchTargets { + scalar_native: scalar_ln_sum, + simdeez_runtime: simdeez_ln_sum, + simdeez_scalar: simdeez_ln_sum_scalar, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse2: simdeez_ln_sum_sse2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse41: simdeez_ln_sum_sse41, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx2: simdeez_ln_sum_avx2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx512: simdeez_ln_sum_avx512, + }, + ); + + bench_variants( + c, + "simd_math/f32/exp_u35", + &exp_inputs, + BenchTargets { + scalar_native: scalar_exp_sum, + simdeez_runtime: simdeez_exp_sum, + simdeez_scalar: simdeez_exp_sum_scalar, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse2: simdeez_exp_sum_sse2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_sse41: simdeez_exp_sum_sse41, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx2: simdeez_exp_sum_avx2, + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + simdeez_avx512: simdeez_exp_sum_avx512, + }, + ); +} + +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(20) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(2)); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs index 577758d..1a80b63 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,10 +24,15 @@ //! * Operator overloading: `let sum = va + vb` or `s *= s` //! * Extract or set a single lane with the index operator: `let v1 = v[1];` //! -//! # Trig Functions via Sleef-sys -//! A number of trigonometric and other common math functions are provided -//! in vectorized form via the Sleef-sys crate. This is an optional feature `sleef` that you can enable. -//! Doing so currently requires nightly, as well as having CMake and Clang installed. +//! # SIMD math revival status +//! SIMDeez now provides a native, pure-Rust first math family via extension traits: +//! `log2_u35`, `exp2_u35`, `ln_u35`, and `exp_u35`. +//! +//! These methods are available through `simdeez::math` and re-exported by `simdeez::prelude`. +//! The implementation follows a layered blueprint: portable kernels first, +//! backend-specific overrides where justified (currently a hand-tuned AVX2 `log2_u35`), +//! and scalar fallback patching for exceptional lanes. +//! The historical `sleef` feature remains deprecated and is not the primary implementation path. //! //! # Compared to stdsimd //! @@ -180,6 +185,9 @@ pub use base::*; mod libm_ext; +pub mod math; +pub use math::{SimdMathF32, SimdMathF64}; + mod engines; pub use engines::scalar; diff --git a/src/math/f32/mod.rs b/src/math/f32/mod.rs new file mode 100644 index 0000000..8692780 --- /dev/null +++ b/src/math/f32/mod.rs @@ -0,0 +1,54 @@ +//! f32 SIMD math kernel layering: +//! - `portable`: backend-agnostic reduction/polynomial kernels + scalar lane patching. +//! - `x86_avx2`: optional hand-optimized override(s) for specific functions. +//! - this module: dispatch glue selecting overrides without changing the public API. + +mod portable; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +mod x86_avx2; + +use crate::{Simd, SimdFloat32}; + +#[inline(always)] +pub(crate) fn log2_u35(input: V) -> V +where + V: SimdFloat32, + V::Engine: Simd, +{ + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + { + if is_avx2_engine::() { + return unsafe { x86_avx2::log2_u35(input) }; + } + } + + portable::log2_u35(input) +} + +#[inline(always)] +pub(crate) fn exp2_u35(input: V) -> V +where + V: SimdFloat32, + V::Engine: Simd, +{ + portable::exp2_u35(input) +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[inline(always)] +fn is_avx2_engine() -> bool { + core::any::TypeId::of::() == core::any::TypeId::of::() +} + +#[cfg(test)] +mod tests { + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[test] + fn avx2_dispatch_gate_matches_only_avx2_engine() { + use crate::engines::{avx2::Avx2, sse2::Sse2}; + + assert!(super::is_avx2_engine::()); + assert!(!super::is_avx2_engine::()); + } +} diff --git a/src/math/f32/portable.rs b/src/math/f32/portable.rs new file mode 100644 index 0000000..d6dc3e5 --- /dev/null +++ b/src/math/f32/portable.rs @@ -0,0 +1,158 @@ +use crate::math::scalar; +use crate::{Simd, SimdBaseIo, SimdBaseOps, SimdConsts, SimdFloat32, SimdInt, SimdInt32}; + +pub(super) type SimdI32 = <::Engine as Simd>::Vi32; + +pub(super) const F32_EXPONENT_MASK: i32 = 0x7F80_0000u32 as i32; +pub(super) const F32_MANTISSA_MASK: i32 = 0x007F_FFFF; +pub(super) const F32_LOG_NORM_MANTISSA: i32 = 0x3F00_0000; +pub(super) const F32_EXPONENT_BIAS_ADJUST: i32 = 126; + +#[inline(always)] +fn any_lane_nonzero(mask: SimdI32) -> bool +where + V: SimdFloat32, + V::Engine: Simd, +{ + unsafe { + let lanes = mask.as_array(); + for lane in 0..V::WIDTH { + if lanes[lane] != 0 { + return true; + } + } + } + + false +} + +#[inline(always)] +pub(super) fn log2_exceptional_mask(input: V) -> SimdI32 +where + V: SimdFloat32, + V::Engine: Simd, +{ + let bits = input.bitcast_i32(); + let exponent_bits = bits & F32_EXPONENT_MASK; + + let non_positive = input + .cmp_gt(V::zeroes()) + .bitcast_i32() + .cmp_eq(SimdI32::::zeroes()); + let subnormal_or_zero = exponent_bits.cmp_eq(SimdI32::::zeroes()); + let inf_or_nan = exponent_bits.cmp_eq(SimdI32::::set1(F32_EXPONENT_MASK)); + + non_positive | subnormal_or_zero | inf_or_nan +} + +#[inline(always)] +pub(super) fn patch_exceptional_lanes( + input: V, + output: V, + exceptional_mask: SimdI32, + scalar_fallback: fn(f32) -> f32, +) -> V +where + V: SimdFloat32, + V::Engine: Simd, +{ + if !any_lane_nonzero::(exceptional_mask) { + return output; + } + + unsafe { + let input_lanes = input.as_array(); + let mask_lanes = exceptional_mask.as_array(); + let mut output_lanes = output.as_array(); + + for lane in 0..V::WIDTH { + if mask_lanes[lane] != 0 { + output_lanes[lane] = scalar_fallback(input_lanes[lane]); + } + } + + V::load_from_ptr_unaligned(&output_lanes as *const V::ArrayRepresentation as *const f32) + } +} + +#[inline(always)] +pub(super) fn log2_u35(input: V) -> V +where + V: SimdFloat32, + V::Engine: Simd, +{ + let bits = input.bitcast_i32(); + let exponent_bits = bits & F32_EXPONENT_MASK; + let mantissa_bits = bits & F32_MANTISSA_MASK; + + let exceptional_mask = log2_exceptional_mask(input); + + let exponent = (exponent_bits.shr(23) - F32_EXPONENT_BIAS_ADJUST).cast_f32(); + let normalized_mantissa = (mantissa_bits | F32_LOG_NORM_MANTISSA).bitcast_f32(); + + let one = V::set1(1.0); + let half = V::set1(0.5); + let sqrt_half = V::set1(core::f32::consts::FRAC_1_SQRT_2); + + let adjust_mask = normalized_mantissa.cmp_lt(sqrt_half); + let exponent = exponent - adjust_mask.blendv(V::zeroes(), one); + let reduced = adjust_mask.blendv( + normalized_mantissa - one, + (normalized_mantissa + normalized_mantissa) - one, + ); + + let reduced_sq = reduced * reduced; + + let mut poly = V::set1(7.037_683_6e-2); + poly = (poly * reduced) + V::set1(-1.151_461e-1); + poly = (poly * reduced) + V::set1(1.167_699_9e-1); + poly = (poly * reduced) + V::set1(-1.242_014_1e-1); + poly = (poly * reduced) + V::set1(1.424_932_3e-1); + poly = (poly * reduced) + V::set1(-1.666_805_8e-1); + poly = (poly * reduced) + V::set1(2.000_071_5e-1); + poly = (poly * reduced) + V::set1(-2.499_999_4e-1); + poly = (poly * reduced) + V::set1(3.333_333e-1); + + poly *= reduced; + poly *= reduced_sq; + poly += exponent * V::set1(-2.121_944_4e-4); + poly -= half * reduced_sq; + + let ln_x = reduced + poly + (exponent * V::set1(0.693_359_4)); + let fast = ln_x * V::set1(core::f32::consts::LOG2_E); + + patch_exceptional_lanes(input, fast, exceptional_mask, scalar::log2_u35_f32) +} + +#[inline(always)] +pub(super) fn exp2_u35(input: V) -> V +where + V: SimdFloat32, + V::Engine: Simd, +{ + let finite_mask = input.cmp_eq(input).bitcast_i32(); + let in_lower_bound = input.cmp_gte(V::set1(-126.0)).bitcast_i32(); + let in_upper_bound = input.cmp_lte(V::set1(126.0)).bitcast_i32(); + let fast_mask = finite_mask & in_lower_bound & in_upper_bound; + let exceptional_mask = fast_mask.cmp_eq(SimdI32::::zeroes()); + + let integral = input.floor().cast_i32(); + let fractional = input - integral.cast_f32(); + let reduced = fractional * V::set1(core::f32::consts::LN_2); + + let mut poly = V::set1(1.987_569_1e-4); + poly = (poly * reduced) + V::set1(1.398_2e-3); + poly = (poly * reduced) + V::set1(8.333_452e-3); + poly = (poly * reduced) + V::set1(4.166_579_6e-2); + poly = (poly * reduced) + V::set1(1.666_666_5e-1); + poly = (poly * reduced) + V::set1(5e-1); + + let reduced_sq = reduced * reduced; + let exp_reduced = (poly * reduced_sq) + reduced + V::set1(1.0); + + let exp_bits = (integral + 127).shl(23); + let scale = exp_bits.bitcast_f32(); + let fast = exp_reduced * scale; + + patch_exceptional_lanes(input, fast, exceptional_mask, scalar::exp2_u35_f32) +} diff --git a/src/math/f32/x86_avx2.rs b/src/math/f32/x86_avx2.rs new file mode 100644 index 0000000..fba9db6 --- /dev/null +++ b/src/math/f32/x86_avx2.rs @@ -0,0 +1,93 @@ +//! Hand-optimized AVX2/FMA overrides for f32 math kernels. +//! +//! Keep these overrides semantically aligned with `portable` kernels and always +//! reuse scalar exceptional-lane patching to preserve special-case contracts. + +use crate::math::f32::portable; +use crate::math::scalar; +use crate::{Simd, SimdFloat32}; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +#[inline(always)] +pub(super) unsafe fn log2_u35(input: V) -> V +where + V: SimdFloat32, + V::Engine: Simd, +{ + debug_assert!( + core::any::TypeId::of::() + == core::any::TypeId::of::() + ); + + let exceptional_mask = portable::log2_exceptional_mask(input); + let x = input.try_transmute_avx2(); + let fast = log2_u35_avx2_impl(x); + let fast = V::try_transmute_from_avx2(fast); + + portable::patch_exceptional_lanes(input, fast, exceptional_mask, scalar::log2_u35_f32) +} + +#[target_feature(enable = "avx2", enable = "fma")] +unsafe fn log2_u35_avx2_impl(x: __m256) -> __m256 { + let exponent_bits = _mm256_and_si256( + _mm256_castps_si256(x), + _mm256_set1_epi32(portable::F32_EXPONENT_MASK), + ); + let mantissa_bits = _mm256_and_si256( + _mm256_castps_si256(x), + _mm256_set1_epi32(portable::F32_MANTISSA_MASK), + ); + + let exponent = _mm256_cvtepi32_ps(_mm256_sub_epi32( + _mm256_srli_epi32(exponent_bits, 23), + _mm256_set1_epi32(portable::F32_EXPONENT_BIAS_ADJUST), + )); + + let normalized_mantissa = _mm256_castsi256_ps(_mm256_or_si256( + mantissa_bits, + _mm256_set1_epi32(portable::F32_LOG_NORM_MANTISSA), + )); + + let one = _mm256_set1_ps(1.0); + let adjust_mask = _mm256_cmp_ps( + normalized_mantissa, + _mm256_set1_ps(core::f32::consts::FRAC_1_SQRT_2), + _CMP_LT_OQ, + ); + + let exponent = _mm256_sub_ps(exponent, _mm256_and_ps(adjust_mask, one)); + + let reduced = _mm256_blendv_ps( + _mm256_sub_ps(normalized_mantissa, one), + _mm256_sub_ps(_mm256_add_ps(normalized_mantissa, normalized_mantissa), one), + adjust_mask, + ); + + let reduced_sq = _mm256_mul_ps(reduced, reduced); + + let mut poly = _mm256_set1_ps(7.037_683_6e-2); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(-1.151_461e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(1.167_699_9e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(-1.242_014_1e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(1.424_932_3e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(-1.666_805_8e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(2.000_071_5e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(-2.499_999_4e-1)); + poly = _mm256_fmadd_ps(poly, reduced, _mm256_set1_ps(3.333_333e-1)); + + poly = _mm256_mul_ps(poly, reduced); + poly = _mm256_mul_ps(poly, reduced_sq); + poly = _mm256_fmadd_ps(exponent, _mm256_set1_ps(-2.121_944_4e-4), poly); + poly = _mm256_fnmadd_ps(_mm256_set1_ps(0.5), reduced_sq, poly); + + let ln_x = _mm256_fmadd_ps( + exponent, + _mm256_set1_ps(0.693_359_4), + _mm256_add_ps(reduced, poly), + ); + _mm256_mul_ps(ln_x, _mm256_set1_ps(core::f32::consts::LOG2_E)) +} diff --git a/src/math/mod.rs b/src/math/mod.rs new file mode 100644 index 0000000..06ec56a --- /dev/null +++ b/src/math/mod.rs @@ -0,0 +1,138 @@ +//! Portable SIMD math scaffolding for SLEEF-style transcendental families. +//! +//! Strategy C baseline: keep semantics in-tree and backend-agnostic by expressing +//! vector math over existing simdeez vector types. +//! +//! `f32` `log2_u35` / `exp2_u35` flow through a layered kernel stack: +//! portable SIMD kernels first, optional backend overrides where available, +//! and scalar-lane fallback for exceptional semantics. +//! `ln_u35` / `exp_u35` currently stay on deterministic scalar references. + +mod f32; +mod scalar; + +use crate::{Simd, SimdFloat32, SimdFloat64}; + +/// Accuracy contracts for currently restored math families. +pub mod contracts { + /// Maximum ULP error target for the f32 `log2_u35` kernel family. + pub const LOG2_U35_F32_MAX_ULP: u32 = 35; + + /// Maximum ULP error target for the f32 `exp2_u35` kernel family. + pub const EXP2_U35_F32_MAX_ULP: u32 = 35; + + /// Maximum ULP error target for the f64 `log2_u35` kernel family. + pub const LOG2_U35_F64_MAX_ULP: u64 = 35; + + /// Maximum ULP error target for the f64 `exp2_u35` kernel family. + pub const EXP2_U35_F64_MAX_ULP: u64 = 35; + + /// Maximum ULP error target for f32 `ln_u35`. + pub const LN_U35_F32_MAX_ULP: u32 = 1; + + /// Maximum ULP error target for f32 `exp_u35`. + pub const EXP_U35_F32_MAX_ULP: u32 = 1; + + /// Maximum ULP error target for f64 `ln_u35`. + pub const LN_U35_F64_MAX_ULP: u64 = 1; + + /// Maximum ULP error target for f64 `exp_u35`. + pub const EXP_U35_F64_MAX_ULP: u64 = 1; +} + +#[inline(always)] +fn map_unary_f32(input: V, f: impl Fn(f32) -> f32) -> V { + unsafe { + let mut lanes = input.as_array(); + for i in 0..V::WIDTH { + lanes[i] = f(lanes[i]); + } + V::load_from_ptr_unaligned(&lanes as *const V::ArrayRepresentation as *const f32) + } +} + +#[inline(always)] +fn map_unary_f64(input: V, f: impl Fn(f64) -> f64) -> V { + unsafe { + let mut lanes = input.as_array(); + for i in 0..V::WIDTH { + lanes[i] = f(lanes[i]); + } + V::load_from_ptr_unaligned(&lanes as *const V::ArrayRepresentation as *const f64) + } +} + +/// SIMD math extension trait for `f32` vector types. +/// +/// `log2_u35`/`exp2_u35` use SIMD-native reduction/polynomial kernels. +/// `log2_u35` additionally demonstrates backend override dispatch with a +/// hand-tuned AVX2/FMA implementation. +/// `ln_u35`/`exp_u35` currently use deterministic lane-wise scalar references. +pub trait SimdMathF32: SimdFloat32 { + /// `log2(x)` with target `u35`-tier contract. + /// + /// Uses a SIMD-native mantissa/exponent reduction + polynomial kernel for + /// positive normal inputs, with scalar fallback for exceptional lanes. + #[inline(always)] + fn log2_u35(self) -> Self + where + Self::Engine: Simd, + { + f32::log2_u35(self) + } + + /// `exp2(x)` with target `u35`-tier contract. + /// + /// Uses a SIMD-native floor/reduction + polynomial kernel in the finite + /// in-range domain, with scalar fallback for exceptional lanes. + #[inline(always)] + fn exp2_u35(self) -> Self + where + Self::Engine: Simd, + { + f32::exp2_u35(self) + } + + /// `ln(x)` with target `u35`-tier contract. + #[inline(always)] + fn ln_u35(self) -> Self { + map_unary_f32(self, scalar::ln_u35_f32) + } + + /// `exp(x)` with target `u35`-tier contract. + #[inline(always)] + fn exp_u35(self) -> Self { + map_unary_f32(self, scalar::exp_u35_f32) + } +} + +impl SimdMathF32 for T {} + +/// SIMD math extension trait for `f64` vector types. +pub trait SimdMathF64: SimdFloat64 { + /// `log2(x)` with target `u35`-tier contract. + #[inline(always)] + fn log2_u35(self) -> Self { + map_unary_f64(self, scalar::log2_u35_f64) + } + + /// `exp2(x)` with target `u35`-tier contract. + #[inline(always)] + fn exp2_u35(self) -> Self { + map_unary_f64(self, scalar::exp2_u35_f64) + } + + /// `ln(x)` with target `u35`-tier contract. + #[inline(always)] + fn ln_u35(self) -> Self { + map_unary_f64(self, scalar::ln_u35_f64) + } + + /// `exp(x)` with target `u35`-tier contract. + #[inline(always)] + fn exp_u35(self) -> Self { + map_unary_f64(self, scalar::exp_u35_f64) + } +} + +impl SimdMathF64 for T {} diff --git a/src/math/scalar.rs b/src/math/scalar.rs new file mode 100644 index 0000000..47b711a --- /dev/null +++ b/src/math/scalar.rs @@ -0,0 +1,41 @@ +use crate::libm_ext::FloatExt; + +#[inline(always)] +pub fn log2_u35_f32(x: f32) -> f32 { + x.m_log2() +} + +#[inline(always)] +pub fn exp2_u35_f32(x: f32) -> f32 { + x.m_exp2() +} + +#[inline(always)] +pub fn ln_u35_f32(x: f32) -> f32 { + x.m_ln() +} + +#[inline(always)] +pub fn exp_u35_f32(x: f32) -> f32 { + x.m_exp() +} + +#[inline(always)] +pub fn log2_u35_f64(x: f64) -> f64 { + x.m_log2() +} + +#[inline(always)] +pub fn exp2_u35_f64(x: f64) -> f64 { + x.m_exp2() +} + +#[inline(always)] +pub fn ln_u35_f64(x: f64) -> f64 { + x.m_ln() +} + +#[inline(always)] +pub fn exp_u35_f64(x: f64) -> f64 { + x.m_exp() +} diff --git a/src/prelude.rs b/src/prelude.rs index b9c39a7..ebe2ba6 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -10,5 +10,6 @@ pub use crate::base::{ SimdFloat, SimdFloat32, SimdFloat64, SimdInt, SimdInt16, SimdInt32, SimdInt64, SimdInt8, SimdIter, }; +pub use crate::math::{SimdMathF32, SimdMathF64}; pub use paste::item as simdeez_paste_item; diff --git a/src/tests/lib/mod.rs b/src/tests/lib/mod.rs index 723ec11..a0a3f77 100644 --- a/src/tests/lib/mod.rs +++ b/src/tests/lib/mod.rs @@ -11,4 +11,7 @@ pub use tester::*; mod numbers; pub use numbers::*; +mod ulp; +pub use ulp::*; + mod constify; diff --git a/src/tests/lib/ulp.rs b/src/tests/lib/ulp.rs new file mode 100644 index 0000000..1db12ac --- /dev/null +++ b/src/tests/lib/ulp.rs @@ -0,0 +1,67 @@ +#[inline] +fn ordered_u32(bits: u32) -> u32 { + if bits & 0x8000_0000 != 0 { + !bits + } else { + bits | 0x8000_0000 + } +} + +#[inline] +fn ordered_u64(bits: u64) -> u64 { + if bits & 0x8000_0000_0000_0000 != 0 { + !bits + } else { + bits | 0x8000_0000_0000_0000 + } +} + +/// Returns the ULP distance between two f32 values, or `None` if either is NaN. +pub fn ulp_distance_f32(a: f32, b: f32) -> Option { + if a.is_nan() || b.is_nan() { + return None; + } + + let oa = ordered_u32(a.to_bits()); + let ob = ordered_u32(b.to_bits()); + Some(oa.abs_diff(ob)) +} + +/// Returns the ULP distance between two f64 values, or `None` if either is NaN. +pub fn ulp_distance_f64(a: f64, b: f64) -> Option { + if a.is_nan() || b.is_nan() { + return None; + } + + let oa = ordered_u64(a.to_bits()); + let ob = ordered_u64(b.to_bits()); + Some(oa.abs_diff(ob)) +} + +#[cfg(test)] +mod tests { + use super::{ulp_distance_f32, ulp_distance_f64}; + + #[test] + fn ulp_distance_handles_zero_signs() { + assert_eq!(ulp_distance_f32(0.0, -0.0), Some(1)); + assert_eq!(ulp_distance_f64(0.0, -0.0), Some(1)); + } + + #[test] + fn ulp_distance_matches_adjacent_bit_pattern_step() { + let a = 1.0f32; + let b = f32::from_bits(a.to_bits() + 1); + assert_eq!(ulp_distance_f32(a, b), Some(1)); + + let c = 1.0f64; + let d = f64::from_bits(c.to_bits() + 1); + assert_eq!(ulp_distance_f64(c, d), Some(1)); + } + + #[test] + fn ulp_distance_rejects_nan() { + assert_eq!(ulp_distance_f32(f32::NAN, 1.0), None); + assert_eq!(ulp_distance_f64(f64::NAN, 1.0), None); + } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 467877d..c05956d 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -11,4 +11,6 @@ mod i8_truthy_regressions; mod integer_edge_contracts; mod real_world; mod run; +mod simd_math; +mod simd_math_targeted_edges; mod wasm_unaligned_regressions; diff --git a/src/tests/simd_math.rs b/src/tests/simd_math.rs new file mode 100644 index 0000000..47eeff8 --- /dev/null +++ b/src/tests/simd_math.rs @@ -0,0 +1,336 @@ +#![allow(unused_imports)] + +use rand::Rng; +use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; + +use super::*; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::engines::avx512::Avx512; +#[cfg(target_arch = "aarch64")] +use crate::engines::neon::Neon; +use crate::engines::scalar::Scalar; +#[cfg(target_arch = "wasm32")] +use crate::engines::wasm32::Wasm; +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::engines::{avx2::Avx2, sse2::Sse2, sse41::Sse41}; + +use crate::math::{contracts, SimdMathF32, SimdMathF64}; +use crate::{Simd, SimdBaseIo, SimdConsts}; + +fn f32_input_space() -> Vec { + let mut values = vec![ + f32::NAN, + f32::from_bits(0x7FC0_1234), + f32::INFINITY, + f32::NEG_INFINITY, + 0.0, + -0.0, + 1.0, + -1.0, + 2.0, + 0.5, + f32::MIN_POSITIVE, + f32::from_bits(1), + f32::MAX, + f32::MIN, + core::f32::consts::PI, + -core::f32::consts::PI, + core::f32::consts::E, + -100.0, + 100.0, + 128.0, + -150.0, + ]; + + let mut rng = ChaCha8Rng::seed_from_u64(0x51DE_EF32); + for _ in 0..10_000 { + values.push(f32::from_bits(rng.gen::())); + } + + values +} + +fn f64_input_space() -> Vec { + let mut values = vec![ + f64::NAN, + f64::from_bits(0x7FF8_0000_0000_1234), + f64::INFINITY, + f64::NEG_INFINITY, + 0.0, + -0.0, + 1.0, + -1.0, + 2.0, + 0.5, + f64::MIN_POSITIVE, + f64::from_bits(1), + f64::MAX, + f64::MIN, + core::f64::consts::PI, + -core::f64::consts::PI, + core::f64::consts::E, + -1000.0, + 1000.0, + 1024.0, + -2000.0, + ]; + + let mut rng = ChaCha8Rng::seed_from_u64(0x51DE_EF64); + for _ in 0..8_000 { + values.push(f64::from_bits(rng.gen::())); + } + + values +} + +fn assert_f32_contract( + fn_name: &str, + input: f32, + actual: f32, + expected: f32, + max_ulp: u32, +) -> Result<(), String> { + if expected.is_nan() { + if actual.is_nan() { + return Ok(()); + } + return Err(format!("{fn_name}({input:?}) expected NaN, got {actual:?}")); + } + + if expected.is_infinite() { + if actual.to_bits() == expected.to_bits() { + return Ok(()); + } + return Err(format!( + "{fn_name}({input:?}) expected {:?}, got {:?}", + expected, actual + )); + } + + if expected == 0.0 { + if actual.to_bits() == expected.to_bits() { + return Ok(()); + } + return Err(format!( + "{fn_name}({input:?}) expected signed zero bits {:08x}, got {:08x}", + expected.to_bits(), + actual.to_bits() + )); + } + + if actual.is_nan() || actual.is_infinite() { + return Err(format!( + "{fn_name}({input:?}) expected finite {expected:?}, got {actual:?}" + )); + } + + let ulp = ulp_distance_f32(actual, expected) + .ok_or_else(|| format!("{fn_name}({input:?}) failed to compute f32 ULP distance"))?; + if ulp > max_ulp { + return Err(format!( + "{fn_name}({input:?}) ULP distance {ulp} exceeds max {max_ulp} (actual={actual:?}, expected={expected:?})" + )); + } + + Ok(()) +} + +fn assert_f64_contract( + fn_name: &str, + input: f64, + actual: f64, + expected: f64, + max_ulp: u64, +) -> Result<(), String> { + if expected.is_nan() { + if actual.is_nan() { + return Ok(()); + } + return Err(format!("{fn_name}({input:?}) expected NaN, got {actual:?}")); + } + + if expected.is_infinite() { + if actual.to_bits() == expected.to_bits() { + return Ok(()); + } + return Err(format!( + "{fn_name}({input:?}) expected {:?}, got {:?}", + expected, actual + )); + } + + if expected == 0.0 { + if actual.to_bits() == expected.to_bits() { + return Ok(()); + } + return Err(format!( + "{fn_name}({input:?}) expected signed zero bits {:016x}, got {:016x}", + expected.to_bits(), + actual.to_bits() + )); + } + + if actual.is_nan() || actual.is_infinite() { + return Err(format!( + "{fn_name}({input:?}) expected finite {expected:?}, got {actual:?}" + )); + } + + let ulp = ulp_distance_f64(actual, expected) + .ok_or_else(|| format!("{fn_name}({input:?}) failed to compute f64 ULP distance"))?; + if ulp > max_ulp { + return Err(format!( + "{fn_name}({input:?}) ULP distance {ulp} exceeds max {max_ulp} (actual={actual:?}, expected={expected:?})" + )); + } + + Ok(()) +} + +fn check_unary_f32( + fn_name: &str, + max_ulp: u32, + simd_fn: impl Fn(S::Vf32) -> S::Vf32, + scalar_ref: impl Fn(f32) -> f32, +) { + for chunk in f32_input_space().chunks(S::Vf32::WIDTH) { + let input = S::Vf32::load_from_slice(chunk); + let output = simd_fn(input); + + for (lane, &x) in chunk.iter().enumerate() { + let actual = output[lane]; + let expected = scalar_ref(x); + if let Err(err) = assert_f32_contract(fn_name, x, actual, expected, max_ulp) { + panic!("{err}"); + } + } + } +} + +fn check_unary_f64( + fn_name: &str, + max_ulp: u64, + simd_fn: impl Fn(S::Vf64) -> S::Vf64, + scalar_ref: impl Fn(f64) -> f64, +) { + for chunk in f64_input_space().chunks(S::Vf64::WIDTH) { + let input = S::Vf64::load_from_slice(chunk); + let output = simd_fn(input); + + for (lane, &x) in chunk.iter().enumerate() { + let actual = output[lane]; + let expected = scalar_ref(x); + if let Err(err) = assert_f64_contract(fn_name, x, actual, expected, max_ulp) { + panic!("{err}"); + } + } + } +} + +fn run_f32_log2_u35_contract() { + check_unary_f32::( + "log2_u35", + contracts::LOG2_U35_F32_MAX_ULP, + ::log2_u35, + f32::log2, + ); +} + +fn run_f32_exp2_u35_contract() { + check_unary_f32::( + "exp2_u35", + contracts::EXP2_U35_F32_MAX_ULP, + ::exp2_u35, + f32::exp2, + ); +} + +fn run_f32_ln_u35_contract() { + check_unary_f32::( + "ln_u35", + contracts::LN_U35_F32_MAX_ULP, + ::ln_u35, + f32::ln, + ); +} + +fn run_f32_exp_u35_contract() { + check_unary_f32::( + "exp_u35", + contracts::EXP_U35_F32_MAX_ULP, + ::exp_u35, + f32::exp, + ); +} + +fn run_f64_log2_u35_contract() { + check_unary_f64::( + "log2_u35", + contracts::LOG2_U35_F64_MAX_ULP, + ::log2_u35, + f64::log2, + ); +} + +fn run_f64_exp2_u35_contract() { + check_unary_f64::( + "exp2_u35", + contracts::EXP2_U35_F64_MAX_ULP, + ::exp2_u35, + f64::exp2, + ); +} + +fn run_f64_ln_u35_contract() { + check_unary_f64::( + "ln_u35", + contracts::LN_U35_F64_MAX_ULP, + ::ln_u35, + f64::ln, + ); +} + +fn run_f64_exp_u35_contract() { + check_unary_f64::( + "exp_u35", + contracts::EXP_U35_F64_MAX_ULP, + ::exp_u35, + f64::exp, + ); +} + +macro_rules! simd_math_backend_test { + ($name:ident, $simd:ident, $runner:ident) => { + crate::with_feature_flag!( + $simd, + paste::item! { + #[test] + fn [<$name _ $simd:lower>]() { + $runner::<$simd>(); + } + } + ); + }; +} + +macro_rules! simd_math_all_backends { + ($name:ident, $runner:ident) => { + simd_math_backend_test!($name, Scalar, $runner); + simd_math_backend_test!($name, Avx512, $runner); + simd_math_backend_test!($name, Avx2, $runner); + simd_math_backend_test!($name, Sse2, $runner); + simd_math_backend_test!($name, Sse41, $runner); + simd_math_backend_test!($name, Neon, $runner); + simd_math_backend_test!($name, Wasm, $runner); + }; +} + +simd_math_all_backends!(f32_log2_u35_contract, run_f32_log2_u35_contract); +simd_math_all_backends!(f32_exp2_u35_contract, run_f32_exp2_u35_contract); +simd_math_all_backends!(f32_ln_u35_contract, run_f32_ln_u35_contract); +simd_math_all_backends!(f32_exp_u35_contract, run_f32_exp_u35_contract); +simd_math_all_backends!(f64_log2_u35_contract, run_f64_log2_u35_contract); +simd_math_all_backends!(f64_exp2_u35_contract, run_f64_exp2_u35_contract); +simd_math_all_backends!(f64_ln_u35_contract, run_f64_ln_u35_contract); +simd_math_all_backends!(f64_exp_u35_contract, run_f64_exp_u35_contract); diff --git a/src/tests/simd_math_targeted_edges.rs b/src/tests/simd_math_targeted_edges.rs new file mode 100644 index 0000000..c4487f0 --- /dev/null +++ b/src/tests/simd_math_targeted_edges.rs @@ -0,0 +1,261 @@ +#![allow(unused_imports)] + +use super::*; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::engines::avx512::Avx512; +#[cfg(target_arch = "aarch64")] +use crate::engines::neon::Neon; +use crate::engines::scalar::Scalar; +#[cfg(target_arch = "wasm32")] +use crate::engines::wasm32::Wasm; +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use crate::engines::{avx2::Avx2, sse2::Sse2, sse41::Sse41}; + +use crate::math::{contracts, SimdMathF32}; +use crate::{Simd, SimdBaseIo, SimdConsts}; + +fn assert_f32_contract( + fn_name: &str, + input: f32, + actual: f32, + expected: f32, + max_ulp: u32, +) -> Result<(), String> { + if expected.is_nan() { + if actual.is_nan() { + return Ok(()); + } + return Err(format!("{fn_name}({input:?}) expected NaN, got {actual:?}")); + } + + if expected.is_infinite() { + if actual.to_bits() == expected.to_bits() { + return Ok(()); + } + return Err(format!( + "{fn_name}({input:?}) expected {:?}, got {:?}", + expected, actual + )); + } + + if expected == 0.0 { + if actual.to_bits() == expected.to_bits() { + return Ok(()); + } + return Err(format!( + "{fn_name}({input:?}) expected signed zero bits {:08x}, got {:08x}", + expected.to_bits(), + actual.to_bits() + )); + } + + if actual.is_nan() || actual.is_infinite() { + return Err(format!( + "{fn_name}({input:?}) expected finite {expected:?}, got {actual:?}" + )); + } + + let ulp = ulp_distance_f32(actual, expected) + .ok_or_else(|| format!("{fn_name}({input:?}) failed to compute f32 ULP distance"))?; + if ulp > max_ulp { + return Err(format!( + "{fn_name}({input:?}) ULP distance {ulp} exceeds max {max_ulp} (actual={actual:?}, expected={expected:?})" + )); + } + + Ok(()) +} + +fn check_targeted_unary_f32( + fn_name: &str, + inputs: &[f32], + max_ulp: u32, + simd_fn: impl Fn(S::Vf32) -> S::Vf32, + scalar_ref: impl Fn(f32) -> f32, +) { + for chunk in inputs.chunks(S::Vf32::WIDTH) { + let input = S::Vf32::load_from_slice(chunk); + let output = simd_fn(input); + + for (lane, &x) in chunk.iter().enumerate() { + let actual = output[lane]; + let expected = scalar_ref(x); + if let Err(err) = assert_f32_contract(fn_name, x, actual, expected, max_ulp) { + panic!("{err}"); + } + } + } +} + +fn run_f32_log2_u35_reduction_boundaries() { + let mut inputs = vec![ + f32::from_bits(0x3EFFFFFE), + f32::from_bits(0x3EFFFFFF), + f32::from_bits(0x3F000000), + f32::from_bits(0x3F000001), + f32::from_bits(0x3F7FFFFF), + f32::from_bits(0x3F800000), + f32::from_bits(0x3F800001), + f32::from_bits(0x3FFFFFFF), + f32::from_bits(0x40000000), + f32::from_bits(0x40000001), + ]; + + for &scale in &[0.5f32, 1.0, 2.0, 8.0] { + let pivot = core::f32::consts::FRAC_1_SQRT_2 * scale; + inputs.push(f32::from_bits(pivot.to_bits() - 1)); + inputs.push(pivot); + inputs.push(f32::from_bits(pivot.to_bits() + 1)); + } + + check_targeted_unary_f32::( + "log2_u35", + &inputs, + contracts::LOG2_U35_F32_MAX_ULP, + ::log2_u35, + f32::log2, + ); +} + +fn run_f32_exp2_u35_fast_domain_boundaries() { + let mut inputs = vec![ + -126.0001, + -126.0, + -125.9999, + -1.0001, + -1.0, + -0.9999, + -0.0001, + -0.0, + 0.0, + 0.0001, + 0.9999, + 1.0, + 1.0001, + 125.9999, + 126.0, + 126.0001, + f32::NEG_INFINITY, + f32::INFINITY, + f32::NAN, + ]; + + for k in -4..=4 { + let center = k as f32; + inputs.push(center - 1.0 / 1024.0); + inputs.push(center); + inputs.push(center + 1.0 / 1024.0); + } + + check_targeted_unary_f32::( + "exp2_u35", + &inputs, + contracts::EXP2_U35_F32_MAX_ULP, + ::exp2_u35, + f32::exp2, + ); +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +fn run_log2_u35_vector_apply_avx2(input: &[f32], output: &mut [f32]) { + assert_eq!(input.len(), output.len()); + + Avx2::invoke(|| { + let mut in_rest = input; + let mut out_rest = output; + + while in_rest.len() >= ::Vf32::WIDTH { + let v = ::Vf32::load_from_slice(in_rest); + v.log2_u35().copy_to_slice(out_rest); + + in_rest = &in_rest[::Vf32::WIDTH..]; + out_rest = &mut out_rest[::Vf32::WIDTH..]; + } + + for (&x, out) in in_rest.iter().zip(out_rest.iter_mut()) { + *out = x.log2(); + } + }); +} + +macro_rules! simd_math_backend_targeted_test { + ($name:ident, $simd:ident, $runner:ident) => { + crate::with_feature_flag!( + $simd, + paste::item! { + #[test] + fn [<$name _ $simd:lower>]() { + $runner::<$simd>(); + } + } + ); + }; +} + +macro_rules! simd_math_targeted_all_backends { + ($name:ident, $runner:ident) => { + simd_math_backend_targeted_test!($name, Scalar, $runner); + simd_math_backend_targeted_test!($name, Avx512, $runner); + simd_math_backend_targeted_test!($name, Avx2, $runner); + simd_math_backend_targeted_test!($name, Sse2, $runner); + simd_math_backend_targeted_test!($name, Sse41, $runner); + simd_math_backend_targeted_test!($name, Neon, $runner); + simd_math_backend_targeted_test!($name, Wasm, $runner); + }; +} + +simd_math_targeted_all_backends!( + f32_log2_u35_reduction_boundaries, + run_f32_log2_u35_reduction_boundaries +); +simd_math_targeted_all_backends!( + f32_exp2_u35_fast_domain_boundaries, + run_f32_exp2_u35_fast_domain_boundaries +); + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[test] +fn f32_log2_u35_mixed_exception_lanes_avx2() { + let has_avx2 = std::is_x86_feature_detected!("avx2"); + let has_fma = std::is_x86_feature_detected!("fma"); + if !(has_avx2 && has_fma) { + eprintln!("[test] skipped avx2/fma mixed-lane log2_u35 test: CPU lacks avx2/fma"); + return; + } + + let input = vec![ + 1.0, + 2.0, + -1.0, + 0.0, + -0.0, + f32::from_bits(1), + f32::INFINITY, + f32::NAN, + 0.75, + 1.5, + 3.0, + 64.0, + 1024.0, + 0.25, + f32::from_bits(0x7FC0_1234), + f32::from_bits(0x0000_0100), + ]; + + let mut output = vec![0.0f32; input.len()]; + run_log2_u35_vector_apply_avx2(&input, &mut output); + + for (&x, &actual) in input.iter().zip(output.iter()) { + let expected = x.log2(); + if let Err(err) = assert_f32_contract( + "log2_u35", + x, + actual, + expected, + contracts::LOG2_U35_F32_MAX_ULP, + ) { + panic!("{err}"); + } + } +}