From 4306478b76a0fd385b5ad95d1ea2f1d926e875e8 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Fri, 5 Jun 2026 16:22:55 -0400 Subject: [PATCH] optimize spherical heap compression + fix global threadpool in exhasutive benchmark --- .../src/backend/exhaustive/spherical.rs | 16 +- .../src/spherical/quantizer.rs | 201 ++++++++++-------- 2 files changed, 124 insertions(+), 93 deletions(-) diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index be29865cd..5be716153 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -116,13 +116,19 @@ mod imp { // Compressing let start = std::time::Instant::now(); let store = { + let threadpool = rayon::ThreadPoolBuilder::new() + .num_threads(input.compression_threads.get()) + .build()?; + let compression_progress = make_progress_bar("compressing", data.nrows(), output.draw_target())?; - let store = Store::new( - data.as_view(), - diskann_quantization::spherical::iface::Impl::::new(quantizer)?, - &compression_progress, - )?; + let store = threadpool.install(|| { + Store::new( + data.as_view(), + diskann_quantization::spherical::iface::Impl::::new(quantizer)?, + &compression_progress, + ) + })?; compression_progress.finish(); store }; diff --git a/diskann-quantization/src/spherical/quantizer.rs b/diskann-quantization/src/spherical/quantizer.rs index ec82b48ce..a4e228264 100644 --- a/diskann-quantization/src/spherical/quantizer.rs +++ b/diskann-quantization/src/spherical/quantizer.rs @@ -20,10 +20,7 @@ use super::{ }; use crate::{ AsFunctor, CompressIntoWith, - algorithms::{ - heap::SliceHeap, - transforms::{NewTransformError, Transform, TransformFailed, TransformKind}, - }, + algorithms::transforms::{NewTransformError, Transform, TransformFailed, TransformKind}, alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone}, bits::{PermutationStrategy, Representation, Unsigned}, num::Positive, @@ -988,6 +985,12 @@ fn maximize_cosine_similarity( num_bits: NonZeroUsize, allocator: ScopedAllocator<'_>, ) -> Result { + // Lint: This is a private method and all the callers have an invariant that they check + // for non-empty inputs. + #[allow(clippy::expect_used)] + let _: NonZeroUsize = + NonZeroUsize::new(v.len()).expect("calling code should not allow the slice to be empty"); + // Initially, the lattice element has the value `0.5` for all dimensions. // This means the initial inner product between `v` and the rounded term is simply // `0.5 * sum(abs.(v))`. The absolute value is used because the latice element is @@ -995,99 +998,121 @@ fn maximize_cosine_similarity( let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::(); let mut current_square_norm = 0.25 * (v.len() as f64); - // Book keeping for the current value of the rounded vector. - // The true numeric value is 0.5 less than this (in the direction of `v`), but we use - // integers for a smaller memory footprint. - let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?; - - // Compute the critical values and store them on a heap. + // Enumerate every critical value into a single flat buffer; which is then sorted. + // + // Critical values for dimension `i` are the scaling factors at which `rounded[i]` + // advances from one integer level to the next. // - // The binary heap will keep track of the minimum critical value. Multiplying `v` by the - // minimum critical value `s` means that `s * v` will only change `rounded` from its - // current value at a single index (the position associated with `s`). + // For `num_bits >= 2`, the encoding for dimension `i` saturates + // once `rounded[i] == stop`, so dimension `i` contributes `stop - 1` critical values. + // + // For `num_bits == 1` the encoding only has the levels `0` and + // `1`, so each dimension contributes a single critical value at the transition from + // `r = 1` to `r = 2`; `visits_per_dim` is clamped to 1 to capture that case. + // + // The `eps` term breaks ties between dimensions that would otherwise transition at the + // exact same scaling factor. + let stop: usize = 1usize << (num_bits.get() - 1); + let visits_per_dim: usize = stop.max(2) - 1; + let total = v.len() * visits_per_dim; let eps = 0.0001f32; - let one_and_change = 1.0 + eps; - let mut base = Poly::from_iter( - v.iter().enumerate().map(|(position, value)| { - let value = one_and_change / value.abs(); - Pair { - value, - position: position as u32, + + let mut crits = Poly::<[Pair], _>::new_uninit_slice(total, allocator)?; + { + let buf = crits.as_mut(); + let mut k = 0usize; + for (position, value) in v.iter().enumerate() { + // Hoist the reciprocal so the inner loop only multiplies. + let inv = 1.0f32 / value.abs(); + for r in 1..=visits_per_dim { + // SAFETY: `k` is bounded above by `v.len() * visits_per_dim = total`, + // which is exactly the length of `buf`. + unsafe { + buf.get_unchecked_mut(k).write(Pair { + value: (r as f32 + eps) * inv, + position: position as u32, + }); + } + k += 1; } - }), - allocator, - )?; + } + debug_assert_eq!(k, total); + } + // SAFETY: The loop above initialized exactly `total` entries, matching the slice's + // length. + let mut crits = unsafe { crits.assume_init() }; + + // Sort critical values in ascending order so that walking the slice corresponds to + // sweeping `s` from `0` to `+inf`. `Pair`'s `Ord` impl is reversed so it's + // in ascending order. + crits.sort_unstable_by(|a, b| { + a.value + .partial_cmp(&b.value) + .unwrap_or(std::cmp::Ordering::Equal) + }); - // Lint: This is a private method and all the callers have an invariant that they check - // for non-empty inputs. - #[allow(clippy::expect_used)] - let mut critical_values = - SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty"); + // Book keeping for the current value of the rounded vector. + // The true numeric value is 0.5 less than this (in the direction of `v`), but we use + // integers for a smaller memory footprint. + let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?; - let mut max_similarity = f64::NEG_INFINITY; + // `max_ip_sq / max_sn` is the squared best cosine similarity seen so far. + // See the cosine-similarity comparison below for why squared quantities are tracked + // instead of the cosine similarity directly. + let mut max_ip_sq = f64::NEG_INFINITY; + let mut max_sn = 1.0f64; let mut optimal_scale = f32::default(); - let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16; - - loop { - let mut should_break = false; - critical_values.update_root(|pair| { - let Pair { value, position } = *pair; - if value == f32::MAX { - should_break = true; - return; - } - let r = &mut rounded[position as usize]; - let vp = &v[position as usize]; + for &Pair { value, position } in crits.iter() { + // SAFETY: `position` is in `0..v.len()` by construction above and is never + // modified. + let r = unsafe { rounded.get_unchecked_mut(position as usize) }; + // SAFETY: Same as above. + let vp = unsafe { *v.get_unchecked(position as usize) }; - let old_r = *r; - // By the nature of cricital values, only `r` will change in `rounded` when - // multiplying by `value`. And that change will be to increase by 1. - *r += 1; - - // The inner product estimate simply increases by `vp.abs()` because: - // - // * `r` is the only value in `rounded` that changes. - // * `r` is increased by 1. - current_ip += vp.abs() as f64; - - // This uses the formula - // ```math - // (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2 - // = 2x + 1 - // ``` - // substitute `x = y - 1/2` to obtain the true value associated with rounded and - // we get - // ```math - // 2 ( y - 1/2 ) + 1 = 2y - 1 + 1 - // = 2y - // ``` - // Therefore, the change in the estimate for the square norm of `rounded` is - // `2 * old_r`. - current_square_norm += (2 * old_r) as f64; - - // Compute the current cosine similarity and update max if needed. - let similarity = current_ip / current_square_norm.sqrt(); - if similarity > max_similarity { - max_similarity = similarity; - optimal_scale = value; - } + let old_r = *r; + // By the nature of critical values, only `r` will change in `rounded` when + // multiplying by `value`. And that change will be to increase by 1. + *r += 1; - // Compute the scaling factor that will change this dimension to the next value. - if *r < stop { - *pair = Pair { - value: (*r as f32 + eps) / vp.abs(), - position, - }; - } else { - *pair = Pair { - value: f32::MAX, - position, - }; - } - }); - if should_break { - break; + // The inner product estimate simply increases by `vp.abs()` because: + // + // * `r` is the only value in `rounded` that changes. + // * `r` is increased by 1. + current_ip += vp.abs() as f64; + + // This uses the formula + // ```math + // (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2 + // = 2x + 1 + // ``` + // substitute `x = y - 1/2` to obtain the true value associated with rounded and we + // get + // ```math + // 2 ( y - 1/2 ) + 1 = 2y - 1 + 1 + // = 2y + // ``` + // Therefore, the change in the estimate for the square norm of `rounded` is + // `2 * old_r`. + current_square_norm += (2 * old_r) as f64; + + // Compare cosine similarities without taking square roots. The cosine similarity + // is `ip / sqrt(sn)`, so the comparison we want to make is + // ```math + // ip / sqrt(sn) > max_ip / sqrt(max_sn) + // ``` + // Both sides are non-negative: `current_ip` starts at `0.5 * sum(|v|)` and only + // grows, and `current_square_norm` is strictly positive. Squaring is monotonic on + // non-negative reals, so the comparison is equivalent to + // ```math + // ip^2 * max_sn > max_ip^2 * sn + // ``` + // which avoids the `sqrt` per critical value. + let ip_sq = current_ip * current_ip; + if ip_sq * max_sn > max_ip_sq * current_square_norm { + max_ip_sq = ip_sq; + max_sn = current_square_norm; + optimal_scale = value; } }