diff --git a/sqlitegraph-core/src/hnsw/simd.rs b/sqlitegraph-core/src/hnsw/simd.rs index ad917da..0c1f0cf 100644 --- a/sqlitegraph-core/src/hnsw/simd.rs +++ b/sqlitegraph-core/src/hnsw/simd.rs @@ -2,38 +2,49 @@ #![allow(unused_unsafe)] //! //! This module provides SIMD-optimized implementations of vector distance -//! calculations using CPU intrinsics (AVX2 on x86_64). Functions automatically -//! dispatch to SIMD or scalar implementations based on runtime CPU feature detection. +//! calculations using CPU intrinsics (AVX-512F and AVX2 on x86_64). Functions +//! automatically dispatch to the best available SIMD path based on runtime CPU +//! feature detection. //! //! # Architecture //! -//! - **Scalar fallback**: Pure Rust implementation, always available -//! - **AVX2 path**: x86_64 intrinsics with 256-bit registers (8 floats per iteration) -//! - **Runtime dispatch**: One-time CPU feature detection with cached result +//! - **Scalar fallback**: Pure Rust, always available +//! - **AVX2 path**: 256-bit registers, 8 f32 per iteration +//! - **AVX-512F path**: 512-bit registers, 16 f32 per iteration +//! - **Runtime dispatch**: One-time CPU feature detection cached in `OnceLock` +//! +//! Dispatch order: `AVX-512F → AVX2 → Scalar`. AVX-512 is detected via +//! `is_x86_feature_detected!("avx512f")` and only used when present — +//! older x86 CPUs and non-x86 platforms transparently fall through. //! //! # Safety Guarantees //! //! All unsafe blocks are contained within this module and only use: -//! - Unaligned loads (`_mm256_loadu_ps`) - no alignment requirements -//! - Standard SIMD intrinsics - well-defined behavior for any f32 input -//! - Proper remainder handling - scalar loop processes trailing elements +//! - Unaligned loads (`_mm256_loadu_ps`, `_mm512_loadu_ps`) — no alignment requirements +//! - Standard SIMD intrinsics — well-defined behavior for any f32 input +//! - Proper remainder handling — scalar loop processes trailing elements //! //! # Performance Characteristics //! +//! ## AVX-512F (512-bit) +//! - **Throughput**: 16 floats per iteration +//! - **Speedup**: ~2× vs AVX2 on supported CPUs (Intel Sapphire Rapids+, AMD Zen4+) +//! - **FMA**: Uses `_mm512_fmadd_ps` for fused multiply-add throughout +//! //! ## AVX2 (256-bit) //! - **Throughput**: 8 floats per iteration -//! - **Speedup**: ~4-6x for large vectors vs scalar (depends on FMA availability) -//! - **Latency**: Similar to scalar for small vectors (< 8 elements) +//! - **Speedup**: ~4-6× vs scalar (depends on FMA availability) //! //! ## Scalar Fallback //! - **Throughput**: 1 float per iteration -//! - **Availability**: All platforms, all CPUs -//! - **Performance**: Baseline, optimized Rust code +//! - **Availability**: All platforms //! //! # Correctness //! -//! SIMD and scalar implementations produce **bit-identical** results for the same inputs. -//! All operations follow IEEE 754 floating-point semantics. +//! SIMD and scalar implementations produce **bit-identical or near-identical** +//! results for the same inputs (differences ≤ `1e-5` relative error due to the +//! different summation order of horizontal reductions). All operations follow +//! IEEE 754 floating-point semantics. //! //! # Examples //! @@ -49,9 +60,42 @@ use std::sync::OnceLock; -// Cache for CPU feature detection result -// Initialized once on first call, then reused for all subsequent calls -static HAS_AVX2: OnceLock = OnceLock::new(); +/// Best SIMD path detected on the current CPU. +/// +/// Detection runs once per process via [`simd_level`]; the result is cached in +/// a [`OnceLock`]. On non-x86 platforms this is always [`SimdLevel::Scalar`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SimdLevel { + /// AVX-512F (and FMA, which AVX-512F implies via VEX-encoded FMA on every + /// AVX-512F-capable CPU). 16 f32 per register. + Avx512, + /// AVX2 + FMA. 8 f32 per register. + Avx2, + /// No SIMD — scalar Rust loop. + Scalar, +} + +static SIMD_LEVEL: OnceLock = OnceLock::new(); + +/// Detect the best available SIMD level on this CPU. +/// +/// Called once on first use; subsequent calls return the cached value +/// from [`SIMD_LEVEL`] without re-running CPUID. Always returns +/// [`SimdLevel::Scalar`] on non-x86_64 targets. +pub fn simd_level() -> SimdLevel { + *SIMD_LEVEL.get_or_init(|| { + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("avx512f") { + return SimdLevel::Avx512; + } + if std::arch::is_x86_feature_detected!("avx2") { + return SimdLevel::Avx2; + } + } + SimdLevel::Scalar + }) +} /// Scalar fallback implementation of dot product /// @@ -178,19 +222,63 @@ unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 { result } -/// Runtime-dispatched dot product with AVX2 acceleration +/// AVX-512F implementation of dot product using 512-bit registers. /// -/// This function automatically selects the best implementation based on -/// CPU features detected at runtime. The detection result is cached -/// after the first call. +/// Processes 16 floats per iteration via `_mm512_fmadd_ps` (fused +/// multiply-add). Falls back to scalar for `len % 16` remainder. /// -/// # Behavior +/// # Safety /// -/// - **x86_64 with AVX2**: Uses `dot_product_avx2` (4-6x faster for large vectors) -/// - **Other platforms**: Uses `dot_product_scalar` (baseline performance) +/// Must only be called on CPUs that support AVX-512F. The caller must verify +/// via `is_x86_feature_detected!("avx512f")` before calling. /// -/// The CPU feature check happens once on first call and is cached using -/// `std::sync::OnceLock` for minimal overhead. +/// # Performance +/// +/// - **Aligned vectors (16+ elements)**: ~2× faster than AVX2 +/// - **Small vectors (< 16 elements)**: Similar to AVX2 (overhead dominates) +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[inline] +unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 { + use std::arch::x86_64::*; + + assert_eq!(a.len(), b.len(), "Vectors must have the same length"); + + let len = a.len(); + let simd_len = len & !15; // round down to nearest multiple of 16 + let mut i = 0; + let mut result = 0.0f32; + + if simd_len > 0 { + let mut sum = unsafe { _mm512_setzero_ps() }; + while i < simd_len { + let va = unsafe { _mm512_loadu_ps(a.as_ptr().add(i)) }; + let vb = unsafe { _mm512_loadu_ps(b.as_ptr().add(i)) }; + sum = unsafe { _mm512_fmadd_ps(va, vb, sum) }; + i += 16; + } + result = unsafe { _mm512_reduce_add_ps(sum) }; + } + + // Remainder + while i < len { + result += a[i] * b[i]; + i += 1; + } + result +} + +/// Runtime-dispatched dot product with AVX-512 / AVX2 acceleration. +/// +/// Selects the best available SIMD path (AVX-512F → AVX2 → Scalar) based on +/// CPU features detected at runtime. The detection happens once and is cached +/// in [`SIMD_LEVEL`] for minimal per-call overhead. +/// +/// # Behavior +/// +/// - **AVX-512F**: 16 floats per iteration via `dot_product_avx512` (~2× AVX2) +/// - **AVX2**: 8 floats per iteration via `dot_product_avx2` (~4-6× scalar) +/// - **Other**: scalar fallback /// /// # Arguments /// @@ -216,31 +304,21 @@ unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 { /// assert_eq!(product, 70.0); // 1*5 + 2*6 + 3*7 + 4*8 /// # Ok::<(), Box>(()) /// ``` -/// -/// # Performance -/// -/// - **First call**: O(1) CPU feature detection + O(n) computation -/// - **Subsequent calls**: O(n) computation only (cached detection) -/// - **Large vectors (100+ elements)**: 4-6x speedup with AVX2 -/// - **Small vectors (< 8 elements)**: Minimal difference (scalar overhead similar) #[inline] pub fn dot_product(a: &[f32], b: &[f32]) -> f32 { #[cfg(target_arch = "x86_64")] { - // Check AVX2 support once and cache the result - let has_avx2 = HAS_AVX2.get_or_init(|| std::arch::is_x86_feature_detected!("avx2")); - - // SAFETY: We've verified AVX2 is available before calling the unsafe function - if *has_avx2 { - unsafe { dot_product_avx2(a, b) } - } else { - dot_product_scalar(a, b) + // SAFETY: simd_level() only returns Avx512/Avx2 after verifying + // the corresponding CPU feature via is_x86_feature_detected!. + match simd_level() { + SimdLevel::Avx512 => unsafe { dot_product_avx512(a, b) }, + SimdLevel::Avx2 => unsafe { dot_product_avx2(a, b) }, + SimdLevel::Scalar => dot_product_scalar(a, b), } } #[cfg(not(target_arch = "x86_64"))] { - // Non-x86_64 platforms always use scalar fallback dot_product_scalar(a, b) } } @@ -352,15 +430,47 @@ unsafe fn compute_norm_squared_avx2(v: &[f32]) -> f32 { } } -/// Runtime-dispatched squared norm computation with AVX2 acceleration +/// AVX-512F implementation of squared norm computation. /// -/// This function automatically selects the best implementation based on -/// CPU features detected at runtime. +/// Processes 16 floats per iteration via `_mm512_fmadd_ps(v, v, sum)`. +/// Falls back to scalar for `len % 16` remainder. /// -/// # Behavior +/// # Safety /// -/// - **x86_64 with AVX2**: Uses `compute_norm_squared_avx2` (4-6x faster for large vectors) -/// - **Other platforms**: Uses `compute_norm_squared_scalar` (baseline performance) +/// Must only be called on CPUs that support AVX-512F. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[inline] +unsafe fn compute_norm_squared_avx512(v: &[f32]) -> f32 { + use std::arch::x86_64::*; + + let len = v.len(); + let simd_len = len & !15; + let mut i = 0; + let mut result = 0.0f32; + + if simd_len > 0 { + let mut sum = unsafe { _mm512_setzero_ps() }; + while i < simd_len { + let vv = unsafe { _mm512_loadu_ps(v.as_ptr().add(i)) }; + sum = unsafe { _mm512_fmadd_ps(vv, vv, sum) }; + i += 16; + } + result = unsafe { _mm512_reduce_add_ps(sum) }; + } + + while i < len { + let val = v[i]; + result += val * val; + i += 1; + } + result +} + +/// Runtime-dispatched squared norm computation with AVX-512 / AVX2 acceleration. +/// +/// Selects the best available SIMD path (AVX-512F → AVX2 → Scalar). The +/// dispatch decision is cached in [`SIMD_LEVEL`]. /// /// # Arguments /// @@ -386,12 +496,10 @@ unsafe fn compute_norm_squared_avx2(v: &[f32]) -> f32 { pub fn compute_norm_squared(v: &[f32]) -> f32 { #[cfg(target_arch = "x86_64")] { - let has_avx2 = HAS_AVX2.get_or_init(|| std::arch::is_x86_feature_detected!("avx2")); - - if *has_avx2 { - unsafe { compute_norm_squared_avx2(v) } - } else { - compute_norm_squared_scalar(v) + match simd_level() { + SimdLevel::Avx512 => unsafe { compute_norm_squared_avx512(v) }, + SimdLevel::Avx2 => unsafe { compute_norm_squared_avx2(v) }, + SimdLevel::Scalar => compute_norm_squared_scalar(v), } } @@ -489,16 +597,67 @@ unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 { } } -/// Runtime-dispatched cosine similarity with AVX2 acceleration +/// AVX-512F implementation of cosine similarity. /// -/// This function automatically selects the best implementation based on -/// CPU features detected at runtime. Uses SIMD-accelerated dot product -/// and norm computation for maximum performance. +/// Fuses the dot product and the two squared-norm computations into a single +/// loop over 16-float chunks, sharing loads of both vectors. Uses three +/// independent FMA accumulators (`dot`, `norm_a`, `norm_b`). /// -/// # Behavior +/// # Safety +/// +/// Must only be called on CPUs that support AVX-512F. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[inline] +unsafe fn cosine_similarity_avx512(a: &[f32], b: &[f32]) -> f32 { + use std::arch::x86_64::*; + + assert!(!a.is_empty(), "Vectors cannot be empty"); + + let len = a.len(); + let simd_len = len & !15; + let mut i = 0; + + let mut dot_sum = 0.0f32; + let mut norm_a_sum = 0.0f32; + let mut norm_b_sum = 0.0f32; + + if simd_len > 0 { + let mut dot = unsafe { _mm512_setzero_ps() }; + let mut na = unsafe { _mm512_setzero_ps() }; + let mut nb = unsafe { _mm512_setzero_ps() }; + while i < simd_len { + let va = unsafe { _mm512_loadu_ps(a.as_ptr().add(i)) }; + let vb = unsafe { _mm512_loadu_ps(b.as_ptr().add(i)) }; + dot = unsafe { _mm512_fmadd_ps(va, vb, dot) }; + na = unsafe { _mm512_fmadd_ps(va, va, na) }; + nb = unsafe { _mm512_fmadd_ps(vb, vb, nb) }; + i += 16; + } + dot_sum = unsafe { _mm512_reduce_add_ps(dot) }; + norm_a_sum = unsafe { _mm512_reduce_add_ps(na) }; + norm_b_sum = unsafe { _mm512_reduce_add_ps(nb) }; + } + + while i < len { + dot_sum += a[i] * b[i]; + norm_a_sum += a[i] * a[i]; + norm_b_sum += b[i] * b[i]; + i += 1; + } + + let norm_a = norm_a_sum.sqrt(); + let norm_b = norm_b_sum.sqrt(); + assert!(norm_a > f32::EPSILON, "First vector has zero magnitude"); + assert!(norm_b > f32::EPSILON, "Second vector has zero magnitude"); + dot_sum / (norm_a * norm_b) +} + +/// Runtime-dispatched cosine similarity with AVX-512 / AVX2 acceleration. /// -/// - **x86_64 with AVX2**: Uses `cosine_similarity_avx2` (4-6x faster for large vectors) -/// - **Other platforms**: Uses `cosine_similarity_scalar` (baseline performance) +/// Selects the best available SIMD path (AVX-512F → AVX2 → Scalar). On AVX-512 +/// the dot product and both squared norms are fused into a single 16-wide loop; +/// on AVX2 they are computed via the 8-wide kernels in sequence. /// /// # Arguments /// @@ -538,12 +697,10 @@ pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { #[cfg(target_arch = "x86_64")] { - let has_avx2 = HAS_AVX2.get_or_init(|| std::arch::is_x86_feature_detected!("avx2")); - - if *has_avx2 { - unsafe { cosine_similarity_avx2(a, b) } - } else { - cosine_similarity_scalar(a, b) + match simd_level() { + SimdLevel::Avx512 => unsafe { cosine_similarity_avx512(a, b) }, + SimdLevel::Avx2 => unsafe { cosine_similarity_avx2(a, b) }, + SimdLevel::Scalar => cosine_similarity_scalar(a, b), } } @@ -678,17 +835,51 @@ unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 { } } -/// Runtime-dispatched Euclidean (L2) distance computation +/// AVX-512F implementation of Euclidean (L2) distance. /// -/// Automatically selects AVX2 or scalar implementation based on CPU -/// feature detection at runtime. Provides optimal performance on AVX2-capable -/// hardware while maintaining compatibility with all platforms. +/// Processes 16 floats per iteration. The squared differences are accumulated +/// in-register via `_mm512_fmadd_ps(diff, diff, sum)`, then horizontally +/// reduced once at the end (vs. one reduction per chunk in the AVX2 path). /// -/// # Behavior +/// # Safety /// -/// - **x86_64 with AVX2**: Uses `euclidean_distance_avx2` (8x parallelism) -/// - **Other platforms**: Uses `euclidean_distance_scalar` (fallback) -/// - **Detection**: Cached after first check (no repeated overhead) +/// Must only be called on CPUs that support AVX-512F. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[inline] +unsafe fn euclidean_distance_avx512(a: &[f32], b: &[f32]) -> f32 { + use std::arch::x86_64::*; + + assert_eq!(a.len(), b.len(), "Vectors must have the same length"); + + let len = a.len(); + let simd_len = len & !15; + let mut i = 0; + let mut sum = 0.0f32; + + if simd_len > 0 { + let mut acc = unsafe { _mm512_setzero_ps() }; + while i < simd_len { + let va = unsafe { _mm512_loadu_ps(a.as_ptr().add(i)) }; + let vb = unsafe { _mm512_loadu_ps(b.as_ptr().add(i)) }; + let diff = unsafe { _mm512_sub_ps(va, vb) }; + acc = unsafe { _mm512_fmadd_ps(diff, diff, acc) }; + i += 16; + } + sum = unsafe { _mm512_reduce_add_ps(acc) }; + } + + while i < len { + let d = a[i] - b[i]; + sum += d * d; + i += 1; + } + sum.sqrt() +} + +/// Runtime-dispatched Euclidean (L2) distance with AVX-512 / AVX2 acceleration. +/// +/// Selects the best available SIMD path (AVX-512F → AVX2 → Scalar). /// /// # Arguments /// @@ -701,21 +892,19 @@ unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 { /// /// # Performance /// -/// - AVX2: ~8x speedup for large vectors -/// - Scalar: Baseline performance (same as iterator-based) -/// - Detection overhead: O(1) after first call +/// - AVX-512: ~16× scalar throughput per iteration +/// - AVX2: ~8× scalar throughput per iteration +/// - Detection overhead: O(1) after first call (cached) #[inline] pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { assert_eq!(a.len(), b.len(), "Vectors must have the same length"); #[cfg(target_arch = "x86_64")] { - let has_avx2 = HAS_AVX2.get_or_init(|| std::arch::is_x86_feature_detected!("avx2")); - - if *has_avx2 { - unsafe { euclidean_distance_avx2(a, b) } - } else { - euclidean_distance_scalar(a, b) + match simd_level() { + SimdLevel::Avx512 => unsafe { euclidean_distance_avx512(a, b) }, + SimdLevel::Avx2 => unsafe { euclidean_distance_avx2(a, b) }, + SimdLevel::Scalar => euclidean_distance_scalar(a, b), } } @@ -1275,4 +1464,184 @@ mod tests { let b = vec![1.0, 2.0]; euclidean_distance(&a, &b); } + + // ------------------------------------------------------------------------- + // SimdLevel detection tests + // ------------------------------------------------------------------------- + + #[test] + fn test_simd_level_detection_succeeds() { + // Never panics, always returns a valid variant. + let level = simd_level(); + eprintln!("Detected SIMD level: {:?}", level); + + // Idempotent + cached: calling again returns the same value. + assert_eq!(level, simd_level()); + assert_eq!(level, simd_level()); + } + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_simd_level_matches_cpu_features() { + // Verify simd_level() picks the highest available path. + let level = simd_level(); + let has_avx512 = std::arch::is_x86_feature_detected!("avx512f"); + let has_avx2 = std::arch::is_x86_feature_detected!("avx2"); + match level { + SimdLevel::Avx512 => assert!(has_avx512), + SimdLevel::Avx2 => { + assert!(has_avx2); + assert!(!has_avx512, "should have picked Avx512 if available"); + } + SimdLevel::Scalar => { + assert!(!has_avx2); + assert!(!has_avx512); + } + } + } + + // ------------------------------------------------------------------------- + // AVX-512 vs scalar correctness + // ------------------------------------------------------------------------- + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_dot_product_matches_scalar() { + if !std::arch::is_x86_feature_detected!("avx512f") { + eprintln!("AVX-512 not available, skipping"); + return; + } + let a: Vec = (0..384).map(|i| (i as f32) * 0.01).collect(); + let b: Vec = (0..384).map(|i| (i as f32) * 0.02 - 0.5).collect(); + let scalar = dot_product_scalar(&a, &b); + let avx512 = unsafe { dot_product_avx512(&a, &b) }; + let abs_diff = (scalar - avx512).abs(); + let rel_error = abs_diff / scalar.abs().max(f32::EPSILON); + assert!( + rel_error < 1e-5, + "scalar={scalar}, avx512={avx512}, rel_error={rel_error}" + ); + } + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_norm_squared_matches_scalar() { + if !std::arch::is_x86_feature_detected!("avx512f") { + return; + } + let v: Vec = (0..512).map(|i| (i as f32) * 0.03).collect(); + let scalar = compute_norm_squared_scalar(&v); + let avx512 = unsafe { compute_norm_squared_avx512(&v) }; + let rel = (scalar - avx512).abs() / scalar.abs(); + assert!(rel < 1e-5, "scalar={scalar}, avx512={avx512}, rel={rel}"); + } + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_cosine_similarity_matches_scalar() { + if !std::arch::is_x86_feature_detected!("avx512f") { + return; + } + let a: Vec = (0..768).map(|i| (i as f32) * 0.01 + 1.0).collect(); + let b: Vec = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect(); + let scalar = cosine_similarity_scalar(&a, &b); + let avx512 = unsafe { cosine_similarity_avx512(&a, &b) }; + assert!( + (scalar - avx512).abs() < 1e-5, + "scalar={scalar}, avx512={avx512}" + ); + } + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_euclidean_distance_matches_scalar() { + if !std::arch::is_x86_feature_detected!("avx512f") { + return; + } + let a: Vec = (0..1024).map(|i| (i as f32) * 0.5).collect(); + let b: Vec = (0..1024).map(|i| (i as f32) * 0.5 + 0.25).collect(); + let scalar = euclidean_distance_scalar(&a, &b); + let avx512 = unsafe { euclidean_distance_avx512(&a, &b) }; + let rel = (scalar - avx512).abs() / scalar.abs().max(f32::EPSILON); + assert!(rel < 1e-5, "scalar={scalar}, avx512={avx512}, rel={rel}"); + } + + // ------------------------------------------------------------------------- + // AVX-512 remainder handling (len % 16 != 0) + // ------------------------------------------------------------------------- + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_remainder_handling() { + if !std::arch::is_x86_feature_detected!("avx512f") { + return; + } + // Sizes that exercise every possible (len / 16, len % 16) bucket. + for size in [1usize, 7, 15, 16, 17, 31, 32, 33, 48, 137, 255] { + let a: Vec = (0..size).map(|i| i as f32).collect(); + let b: Vec = (0..size).map(|i| (i as f32) * 0.5).collect(); + + let scalar = dot_product_scalar(&a, &b); + let avx512 = unsafe { dot_product_avx512(&a, &b) }; + let abs = (scalar - avx512).abs(); + let rel = if scalar.abs() > f32::EPSILON { + abs / scalar.abs() + } else { + abs + }; + assert!( + rel < 1e-5 || abs < 1e-3, + "dot size={size}: scalar={scalar}, avx512={avx512}, rel={rel}" + ); + + let scalar_e = euclidean_distance_scalar(&a, &b); + let avx512_e = unsafe { euclidean_distance_avx512(&a, &b) }; + let abs_e = (scalar_e - avx512_e).abs(); + let rel_e = if scalar_e.abs() > f32::EPSILON { + abs_e / scalar_e.abs() + } else { + abs_e + }; + assert!( + rel_e < 1e-5 || abs_e < 1e-3, + "euclidean size={size}: scalar={scalar_e}, avx512={avx512_e}, rel={rel_e}" + ); + } + } + + // ------------------------------------------------------------------------- + // Typical embedding dimensions through the public dispatch. + // ------------------------------------------------------------------------- + + #[test] + fn test_dispatch_typical_embedding_dims() { + // 384 = MiniLM, 768 = BERT-base, 1024 = Voyage-2/Mistral, 1536 = OpenAI ada-002 + for dim in [384usize, 768, 1024, 1536] { + let a: Vec = (0..dim).map(|i| (i as f32) * 0.001 + 0.1).collect(); + let b: Vec = (0..dim).map(|i| (i as f32) * 0.002 - 0.05).collect(); + + let scalar_dot = dot_product_scalar(&a, &b); + let auto_dot = dot_product(&a, &b); + let rel_dot = (scalar_dot - auto_dot).abs() / scalar_dot.abs().max(f32::EPSILON); + assert!( + rel_dot < 1e-5, + "dot dim={dim}: scalar={scalar_dot}, auto={auto_dot}, rel={rel_dot}" + ); + + let scalar_cos = cosine_similarity_scalar(&a, &b); + let auto_cos = cosine_similarity(&a, &b); + assert!( + (scalar_cos - auto_cos).abs() < 1e-5, + "cosine dim={dim}: scalar={scalar_cos}, auto={auto_cos}" + ); + + let scalar_eu = euclidean_distance_scalar(&a, &b); + let auto_eu = euclidean_distance(&a, &b); + let rel_eu = (scalar_eu - auto_eu).abs() / scalar_eu.abs().max(f32::EPSILON); + assert!( + rel_eu < 1e-5, + "euclidean dim={dim}: scalar={scalar_eu}, auto={auto_eu}, rel={rel_eu}" + ); + } + } }