Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions benches/simd_math/hyperbolic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use criterion::Criterion;
use simdeez::math::SimdMathF32Hyperbolic;
use criterion::{Criterion, Throughput};
use simdeez::math::{SimdMathF32Hyperbolic, SimdMathF64Hyperbolic};
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
use simdeez::scalar::Scalar;
use simdeez::{prelude::*, simd_unsafe_generate_all};
use std::hint::black_box;

use crate::shared::{self, BenchTargets, INPUT_LEN};

Expand Down Expand Up @@ -39,6 +40,39 @@ simd_unsafe_generate_all!(
}
);

#[inline(never)]
fn scalar_sinh_sum_f64(input: &[f64]) -> f64 {
input.iter().copied().map(f64::sinh).sum()
}

#[inline(never)]
fn scalar_cosh_sum_f64(input: &[f64]) -> f64 {
input.iter().copied().map(f64::cosh).sum()
}

#[inline(never)]
fn scalar_tanh_sum_f64(input: &[f64]) -> f64 {
input.iter().copied().map(f64::tanh).sum()
}

simd_unsafe_generate_all!(
fn simdeez_sinh_sum_f64(input: &[f64]) -> f64 {
simdeez_sum_impl_f64::<S>(input, |v| v.sinh_u35())
}
);

simd_unsafe_generate_all!(
fn simdeez_cosh_sum_f64(input: &[f64]) -> f64 {
simdeez_sum_impl_f64::<S>(input, |v| v.cosh_u35())
}
);

simd_unsafe_generate_all!(
fn simdeez_tanh_sum_f64(input: &[f64]) -> f64 {
simdeez_sum_impl_f64::<S>(input, |v| v.tanh_u35())
}
);

#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
#[inline(never)]
fn simdeez_sinh_sum_scalar(input: &[f32]) -> f32 {
Expand All @@ -57,6 +91,73 @@ fn simdeez_tanh_sum_scalar(input: &[f32]) -> f32 {
shared::force_scalar_sum(input, |v: <Scalar as Simd>::Vf32| v.tanh_u35())
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
#[inline(never)]
fn force_scalar_sum_f64(
input: &[f64],
op: impl Fn(<Scalar as Simd>::Vf64) -> <Scalar as Simd>::Vf64,
) -> f64 {
simdeez_sum_impl_f64::<Scalar>(input, op)
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
#[inline(never)]
fn simdeez_sinh_sum_scalar_f64(input: &[f64]) -> f64 {
force_scalar_sum_f64(input, |v| v.sinh_u35())
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
#[inline(never)]
fn simdeez_cosh_sum_scalar_f64(input: &[f64]) -> f64 {
force_scalar_sum_f64(input, |v| v.cosh_u35())
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
#[inline(never)]
fn simdeez_tanh_sum_scalar_f64(input: &[f64]) -> f64 {
force_scalar_sum_f64(input, |v| v.tanh_u35())
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(never)]
fn simdeez_sinh_sum_scalar_f64(input: &[f64]) -> f64 {
simdeez_sinh_sum_f64(input)
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(never)]
fn simdeez_cosh_sum_scalar_f64(input: &[f64]) -> f64 {
simdeez_cosh_sum_f64(input)
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[inline(never)]
fn simdeez_tanh_sum_scalar_f64(input: &[f64]) -> f64 {
simdeez_tanh_sum_f64(input)
}

#[inline(always)]
fn simdeez_sum_impl_f64<S: Simd>(input: &[f64], op: impl Fn(S::Vf64) -> S::Vf64) -> f64 {
let mut sum = 0.0f64;
let mut i = 0;

while i + S::Vf64::WIDTH <= input.len() {
let v = S::Vf64::load_from_slice(&input[i..]);
sum += op(v).horizontal_add();
i += S::Vf64::WIDTH;
}

sum
}

fn make_unary_inputs_f64(len: usize, seed: u64, range: core::ops::Range<f64>) -> Vec<f64> {
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;

let mut rng = ChaCha8Rng::seed_from_u64(seed);
(0..len).map(|_| rng.gen_range(range.clone())).collect()
}

pub fn register(c: &mut Criterion) {
let sinh_inputs = shared::make_unary_inputs(INPUT_LEN, 0xA11C_E006, -5.0..5.0);
let cosh_inputs = shared::make_unary_inputs(INPUT_LEN, 0xA11C_E007, -5.0..5.0);
Expand Down Expand Up @@ -118,4 +219,59 @@ pub fn register(c: &mut Criterion) {
simdeez_avx512: simdeez_tanh_sum_avx512,
},
);

let sinh_inputs_f64 = make_unary_inputs_f64(INPUT_LEN, 0xA11C_E106, -5.0..5.0);
let cosh_inputs_f64 = make_unary_inputs_f64(INPUT_LEN, 0xA11C_E107, -5.0..5.0);
let tanh_inputs_f64 = make_unary_inputs_f64(INPUT_LEN, 0xA11C_E108, -20.0..20.0);

bench_variants_f64(
c,
"simd_math/f64/sinh_u35",
&sinh_inputs_f64,
scalar_sinh_sum_f64,
simdeez_sinh_sum_f64,
simdeez_sinh_sum_scalar_f64,
);
bench_variants_f64(
c,
"simd_math/f64/cosh_u35",
&cosh_inputs_f64,
scalar_cosh_sum_f64,
simdeez_cosh_sum_f64,
simdeez_cosh_sum_scalar_f64,
);
bench_variants_f64(
c,
"simd_math/f64/tanh_u35",
&tanh_inputs_f64,
scalar_tanh_sum_f64,
simdeez_tanh_sum_f64,
simdeez_tanh_sum_scalar_f64,
);
}

fn bench_variants_f64(
c: &mut Criterion,
group_name: &str,
input: &[f64],
scalar_native: fn(&[f64]) -> f64,
simdeez_runtime: fn(&[f64]) -> f64,
simdeez_scalar: fn(&[f64]) -> f64,
) {
let mut group = c.benchmark_group(group_name);
group.throughput(Throughput::Elements(input.len() as u64));

group.bench_function("scalar-native", |b| {
b.iter(|| black_box(scalar_native(black_box(input))))
});

group.bench_function("simdeez-runtime", |b| {
b.iter(|| black_box(simdeez_runtime(black_box(input))))
});

group.bench_function("simdeez-forced-scalar", |b| {
b.iter(|| black_box(simdeez_scalar(black_box(input))))
});

group.finish();
}
172 changes: 167 additions & 5 deletions src/math/f64/hyperbolic.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,188 @@
use crate::math::{map, scalar};
use crate::SimdFloat64;
use crate::math::scalar;
use crate::{Simd, SimdBaseIo, SimdBaseOps, SimdConsts, SimdFloat64};

type SimdI64<V> = <<V as SimdConsts>::Engine as Simd>::Vi64;

const SINH_COSH_SMALL_ABS: f64 = 0.125;
const SINH_COSH_FAST_ABS_MAX: f64 = 0.125;
const TANH_SMALL_ABS: f64 = 0.0;
const TANH_FAST_ABS_MAX: f64 = 0.0;

#[inline(always)]
fn any_lane_nonzero<V>(mask: SimdI64<V>) -> bool
where
V: SimdFloat64,
{
unsafe {
let lanes = mask.as_array();
for lane in 0..V::WIDTH {
if lanes[lane] != 0 {
return true;
}
}
}

false
}

#[inline(always)]
fn patch_exceptional_lanes<V>(
input: V,
output: V,
exceptional_mask: SimdI64<V>,
scalar_fallback: fn(f64) -> f64,
) -> V
where
V: SimdFloat64,
{
if !any_lane_nonzero::<V>(exceptional_mask) {
return output;
}

unsafe {
let input_lanes = input.as_array();
let mask_lanes = exceptional_mask.as_array();
let mut output_lanes = output.as_array();

for lane in 0..V::WIDTH {
if mask_lanes[lane] != 0 {
output_lanes[lane] = scalar_fallback(input_lanes[lane]);
}
}

V::load_from_ptr_unaligned(&output_lanes as *const V::ArrayRepresentation as *const f64)
}
}

#[inline(always)]
fn exp_u35<V>(input: V) -> V
where
V: SimdFloat64,
{
// Temporary family-local bridge: use scalar exp lane mapping here while
// avoiding scalar lane mapping for the final hyperbolic functions.
unsafe {
let mut lanes = input.as_array();
for lane in 0..V::WIDTH {
lanes[lane] = scalar::exp_u35_f64(lanes[lane]);
}
V::load_from_ptr_unaligned(&lanes as *const V::ArrayRepresentation as *const f64)
}
}

#[inline(always)]
fn sinh_small<V>(input: V, input_sq: V) -> V
where
V: SimdFloat64,
{
let poly = ((((V::set1(1.0 / 39916800.0) * input_sq) + V::set1(1.0 / 362880.0)) * input_sq
+ V::set1(1.0 / 5040.0))
* input_sq
+ V::set1(1.0 / 120.0))
* input_sq
+ V::set1(1.0 / 6.0);

input + (input * input_sq * poly)
}

#[inline(always)]
fn cosh_small<V>(input_sq: V) -> V
where
V: SimdFloat64,
{
let poly = (((V::set1(1.0 / 40320.0) * input_sq) + V::set1(1.0 / 720.0)) * input_sq
+ V::set1(1.0 / 24.0))
* input_sq
+ V::set1(0.5);

V::set1(1.0) + (input_sq * poly)
}

#[inline(always)]
fn sinh_cosh_medium<V>(abs_input: V) -> (V, V)
where
V: SimdFloat64,
{
let exp_abs = exp_u35(abs_input);
let exp_neg_abs = V::set1(1.0) / exp_abs;
let half = V::set1(0.5);

(
(exp_abs - exp_neg_abs) * half,
(exp_abs + exp_neg_abs) * half,
)
}

#[inline(always)]
fn sinh_cosh_masks<V>(input: V) -> (SimdI64<V>, V, V)
where
V: SimdFloat64,
{
let abs_input = input.abs();
let finite_mask = input.cmp_eq(input).bitcast_i64();
let within_fast_range = abs_input
.cmp_lte(V::set1(SINH_COSH_FAST_ABS_MAX))
.bitcast_i64();

(finite_mask & within_fast_range, abs_input, input * input)
}

#[inline(always)]
pub(crate) fn sinh_u35<V>(input: V) -> V
where
V: SimdFloat64,
{
map::unary_f64(input, scalar::sinh_u35_f64)
let (fast_mask, abs_input, input_sq) = sinh_cosh_masks(input);
let exceptional_mask = fast_mask.cmp_eq(SimdI64::<V>::zeroes());
let small_mask = abs_input.cmp_lt(V::set1(SINH_COSH_SMALL_ABS));

let fast_small = sinh_small(input, input_sq);
let exp_input = exp_u35(input);
let exp_neg_input = V::set1(1.0) / exp_input;
let sinh_medium = (exp_input - exp_neg_input) * V::set1(0.5);
let fast = small_mask.blendv(sinh_medium, fast_small);
let zero_mask = input.cmp_eq(V::set1(0.0));
let fast = zero_mask.blendv(fast, input);

patch_exceptional_lanes(input, fast, exceptional_mask, scalar::sinh_u35_f64)
}

#[inline(always)]
pub(crate) fn cosh_u35<V>(input: V) -> V
where
V: SimdFloat64,
{
map::unary_f64(input, scalar::cosh_u35_f64)
let (fast_mask, abs_input, input_sq) = sinh_cosh_masks(input);
let exceptional_mask = fast_mask.cmp_eq(SimdI64::<V>::zeroes());
let small_mask = abs_input.cmp_lt(V::set1(SINH_COSH_SMALL_ABS));

let fast_small = cosh_small(input_sq);
let (_, cosh_medium) = sinh_cosh_medium(abs_input);
let fast = small_mask.blendv(cosh_medium, fast_small);

patch_exceptional_lanes(input, fast, exceptional_mask, scalar::cosh_u35_f64)
}

#[inline(always)]
pub(crate) fn tanh_u35<V>(input: V) -> V
where
V: SimdFloat64,
{
map::unary_f64(input, scalar::tanh_u35_f64)
let abs_input = input.abs();
let finite_mask = input.cmp_eq(input).bitcast_i64();
let within_fast_range = abs_input.cmp_lte(V::set1(TANH_FAST_ABS_MAX)).bitcast_i64();
let exceptional_mask = (finite_mask & within_fast_range).cmp_eq(SimdI64::<V>::zeroes());
let small_mask = abs_input.cmp_lt(V::set1(TANH_SMALL_ABS));

let input_sq = input * input;
let fast_small = sinh_small(input, input_sq) / cosh_small(input_sq);

let exp_input = exp_u35(input);
let exp_neg_input = V::set1(1.0) / exp_input;
let tanh_medium = (exp_input - exp_neg_input) / (exp_input + exp_neg_input);
let fast = small_mask.blendv(tanh_medium, fast_small);
let zero_mask = input.cmp_eq(V::set1(0.0));
let fast = zero_mask.blendv(fast, input);

patch_exceptional_lanes(input, fast, exceptional_mask, scalar::tanh_u35_f64)
}
Loading
Loading