From a9c5b50dc61e077684b69a3e27b68e7a1bc867dc Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 6 Jun 2026 16:08:03 -0700 Subject: [PATCH 01/45] This is a lot of work. --- Cargo.lock | 11 + Cargo.toml | 2 +- diskann-inmem/Cargo.toml | 18 + diskann-inmem/src/arbiter/buffer.rs | 200 ++++++++++ diskann-inmem/src/arbiter/epoch.rs | 156 ++++++++ diskann-inmem/src/arbiter/freelist.rs | 117 ++++++ diskann-inmem/src/arbiter/generation.rs | 121 ++++++ diskann-inmem/src/arbiter/mod.rs | 15 + diskann-inmem/src/layers/full.rs | 211 +++++++++++ diskann-inmem/src/layers/mod.rs | 46 +++ diskann-inmem/src/lib.rs | 15 + diskann-inmem/src/neighbors.rs | 225 ++++++++++++ diskann-inmem/src/num.rs | 10 + diskann-inmem/src/provider.rs | 346 ++++++++++++++++++ diskann-inmem/src/store.rs | 209 +++++++++++ .../src/distance/distance_provider.rs | 18 +- 16 files changed, 1718 insertions(+), 2 deletions(-) create mode 100644 diskann-inmem/Cargo.toml create mode 100644 diskann-inmem/src/arbiter/buffer.rs create mode 100644 diskann-inmem/src/arbiter/epoch.rs create mode 100644 diskann-inmem/src/arbiter/freelist.rs create mode 100644 diskann-inmem/src/arbiter/generation.rs create mode 100644 diskann-inmem/src/arbiter/mod.rs create mode 100644 diskann-inmem/src/layers/full.rs create mode 100644 diskann-inmem/src/layers/mod.rs create mode 100644 diskann-inmem/src/lib.rs create mode 100644 diskann-inmem/src/neighbors.rs create mode 100644 diskann-inmem/src/num.rs create mode 100644 diskann-inmem/src/provider.rs create mode 100644 diskann-inmem/src/store.rs diff --git a/Cargo.lock b/Cargo.lock index 1ecde9f9f..3195695f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,6 +806,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "diskann-inmem" +version = "0.53.0" +dependencies = [ + "bytemuck", + "diskann", + "diskann-utils", + "diskann-vector", + "thiserror 2.0.17", +] + [[package]] name = "diskann-label-filter" version = "0.54.0" diff --git a/Cargo.toml b/Cargo.toml index b285a94f6..5ba752b97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ members = [ "diskann-benchmark", "diskann-tools", "vectorset", - "diskann-bftree", + "diskann-bftree", "diskann-inmem", ] default-members = [ diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml new file mode 100644 index 000000000..2516abf2c --- /dev/null +++ b/diskann-inmem/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "diskann-inmem" +version.workspace = true +description.workspace = true +authors.workspace = true +documentation.workspace = true +license.workspace = true +edition.workspace = true + +[dependencies] +bytemuck = { workspace = true, features = ["must_cast"] } +diskann = { workspace = true } +diskann-utils = { workspace = true, default-features = false } +diskann-vector.workspace = true +thiserror.workspace = true + +[lints] +workspace = true diff --git a/diskann-inmem/src/arbiter/buffer.rs b/diskann-inmem/src/arbiter/buffer.rs new file mode 100644 index 000000000..7249414e4 --- /dev/null +++ b/diskann-inmem/src/arbiter/buffer.rs @@ -0,0 +1,200 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{alloc::Layout, marker::PhantomData, ptr::NonNull, sync::atomic::AtomicU64}; + +use crate::num::{Align, Bytes}; + +#[derive(Debug)] +pub struct Buffer { + ptr: NonNull, + stride: Bytes, + entries: usize, + layout: Layout, +} + +impl Buffer { + pub fn new(entries: usize, bytes_per_entry: Bytes, align: Align) -> Self { + let size = bytes_per_entry.0.checked_mul(entries).unwrap(); + let layout = std::alloc::Layout::from_size_align(size, align.0).unwrap(); + + let ptr = unsafe { std::alloc::alloc_zeroed(layout) }; + let ptr = match NonNull::new(ptr) { + Some(ptr) => ptr, + None => std::alloc::handle_alloc_error(layout), + }; + + Self { + ptr, + stride: bytes_per_entry, + entries, + layout, + } + } + + #[inline] + pub fn len(&self) -> usize { + self.entries + } + + #[inline] + pub fn stride(&self) -> Bytes { + self.stride + } + + #[inline] + pub fn align(&self) -> Align { + Align(self.layout.align()) + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Issue prefetch hints for the entry at index `i`. + /// + /// `bytes` controls how many bytes to prefetch (clamped to `stride`). + /// Uses `wrapping_add` to avoid UB on out-of-bounds indices — prefetching + /// a bad address is architecturally harmless. + #[inline(always)] + pub fn prefetch(&self, i: usize, bytes: usize) { + let offset = self.stride.0.wrapping_mul(i); + let ptr = self.ptr.as_ptr().wrapping_add(offset); + let bytes = bytes.min(self.stride.0); + prefetch_cachelines(ptr, bytes); + } + + #[inline] + pub fn get(&self, i: usize) -> Option> { + if i >= self.entries { + None + } else { + // SAFETY: We have validated that `i < self.entries`. This does two things: + // + // 1. Ensure that the multiplication will not overflow. + // 2. Ensures that the computed offset is within the original allocation. + Some(unsafe { self.get_unchecked(i) }) + } + } + + /// Get the slice for entry `i` without bounds checking. + /// + /// # Safety + /// + /// `i` must be less than [`len`](Self::len). + #[inline] + pub unsafe fn get_unchecked(&self, i: usize) -> Slice<'_> { + debug_assert!(i < self.entries); + let ptr = unsafe { self.ptr.add(self.stride.0 * i) }; + Slice { + ptr, + len: self.stride.0, + _lifetime: PhantomData, + } + } +} + +impl Drop for Buffer { + fn drop(&mut self) { + // SAFETY: This is the same pointer and allocation that was previously returned + // from a successful `alloc_zeroed`. + unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) } + } +} + +// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `Slice` +// the needs to be arbitrated. +unsafe impl Send for Buffer {} + +// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `Slice` +// the needs to be arbitrated. +unsafe impl Sync for Buffer {} + +#[derive(Debug, Clone, Copy)] +pub struct Slice<'a> { + ptr: NonNull, + len: usize, + _lifetime: PhantomData<&'a ()>, +} + +impl<'a> Slice<'a> { + unsafe fn new(ptr: NonNull, len: usize) -> Self { + Self { + ptr, + len, + _lifetime: PhantomData, + } + } + + #[inline] + pub fn truncate(&self, n: usize) -> Slice<'a> { + unsafe { Self::new(self.ptr, self.len.min(n)) } + } + + #[inline] + pub fn skip(&self, n: usize) -> Slice<'a> { + let advance_by = self.len.min(n); + unsafe { Self::new(self.ptr.add(advance_by), self.len - advance_by) } + } + + #[inline] + pub fn split(&self, n: usize) -> (Slice<'a>, Slice<'a>) { + let n = self.len.min(n); + unsafe { + ( + Self::new(self.ptr, n), + Self::new(self.ptr.add(n), self.len - n), + ) + } + } + + #[inline] + pub fn len(&self) -> usize { + self.len + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn as_ptr(&self) -> NonNull { + self.ptr + } + + #[inline] + pub unsafe fn as_slice(&self) -> &'a [u8] { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) } + } + + #[inline] + pub unsafe fn as_mut_slice(&mut self) -> &'a mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) } + } +} + +/// Issue prefetch hints for `bytes` starting at `ptr`. +/// +/// This is purely a performance hint and cannot cause undefined behavior, +/// even if `ptr` is invalid or out of bounds. +#[inline(always)] +pub fn prefetch_cachelines(ptr: *const u8, bytes: usize) { + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0}; + let lines = bytes.div_ceil(64); + for i in 0..lines { + // SAFETY: _mm_prefetch is a hint; invalid addresses are silently ignored. + unsafe { _mm_prefetch(ptr.wrapping_add(i * 64) as *const i8, _MM_HINT_T0) }; + } + } +} + +/// Issue a prefetch hint for a single generation tag. +#[inline(always)] +pub fn prefetch_tag(tag: &AtomicU64) { + prefetch_cachelines(tag as *const AtomicU64 as *const u8, 8); +} diff --git a/diskann-inmem/src/arbiter/epoch.rs b/diskann-inmem/src/arbiter/epoch.rs new file mode 100644 index 000000000..0fcd052a7 --- /dev/null +++ b/diskann-inmem/src/arbiter/epoch.rs @@ -0,0 +1,156 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::{ + atomic::{AtomicU32, AtomicU64, Ordering}, + Mutex, +}; + +use crate::arbiter::Generation; + +const CAPACITY: usize = 256; + +#[derive(Debug)] +pub struct Registry { + /// A record of the active generations. + /// + /// * 0 = "available". + /// * non-zero: generation is active. + slots: Box<[AtomicU64]>, + generation: AtomicU64, + barrier: Mutex, +} + +impl Registry { + pub fn new() -> Self { + Self::with_capacity(CAPACITY) + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + slots: (0..capacity).map(|_| AtomicU64::new(0)).collect(), + generation: AtomicU64::new(u64::MAX), + barrier: Mutex::new(Hint(0)), + } + } + + pub fn generation(&self) -> Generation { + Generation::new(self.generation.load(Ordering::Acquire)) + } + + pub fn register(&self) -> Guard<'_> { + let mut barrier = self.barrier.lock().unwrap(); + + // No synchronization happens on the global generation tag. + let generation = self.generation.load(Ordering::Acquire); + let hint = barrier.increment(); + + let nslots = self.slots.len(); + for i in 0..nslots { + let slot = (hint + i) % nslots; + if let Ok(_) = self.slots[slot].compare_exchange( + 0, + generation, + Ordering::Release, + Ordering::Relaxed, + ) { + return Guard { + registry: self, + slot, + generation: Generation::new(generation), + }; + } + } + + panic!("Let's turn this into a proper error."); + } + + pub fn advance(&self) -> Generation { + // TODO: What to do on the unlikely event of a wrap-around? + Generation::new(self.generation.fetch_sub(1, Ordering::AcqRel)) + } + + fn wait_for(&self, generation: Generation) { + let generation = generation.value(); + let wait_list = { + let _barrier = self.barrier.lock().unwrap(); + let mut wait_list = Vec::new(); + for (i, s) in self.slots.iter().enumerate() { + let g = s.load(Ordering::Relaxed); + if g != 0 && g >= generation { + wait_list.push(i); + } + } + + wait_list + }; + + for slot in wait_list { + let s = &self.slots[slot]; + loop { + let g = s.load(Ordering::Relaxed); + if g == 0 || g < generation { + break; + } + std::hint::spin_loop(); + } + } + + // This barrier synchronizes with all the relaxed loads on the slots, which are + // set with `Release` semantics. + std::sync::atomic::fence(Ordering::Acquire); + } + + /// Return the oldest generation that is currently being protected. + /// + /// Generations decrement from `Generation::MAX` + /// + /// This is a syncronizing operation. + pub fn waiting(&self) -> Generation { + let _barrier = self.barrier.lock().unwrap(); + let mut highest = 0; + for s in self.slots.iter() { + let g = s.load(Ordering::Relaxed); + highest = highest.max(g); + } + + // `acquires` with respect to all previous relaxed loads. + std::sync::atomic::fence(Ordering::Acquire); + + Generation::new(highest) + } +} + +#[derive(Debug)] +pub struct Guard<'a> { + registry: &'a Registry, + slot: usize, + generation: Generation, +} + +impl Guard<'_> { + /// Return the generation associated with the [`Guard`]'s creation. + #[inline] + pub fn generation(&self) -> Generation { + self.generation + } +} + +impl Drop for Guard<'_> { + fn drop(&mut self) { + self.registry.slots[self.slot].store(0, Ordering::Release) + } +} + +#[derive(Debug)] +struct Hint(usize); + +impl Hint { + fn increment(&mut self) -> usize { + let x = self.0; + self.0 += 1; + x + } +} diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs new file mode 100644 index 000000000..b59e18522 --- /dev/null +++ b/diskann-inmem/src/arbiter/freelist.rs @@ -0,0 +1,117 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{ + num::NonZeroU32, + sync::{ + atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}, + Mutex, + }, +}; + +#[derive(Debug)] +pub struct Freelist { + recycled: Mutex>, + capacity: NonZeroU32, + have_recycled: AtomicBool, + + /// The highest ID the freelist manages. This is used when in "append" to determine the + /// maximum ID we can return this way. + max: u32, + /// The number of "unallocated" IDs remaining. + unallocated: AtomicU32, +} + +#[derive(Debug, Clone, Copy)] +pub enum Id { + Found(u32), + Scan, +} + +impl Freelist { + pub fn new(max: u32, capacity: NonZeroU32) -> Self { + Self { + recycled: Mutex::new(Vec::with_capacity(capacity.get() as usize)), + capacity, + have_recycled: AtomicBool::new(false), + max, + unallocated: AtomicU32::new(max), + } + } + + pub fn pop(&self) -> Id { + // Small performance optimization - avoid locking the mutex if looks like that won't + // succeed anyways. + if self.have_recycled.load(Ordering::Relaxed) { + let mut recycled = self.recycled.lock().unwrap(); + if let Some(id) = recycled.pop() { + return Id::Found(id); + } + self.have_recycled.store(false, Ordering::Relaxed); + } + + // Missed in the recycled buffer. Try pulling from the high-water mark. + let mut unallocated = self.unallocated.load(Ordering::Relaxed); + while unallocated != 0 { + match self.unallocated.compare_exchange( + unallocated, + unallocated - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(unallocated) => return Id::Found(self.max - unallocated), + Err(actual) => { + unallocated = actual; + } + } + } + + // Missed in the recycle bin and from unallocated IDs. Time to indicate a scan. + Id::Scan + } + + /// Attempt to push `id` into the recycled list. Return `true` if `id` was + /// inserted. If `false` is returned, it is likely because the internal recycle + /// buffer is full. + pub fn push(&self, id: u32) -> bool { + let mut recycled = self.recycled.lock().unwrap(); + if recycled.len() < self.capacity() { + recycled.push(id); + self.have_recycled.store(true, Ordering::Relaxed); + true + } else { + false + } + } + + /// Append items from `itr` into the recycled buffer. Return the number of items + /// actually added. + pub fn append(&self, itr: I) -> usize + where + I: IntoIterator, + { + let mut recycled = self.recycled.lock().unwrap(); + let available = self.capacity() - recycled.len(); + let mut count = 0; + itr.into_iter().take(available).for_each(|id| { + count += 1; + recycled.push(id); + }); + + if count > 0 { + self.have_recycled.store(true, Ordering::Relaxed); + } + + count + } + + //----------// + // Internal // + //----------// + + fn capacity(&self) -> usize { + self.capacity.get() as usize + } +} diff --git a/diskann-inmem/src/arbiter/generation.rs b/diskann-inmem/src/arbiter/generation.rs new file mode 100644 index 000000000..6d82a8cd7 --- /dev/null +++ b/diskann-inmem/src/arbiter/generation.rs @@ -0,0 +1,121 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::atomic::{AtomicU64, Ordering}; + +#[derive(Debug)] +#[repr(transparent)] +pub struct Tag(AtomicU64); + +impl Tag { + pub const fn new(generation: Generation) -> Self { + Self(AtomicU64::new(generation.value())) + } + + pub fn as_ref(&self) -> Ref<'_> { + Ref::new(&self.0) + } + + pub fn as_mut(&self) -> Mut<'_> { + Mut::new(&self.0) + } + + pub unsafe fn from_ptr<'a>(ptr: *mut Tag) -> &'a Self { + unsafe { &*ptr } + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Generation(u64); + +impl Generation { + pub const MAX: Self = Self::new(u64::MAX); + + // Reserved generations. + // + // These all have small values, with `0` marking the "available" state. + // In this way, zeroed allocations for tags naturally begin in the "available" state and + // don't require additional initialization. + // + // If you add states - make sure to increment the `RESERVED` marker! + pub(crate) const AVAILABLE: Self = Self::new(0); + pub(crate) const OWNED: Self = Self::new(1); + pub(crate) const FROZEN: Self = Self::new(2); + const RESERVED: Self = Self::FROZEN; + + #[must_use = "this function has no side-effects"] + pub(crate) fn is_reserved(self) -> bool { + self <= Self::RESERVED + } + + #[inline] + pub const fn new(value: u64) -> Self { + Self(value) + } + + #[inline] + pub const fn value(self) -> u64 { + self.0 + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct Ref<'a>(&'a AtomicU64); + +impl<'a> Ref<'a> { + #[inline] + pub(crate) fn new(slot: &'a AtomicU64) -> Self { + Self(slot) + } + + #[inline] + fn inner(&self) -> &'a AtomicU64 { + self.0 + } + + #[inline] + pub fn get(&self, ordering: Ordering) -> Generation { + Generation::new(self.0.load(ordering)) + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct Mut<'a>(Ref<'a>); + +impl<'a> Mut<'a> { + #[inline] + pub(crate) fn new(slot: &'a AtomicU64) -> Self { + Self(Ref::new(slot)) + } + + #[inline] + pub fn try_set( + &self, + current: Generation, + new: Generation, + success: Ordering, + failure: Ordering, + ) -> Result { + self.inner() + .compare_exchange(current.value(), new.value(), success, failure) + .map(Generation::new) + .map_err(Generation::new) + } + + #[inline] + pub fn set(&self, generation: Generation, ordering: Ordering) { + self.inner().store(generation.value(), ordering) + } +} + +impl<'a> std::ops::Deref for Mut<'a> { + type Target = Ref<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/diskann-inmem/src/arbiter/mod.rs b/diskann-inmem/src/arbiter/mod.rs new file mode 100644 index 000000000..097962ac1 --- /dev/null +++ b/diskann-inmem/src/arbiter/mod.rs @@ -0,0 +1,15 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) mod buffer; +pub use buffer::{prefetch_cachelines, Buffer, Slice}; + +pub mod epoch; + +mod freelist; +pub use freelist::Freelist; + +pub mod generation; +pub use generation::Generation; diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs new file mode 100644 index 000000000..1a5c40d14 --- /dev/null +++ b/diskann-inmem/src/layers/full.rs @@ -0,0 +1,211 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann::{ANNError, ANNResult}; +use diskann_vector::{ + AsUnaligned, + distance::{self, DistanceProvider, Metric}, + UnalignedSlice, +}; +use thiserror::Error; + +use crate::layers; + +/// Full-precision data layer. +#[derive(Debug)] +pub struct Full +where + T: 'static, +{ + distance: Distance, + _type: std::marker::PhantomData, +} + +impl Full +where + T: 'static, +{ + pub fn new(dim: usize, metric: Metric) -> Self + where + T: DistanceProvider, + { + let distance = Distance { + f: T::distance_comparer(metric, Some(dim)), + dim, + }; + + Self { + distance, + _type: std::marker::PhantomData, + } + } + + pub fn dim(&self) -> usize { + self.distance.dim + } +} + +impl layers::AsDistance for Full +where + T: std::fmt::Debug + 'static, +{ + fn as_distance(&self) -> &dyn layers::Distance { + &self.distance + } +} + +impl<'a, T> layers::AsQueryDistance<'a, &'a [T]> for Full +where + T: std::fmt::Debug + Sync + 'static, +{ + fn as_query_distance( + &'a self, + query: &'a [T], + ) -> ANNResult> { + Ok(Box::new(QueryDistance::new(self.distance, query))) + } +} + +////////////// +// Distance // +////////////// + +#[derive(Debug)] +struct Distance +where + T: 'static, +{ + f: distance::Distance, + dim: usize, +} + +impl Clone for Distance { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Distance {} + +impl Distance +where + T: 'static, +{ + #[cold] + #[inline(never)] + fn error(&self, x: &[u8], y: &[u8]) -> ANNResult { + let error = DistanceError { + expected: self.bytes(), + xlen: x.len(), + ylen: y.len(), + }; + + Err(ANNError::opaque(error)) + } + + fn dim(&self) -> usize { + self.dim + } + + fn bytes(&self) -> usize { + self.dim * std::mem::size_of::() + } +} + +impl layers::Distance for Distance +where + T: std::fmt::Debug + 'static, +{ + fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult { + let bytes = self.bytes(); + if x.len() != bytes || y.len() != bytes { + self.error(x, y) + } else { + Ok(self.f.call_unaligned( + unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.dim) }, + unsafe { UnalignedSlice::new(y.as_ptr().cast::(), self.dim) }, + )) + } + } +} + +#[derive(Debug, Error)] +#[error( + "expected slices of lenght {} - instead got {} and {}", + self.expected, + self.xlen, + self.ylen +)] +struct DistanceError { + expected: usize, + xlen: usize, + ylen: usize, +} + +/////////////////// +// QueryDistance // +/////////////////// + +#[derive(Debug)] +struct QueryDistance<'a, T> +where + T: 'static, +{ + distance: Distance, + query: &'a [T], +} + +impl<'a, T> QueryDistance<'a, T> +where + T: 'static, +{ + fn new(distance: Distance, query: &'a [T]) -> Self { + if query.len() != distance.dim() { + panic!("oops"); + } + + Self { distance, query } + } + + #[cold] + #[inline(never)] + fn error(&self, x: &[u8]) -> ANNResult { + let error = QueryDistanceError { + expected: self.distance.bytes(), + xlen: x.len(), + }; + + Err(ANNError::opaque(error)) + } +} + +impl layers::QueryDistance for QueryDistance<'_, T> +where + T: std::fmt::Debug + Sync + 'static, +{ + fn evaluate(&self, x: &[u8]) -> ANNResult { + if x.len() != self.distance.bytes() { + self.error(x) + } else { + Ok(self.distance.f.call_unaligned( + unsafe { + UnalignedSlice::new(self.query.as_ptr().cast::(), self.distance.dim) + }, + unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.distance.dim) }, + )) + } + } +} + +#[derive(Debug, Error)] +#[error( + "expected slice of lenght {} - instead got {}", + self.expected, + self.xlen, +)] +struct QueryDistanceError { + expected: usize, + xlen: usize, +} diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs new file mode 100644 index 000000000..f75967acf --- /dev/null +++ b/diskann-inmem/src/layers/mod.rs @@ -0,0 +1,46 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann::{error::StandardError, utils::VectorRepr, ANNResult}; +use diskann_vector::{distance::Metric, DistanceFunction}; + +mod full; + +pub trait Distance: Send + Sync + std::fmt::Debug { + fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult; +} + +impl DistanceFunction<&[u8], &[u8]> for &dyn Distance { + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.evaluate(x, y).unwrap() + } +} + +pub trait AsDistance { + fn as_distance(&self) -> &dyn Distance; +} + +pub trait QueryDistance: Send + Sync + std::fmt::Debug { + fn evaluate(&self, x: &[u8]) -> ANNResult; +} + +pub trait AsQueryDistance<'a, T> { + fn as_query_distance(&'a self, query: T) -> ANNResult>; +} + +pub trait Set { + /// Return the number of bytes needed by this layer representation. + /// + /// To be well-behaved, this function must be idempotent. + fn bytes(&self) -> usize; + + /// Write into the stored representation. + fn into_bytes<'a>(&self, element: T, bytes: &'a mut [u8]) -> ANNResult<()>; +} + +// Meta traits for `Search` and `Insert` compatibility. +pub trait Layer: Send + Sync + 'static {} +pub trait Search<'a, T>: Layer + AsQueryDistance<'a, T> {} +pub trait Insert<'a, T>: Search<'a, T> + Set + AsDistance {} diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs new file mode 100644 index 000000000..17c89f53e --- /dev/null +++ b/diskann-inmem/src/lib.rs @@ -0,0 +1,15 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod arbiter; + +mod layers; +mod store; + +pub mod neighbors; +pub mod num; +mod provider; + +pub use neighbors::Neighbors; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs new file mode 100644 index 000000000..b493f6a3c --- /dev/null +++ b/diskann-inmem/src/neighbors.rs @@ -0,0 +1,225 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::RwLock; + +use diskann::{graph::AdjacencyList, utils::IntoUsize}; +use thiserror::Error; + +use crate::{ + arbiter::Buffer, + num::{Align, Bytes}, +}; + +type Id = u32; + +const LOCK_GRANULARITY: usize = 16; + +fn lock_index(i: usize) -> usize { + i / LOCK_GRANULARITY +} + +#[derive(Debug)] +pub struct Neighbors { + neighbors: Buffer, + // One lock for each slot in `neighbors`. + locks: Vec>, +} + +impl Neighbors { + pub fn new(entries: usize, max_length: usize) -> Self { + let bytes = Bytes((max_length + 1) * std::mem::size_of::()); + let neighbors = Buffer::new(entries, bytes, Align(128)); + let locks = std::iter::repeat_with(|| RwLock::new(())) + .take(entries.div_ceil(LOCK_GRANULARITY)) + .collect(); + + Self { neighbors, locks } + } + + /// Return the maximum length for any adjacency list. + pub fn max_length(&self) -> usize { + // We reserve 4 bytes at the beginning for the length of the adjacency list. + (self.neighbors.stride().0 - std::mem::size_of::()) / std::mem::size_of::() + } + + pub fn entries(&self) -> usize { + self.neighbors.len() + } + + pub fn get(&self, i: usize, neighbors: &mut AdjacencyList) -> Result<(), OutOfBounds> { + self.check(i)?; + + let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }; + + let _guard = dismiss_poison(lock.read()); + + // SAFETY: By consruction `self.buffer` has the same number of entries as + // `self.locks` and we have already checked that `i` is in-bounds there. + let (prefix, rest) = + unsafe { self.neighbors.get_unchecked(i) }.split(std::mem::size_of::()); + + debug_assert_eq!(prefix.len(), std::mem::size_of::()); + debug_assert!(prefix.as_ptr().is_aligned()); + + // SAFETY: We hold the read-lock, so reading is safe. From our bounds checks, we + // know that this pointer is valid. + let len: usize = unsafe { prefix.as_ptr().cast::().read() } + .into_usize() + .min(self.max_length()); + + let mut resizer = neighbors.resize(len); + unsafe { + std::ptr::copy_nonoverlapping( + rest.as_ptr().as_ptr(), + resizer.as_mut_ptr().cast::(), + len * std::mem::size_of::(), + ) + }; + resizer.finish(len); + Ok(()) + } + + pub fn lock(&self, i: usize) -> Result, OutOfBounds> { + self.check(i)?; + Ok(unsafe { self.lock_unchecked(i) }) + } + + unsafe fn lock_unchecked(&self, i: usize) -> Lock<'_> { + let lock = dismiss_poison(unsafe { self.locks.get_unchecked(lock_index(i)) }.write()); + + // SAFETY: By consruction `self.buffer` has the same number of entries as + // `self.locks` and we have already checked that `i` is in-bounds there. + let slice = unsafe { self.neighbors.get_unchecked(i) }; + + debug_assert!(slice.as_ptr().is_aligned()); + + let raw = unsafe { + std::slice::from_raw_parts_mut( + slice.as_ptr().as_ptr().cast::(), + slice.len() / std::mem::size_of::(), + ) + }; + + Lock { raw, lock } + } + + pub fn set(&self, i: usize, neighbors: &[u32]) -> Result<(), SetError> { + self.check(i).map_err(SetError::OutOfBounds)?; + + // We can check the length of `neighbors` before acquiring any locks as an early exit. + if neighbors.len() > self.max_length() { + return Err(SetError::TooLong(TooLong)); + } + + let lock = unsafe { self.lock_unchecked(i) }; + unsafe { lock.write_unchecked(neighbors) }; + Ok(()) + } + + fn check(&self, i: usize) -> Result<(), OutOfBounds> { + if i >= self.entries() { + Err(OutOfBounds(i)) + } else { + Ok(()) + } + } +} + +#[derive(Debug, Clone, Copy, Error)] +#[error("index {} is out-of-bounds", self.0)] +pub struct OutOfBounds(usize); + +#[derive(Debug, Clone, Copy, Error)] +pub enum SetError { + #[error(transparent)] + OutOfBounds(OutOfBounds), + #[error(transparent)] + TooLong(TooLong), +} + +// We carefully guard where locks are acquired in this function, so that panicking while +// holding a lock won't happen and if it does, we know we're still in decent shape. +fn dismiss_poison(r: std::sync::LockResult) -> T { + match r { + Ok(v) => v, + Err(poison) => poison.into_inner(), + } +} + +/// A locked adjacency list to implement atomic read-modify-write operations. +#[derive(Debug)] +pub struct Lock<'a> { + // The raw adjacency list with the actual length stored as the first element. + // + // This **must** have a length of at least one. + // + // Also, `raw.len()` must be less than `u32::MAX`. + raw: &'a mut [u32], + // VERY IMPORTANT: `lock` has to be **after** `raw` because `lock` is guarding `raw` + // and thus must be dropped **after** `raw`. + lock: std::sync::RwLockWriteGuard<'a, ()>, +} + +impl Lock<'_> { + /// Return the capacity of the neighbor buffer. + pub fn capacity(&self) -> usize { + self.raw.len() - 1 + } + + /// Return the current length of the neighbor list. + /// + /// This is guaranteed to be less than [`capacity`](Self::capacity). + pub fn len(&self) -> usize { + // SAFETY: By construction, `self.raw` has a length of at least 1. + // + // The `min` operation is to be conservative. + unsafe { self.raw.get_unchecked(0) } + .into_usize() + .min(self.capacity()) + } + + /// View the current contents of the locked adjacency list as a slice. + pub fn as_slice(&self) -> &[u32] { + let len = self.len(); + unsafe { self.raw.get_unchecked(1..len + 1) } + } + + /// Consume the [`Lock`] - copying the contents of `neighbors`. + /// + /// Returns an error if `neighbors.len() > self.capacity()` without copying any of the + /// contents of `neighbors`. + pub fn write(self, neighbors: &[u32]) -> Result<(), TooLong> { + if neighbors.len() > self.capacity() { + return Err(TooLong); + } + + unsafe { self.write_unchecked(neighbors) }; + Ok(()) + } + + unsafe fn write_unchecked(self, neighbors: &[u32]) { + let len = neighbors.len(); + debug_assert!(len <= self.capacity()); + unsafe { + std::ptr::copy_nonoverlapping(neighbors.as_ptr(), self.raw.as_mut_ptr().add(1), len) + } + *unsafe { self.raw.get_unchecked_mut(0) } = len as u32; + // `self.raw` is dropped first, then `self.lock` which was guarding it. + } +} + +#[derive(Debug, Clone, Copy, Error)] +#[error("too long")] +pub struct TooLong; + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; +} diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs new file mode 100644 index 000000000..e806a9113 --- /dev/null +++ b/diskann-inmem/src/num.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +#[derive(Debug, Clone, Copy)] +pub struct Bytes(pub usize); + +#[derive(Debug, Clone, Copy)] +pub struct Align(pub usize); diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs new file mode 100644 index 000000000..5a265a472 --- /dev/null +++ b/diskann-inmem/src/provider.rs @@ -0,0 +1,346 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann::{ + error::Infallible, + graph::{ + glue::{self, HybridPredicate}, + workingset, AdjacencyList, + }, + provider, + utils::IntoUsize, + ANNError, ANNErrorKind, ANNResult, +}; +use diskann_utils::future::{AsyncFriendly, SendFuture}; + +use crate::{ + layers::{self, Distance, QueryDistance}, + store::{self, Primary}, +}; + +#[derive(Debug)] +pub struct Provider { + primary: Primary, + layer: T, +} + +#[derive(Debug, Clone)] +pub struct Context {} + +impl diskann::provider::ExecutionContext for Context {} + +impl diskann::provider::DataProvider for Provider +where + T: layers::Layer, +{ + type Context = Context; + type InternalId = u32; + type ExternalId = u32; + type Error = diskann::error::Infallible; + type Guard = diskann::provider::NoopGuard; + + fn to_internal_id( + &self, + context: &Self::Context, + gid: &Self::ExternalId, + ) -> Result { + Ok(*gid) + } + + /// Translate an internal id to its corresponding external id. + fn to_external_id( + &self, + context: &Self::Context, + id: Self::InternalId, + ) -> Result { + Ok(id) + } +} + +fn ready(f: F) -> std::future::Ready +where + F: FnOnce() -> R, +{ + std::future::ready(f()) +} + +//////////// +// Search // +//////////// + +const fn start_point() -> u32 { + 0 +} + +#[derive(Debug)] +pub struct SearchAccessor<'a> { + reader: store::Reader<'a>, + distance: Box, + ids: AdjacencyList, +} + +impl diskann::provider::HasId for SearchAccessor<'_> { + type Id = u32; +} + +impl glue::SearchAccessor for SearchAccessor<'_> { + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + std::future::ready(Ok(vec![start_point()])) + } + + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + let work = move || { + let start = start_point(); + match self.reader.read(start.into_usize()) { + Some(point) => { + f(start, self.distance.evaluate(point)?); + Ok(()) + } + // TODO: "lock" start points. + None => Err(ANNError::message( + ANNErrorKind::Opaque, + "could not retrieve start point", + )), + } + }; + + ready(work) + } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: HybridPredicate + Send + Sync, + F: FnMut(Self::Id, f32) + Send, + { + let work = move || -> ANNResult<()> { + for i in ids { + self.reader + .neighbors() + .get(i.into_usize(), &mut self.ids) + .unwrap(); + for neighbor in self.ids.iter().filter(|i| pred.eval_mut(i)) { + if let Some(data) = self.reader.read(i.into_usize()) { + on_neighbors(*neighbor, self.distance.evaluate(data)?) + } + } + } + + Ok(()) + }; + + ready(work) + } +} + +impl diskann::provider::SetElement for Provider +where + L: layers::Layer + layers::Set +{ + type SetError = ANNError; + + fn set_element( + &self, + context: &Self::Context, + id: &Self::ExternalId, + element: T, + ) -> impl std::future::Future> + Send { + let work = move || { + let mut write = self.primary.write(id.into_usize()).unwrap(); + >::into_bytes(&self.layer, element, write.as_mut_slice())?; + Ok(diskann::provider::NoopGuard::new(*id)) + }; + + ready(work) + } +} + +//////////// +// Insert // +//////////// + +#[derive(Debug)] +pub struct PruneAccessor<'a> { + reader: store::Reader<'a>, + set: workingset::Map>, + distance: &'a dyn Distance, + ids: AdjacencyList, +} + +impl diskann::provider::HasId for PruneAccessor<'_> { + type Id = u32; +} + +impl glue::PruneAccessor for PruneAccessor<'_> { + type Neighbors<'a> + = provider::Neighbors<'a, Self> + where + Self: 'a; + + type ElementRef<'a> = &'a [u8]; + + type View<'a> + = workingset::map::View<'a, u32, Box<[u8]>> + where + Self: 'a; + + type Distance<'a> + = &'a dyn Distance + where + Self: 'a; + + fn neighbors(&mut self) -> Self::Neighbors<'_> { + provider::Neighbors(self) + } + + async fn fill<'a, Itr>( + &'a mut self, + itr: Itr, + ) -> ANNResult<(Self::View<'a>, Self::Distance<'a>)> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + { + let v = self + .set + .fill(itr, |i| -> Result<_, Infallible> { + Ok(self.reader.read(i.into_usize()).map(|v| v.into())) + }) + .unwrap(); + + Ok((v, &*self.distance)) + } +} + +impl provider::NeighborAccessor for PruneAccessor<'_> { + fn get_neighbors( + &mut self, + id: Self::Id, + neighbors: &mut AdjacencyList, + ) -> impl std::future::Future> + Send { + let work = move || { + Ok(self + .reader + .neighbors() + .get(id.into_usize(), neighbors) + .unwrap()) + }; + ready(work) + } +} + +impl provider::NeighborAccessorMut for PruneAccessor<'_> { + fn set_neighbors( + &mut self, + id: Self::Id, + neighbors: &[Self::Id], + ) -> impl std::future::Future> + Send { + let work = move || { + Ok(self + .reader + .neighbors() + .set(id.into_usize(), neighbors) + .unwrap()) + }; + ready(work) + } + + fn append_vector( + &mut self, + id: Self::Id, + neighbors: &[Self::Id], + ) -> impl std::future::Future> + Send { + let work = move || -> ANNResult<()> { + let current = self.reader.neighbors().lock(id.into_usize()).unwrap(); + + // Copy out the current neighbors. + let mut resize = self.ids.resize(current.len()); + resize.copy_from_slice(current.as_slice()); + resize.finish(current.len()); + + // Append the new neighbors. + self.ids.extend_from_slice(neighbors); + current.write(&self.ids).unwrap(); + Ok(()) + }; + + ready(work) + } +} + +//////////////// +// Strategies // +//////////////// + +#[derive(Debug, Clone, Copy)] +pub struct Strategy; + +impl<'a, T, L> glue::SearchStrategy<'a, Provider, T> for Strategy +where + L: layers::Search<'a, T>, +{ + type SearchAccessor = SearchAccessor<'a>; + type SearchAccessorError = ANNError; + + fn search_accessor( + &'a self, + provider: &'a Provider, + context: &'a Context, + query: T + ) -> ANNResult> { + let distance = >::as_query_distance(&provider.layer, query)?; + let accessor = SearchAccessor { + reader: provider.primary.reader(), + distance, + ids: AdjacencyList::new(), + }; + Ok(accessor) + } +} + +impl glue::PruneStrategy> for Strategy +where + L: layers::Layer + layers::AsDistance, +{ + type PruneAccessor<'a> = PruneAccessor<'a>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &self, + provider: &'a Provider, + context: &'a Context, + capacity: usize, + ) -> Result, diskann::error::Infallible> { + let set = workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity); + Ok(PruneAccessor { + reader: provider.primary.reader(), + set, + distance: ::as_distance(&provider.layer), + ids: AdjacencyList::new(), + }) + } +} + +impl<'a, L, T> glue::InsertStrategy<'a, Provider, T> for Strategy +where + L: layers::Insert<'a, T>, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs new file mode 100644 index 000000000..d6cc5f4e5 --- /dev/null +++ b/diskann-inmem/src/store.rs @@ -0,0 +1,209 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{ + iter::repeat_n, + num::NonZeroU32, + sync::{ + atomic::{AtomicU64, Ordering}, + Mutex, + }, +}; + +use crate::{ + arbiter::{self, buffer, epoch, generation, Buffer, Freelist, Generation, Slice}, + num::{Align, Bytes}, + Neighbors, +}; + +#[derive(Debug)] +pub struct Primary { + // The invasive store where concurrency tags are stored inline with the data. + // + // These tags are mirrored from `tags` - with the latter being used for secondary scans + // offering slightly better locality. + buffer: Buffer, + tags: Vec, + freelist: Freelist, + registry: epoch::Registry, + neighbors: Neighbors, + drain: Mutex>, +} + +const SPLIT: usize = std::mem::size_of::(); + +impl Primary { + pub fn new(entries: usize, bytes: Bytes, max_neighbors: usize) -> Self { + Self { + buffer: Buffer::new(entries, Bytes(bytes.0 + SPLIT), Align(128)), + tags: repeat_n(Generation::AVAILABLE, entries) + .map(|v| generation::Tag::new(v)) + .collect(), + freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), + registry: epoch::Registry::new(), + neighbors: Neighbors::new(entries, max_neighbors), + drain: Mutex::new(Vec::new()), + } + } + + #[inline] + fn tag(&self, i: usize) -> Option> { + self.tags.get(i).map(|v| v.as_ref()) + } + + pub fn drain(&self) -> usize { + let mut drain = self.drain.lock().unwrap(); + let waiter = self.registry.waiting(); + let before = drain.len(); + drain.retain(|(i, generation)| { + if waiter < *generation { + self.freelist.push(*i); + false + } else { + true + } + }); + before - drain.len() + } + + pub fn reader(&self) -> Reader<'_> { + Reader { + buffer: &self.buffer, + neighbors: &self.neighbors, + epoch: self.registry.register(), + } + } + + pub(crate) fn write(&self, i: usize) -> Option> { + let tag = self.tag_mut(i)?; + match tag.try_set( + Generation::AVAILABLE, + Generation::OWNED, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => { + let (mirror, data) = unsafe { self.data(i) }; + let write = Write { + tag, + mirror, + generation: self.registry.generation(), + data, + }; + Some(write) + } + Err(_) => None, + } + } + + pub(crate) fn delete(&self, i: usize) -> bool { + let tag = self.tag_mut(i).unwrap(); + let current = tag.get(Ordering::Relaxed); + + // We can only perform a deletion if the generation is not in a reserved state. + if current.is_reserved() { + return false; + } + + let owned = Generation::OWNED; + + // Even if we make this change, we can't access any data until we wait for the + // epoch to be bumped. As such, relaxed semantics are fine. + match tag.try_set(current, owned, Ordering::Relaxed, Ordering::Relaxed) { + Ok(current) => { + // Set the metadata in the mirror as well. + let (mirror, _) = unsafe { self.data(i) }; + mirror.set(owned, Ordering::Relaxed); + let wait_for = self.registry.advance(); + self.drain + .lock() + .unwrap() + .push((i.try_into().unwrap(), wait_for)); + true + } + Err(_) => false, + } + } + + unsafe fn data(&self, i: usize) -> (generation::Mut<'_>, Slice<'_>) { + let (mirror, data) = unsafe { self.buffer.get_unchecked(i) }.split(SPLIT); + ( + unsafe { generation::Tag::from_ptr(mirror.as_ptr().as_ptr().cast()) }.as_mut(), + data, + ) + } + + /// Creating a `Mut` is impossible for user code. Exposing this functionality would + /// allow user code to break all safety invariantes this data structure relies on. + fn tag_mut(&self, i: usize) -> Option> { + self.tags.get(i).map(|v| v.as_mut()) + } +} + +#[derive(Debug)] +pub struct Reader<'a> { + buffer: &'a Buffer, + neighbors: &'a Neighbors, + epoch: epoch::Guard<'a>, +} + +impl<'a> Reader<'a> { + /// Attempt to read the value at index `i`. This can fail for any of the + /// following reasons: + /// + /// 1. Index `i` is out-of-bounds. + /// 2. The read cannot be guaranteed to be race-free. + #[inline] + pub fn read(&self, i: usize) -> Option<&[u8]> { + let (generation, rest) = match self.buffer.get(i) { + Some(slice) => slice.split(SPLIT), + None => return None, + }; + + // NOTE: Must be `Acquire` to correctly synchronize with writes. + let generation = unsafe { generation::Tag::from_ptr(generation.as_ptr().as_ptr().cast()) } + .as_ref() + .get(Ordering::Acquire); + + if generation >= self.epoch.generation() { + // SAFETY: tags and buffer always have the same length, and we + // verified i < tags.len() above. + Some(unsafe { rest.as_slice() }) + } else { + None + } + } + + // TODO: We may want to lock `Neighbors` in some way to enable exclusive access during + // operations like snapshots. + pub(crate) fn neighbors(&self) -> &Neighbors { + &self.neighbors + } +} + +#[derive(Debug)] +pub struct Write<'a> { + tag: generation::Mut<'a>, + mirror: generation::Mut<'a>, + generation: Generation, + data: Slice<'a>, +} + +impl<'a> Write<'a> { + pub fn raw_slice(&mut self) -> Slice<'_> { + self.data + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + unsafe { self.data.as_mut_slice() } + } +} + +impl Drop for Write<'_> { + fn drop(&mut self) { + self.mirror.set(self.generation, Ordering::Release); + self.tag.set(self.generation, Ordering::Release); + } +} diff --git a/diskann-vector/src/distance/distance_provider.rs b/diskann-vector/src/distance/distance_provider.rs index 2ec7f4b88..2bee175a3 100644 --- a/diskann-vector/src/distance/distance_provider.rs +++ b/diskann-vector/src/distance/distance_provider.rs @@ -58,7 +58,7 @@ where /// A function pointer-like type for computing distances between `&[T]` and `&[U]`. /// /// See: [`DistanceProvider`]. -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] pub struct Distance where T: 'static, @@ -99,6 +99,22 @@ where } } +impl Clone for Distance +where + T: 'static, + U: 'static, +{ + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Distance +where + T: 'static, + U: 'static, +{} + impl crate::DistanceFunction<&[T], &[U], f32> for Distance where T: 'static, From 5ecae48f24f993fcc56d5ee77944c3ab0625f4a1 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 9 Jun 2026 17:52:35 -0700 Subject: [PATCH 02/45] Checkpoint. --- Cargo.lock | 2 + diskann-benchmark/Cargo.toml | 1 + diskann-inmem/Cargo.toml | 3 + diskann-inmem/src/layers/full.rs | 64 +++++++-- diskann-inmem/src/layers/mod.rs | 26 ++-- diskann-inmem/src/provider.rs | 124 ++++++++++++++---- diskann-inmem/src/store.rs | 19 ++- .../src/distance/distance_provider.rs | 3 +- 8 files changed, 186 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3195695f9..efc0e7f38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,6 +658,7 @@ dependencies = [ "diskann-benchmark-runner", "diskann-bftree", "diskann-disk", + "diskann-inmem", "diskann-label-filter", "diskann-providers", "diskann-quantization", @@ -815,6 +816,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "thiserror 2.0.17", + "tokio", ] [[package]] diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index ce5018aad..79a944cba 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -39,6 +39,7 @@ opentelemetry_sdk = { workspace = true, optional = true } scopeguard = { version = "1.2", optional = true } diskann-benchmark-core = { workspace = true, features = ["bigann"] } itertools.workspace = true +diskann-inmem = { version = "0.53.0", path = "../diskann-inmem" } [lints] clippy.undocumented_unsafe_blocks = "warn" diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 2516abf2c..3147434e0 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -16,3 +16,6 @@ thiserror.workspace = true [lints] workspace = true + +[dev-dependencies] +tokio = { workspace = true, features = ["macros"] } diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 1a5c40d14..74c6149f0 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -5,13 +5,12 @@ use diskann::{ANNError, ANNResult}; use diskann_vector::{ - AsUnaligned, distance::{self, DistanceProvider, Metric}, - UnalignedSlice, + AsUnaligned, UnalignedSlice, }; use thiserror::Error; -use crate::layers; +use crate::{layers, num::Bytes}; /// Full-precision data layer. #[derive(Debug)] @@ -45,22 +44,37 @@ where pub fn dim(&self) -> usize { self.distance.dim } + + pub fn bytes(&self) -> Bytes { + Bytes(self.dim() * std::mem::size_of::()) + } } -impl layers::AsDistance for Full +impl layers::Layer for Full where - T: std::fmt::Debug + 'static, + T: bytemuck::Pod + Send + Sync, { - fn as_distance(&self) -> &dyn layers::Distance { - &self.distance + fn bytes(&self) -> usize { + >::bytes(self).0 } } -impl<'a, T> layers::AsQueryDistance<'a, &'a [T]> for Full +impl layers::Set<&[T]> for Full where - T: std::fmt::Debug + Sync + 'static, + T: bytemuck::Pod + Send + Sync, { - fn as_query_distance( + fn into_bytes(&self, v: &[T], bytes: &mut [u8]) -> ANNResult<()> { + assert_eq!(self.dim(), v.len()); + bytes.copy_from_slice(bytemuck::must_cast_slice::(v)); + Ok(()) + } +} + +impl<'a, T> layers::Search<'a, &'a [T]> for Full +where + T: std::fmt::Debug + Send + Sync + 'static, +{ + fn query_distance( &'a self, query: &'a [T], ) -> ANNResult> { @@ -68,6 +82,32 @@ where } } +impl layers::AsDistance for Full +where + T: std::fmt::Debug + Send + Sync + 'static, +{ + fn as_distance(&self) -> &dyn layers::Distance { + &self.distance + } +} + +impl<'a, T> layers::Insert<'a, &'a [T]> for Full +where + T: bytemuck::Pod + std::fmt::Debug + Send + Sync, +{} + +// impl<'a, T> layers::Insert<'a, &'a [T]> for Full +// where +// T: bytemuck::Pod + std::fmt::Debug + Send + Sync, +// { +// fn search_distance( +// &'a self, +// query: &'a [T], +// ) -> ANNResult> { +// Ok(Box::new(QueryDistance::new(self.distance, query))) +// } +// } + ////////////// // Distance // ////////////// @@ -190,9 +230,7 @@ where self.error(x) } else { Ok(self.distance.f.call_unaligned( - unsafe { - UnalignedSlice::new(self.query.as_ptr().cast::(), self.distance.dim) - }, + unsafe { UnalignedSlice::new(self.query.as_ptr().cast::(), self.distance.dim) }, unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.distance.dim) }, )) } diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index f75967acf..f7fc303c0 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -6,41 +6,43 @@ use diskann::{error::StandardError, utils::VectorRepr, ANNResult}; use diskann_vector::{distance::Metric, DistanceFunction}; -mod full; +pub(crate) mod full; +pub use full::Full; pub trait Distance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult; } +pub trait AsDistance: Send + Sync + std::fmt::Debug { + fn as_distance(&self) -> &dyn Distance; +} + impl DistanceFunction<&[u8], &[u8]> for &dyn Distance { fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { self.evaluate(x, y).unwrap() } } -pub trait AsDistance { - fn as_distance(&self) -> &dyn Distance; -} - pub trait QueryDistance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8]) -> ANNResult; } -pub trait AsQueryDistance<'a, T> { - fn as_query_distance(&'a self, query: T) -> ANNResult>; -} - -pub trait Set { +pub trait Layer: Send + Sync + 'static { /// Return the number of bytes needed by this layer representation. /// /// To be well-behaved, this function must be idempotent. fn bytes(&self) -> usize; +} +pub trait Set: Layer { /// Write into the stored representation. fn into_bytes<'a>(&self, element: T, bytes: &'a mut [u8]) -> ANNResult<()>; } // Meta traits for `Search` and `Insert` compatibility. -pub trait Layer: Send + Sync + 'static {} -pub trait Search<'a, T>: Layer + AsQueryDistance<'a, T> {} +pub trait Search<'a, T>: Send + Sync + 'static { + fn query_distance(&'a self, query: T) -> ANNResult>; +} + pub trait Insert<'a, T>: Search<'a, T> + Set + AsDistance {} + diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 5a265a472..b6832f0ae 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -17,6 +17,7 @@ use diskann_utils::future::{AsyncFriendly, SendFuture}; use crate::{ layers::{self, Distance, QueryDistance}, + num::Bytes, store::{self, Primary}, }; @@ -26,14 +27,47 @@ pub struct Provider { layer: T, } +impl Provider { + pub fn new(layer: T, capacity: usize, start_points: I) -> Self + where + I: IntoIterator, + T: layers::Set, + { + let start_points: Vec<_> = start_points.into_iter().collect(); + let bytes = layers::Layer::bytes(&layer); + let primary = Primary::new( + capacity.checked_add(start_points.len()).unwrap(), + Bytes(bytes), + 32, + ); + + let mut i = capacity; + for v in start_points.into_iter() { + let mut writer = primary.write(i).unwrap(); + layers::Set::into_bytes(&layer, v, writer.as_mut_slice()).unwrap(); + i += 1; + } + + Self { primary, layer } + } + + fn reader(&self) -> store::Reader<'_> { + self.primary.reader() + } +} + +/////////////////// +// Data Provider // +/////////////////// + #[derive(Debug, Clone)] -pub struct Context {} +pub struct Context; impl diskann::provider::ExecutionContext for Context {} impl diskann::provider::DataProvider for Provider where - T: layers::Layer, + T: Send + Sync + 'static, { type Context = Context; type InternalId = u32; @@ -66,6 +100,28 @@ where std::future::ready(f()) } +impl diskann::provider::SetElement for Provider +where + L: layers::Layer + layers::Set, +{ + type SetError = ANNError; + + fn set_element( + &self, + context: &Self::Context, + id: &Self::ExternalId, + element: T, + ) -> impl std::future::Future> + Send { + let work = move || { + let mut write = self.primary.write(id.into_usize()).unwrap(); + >::into_bytes(&self.layer, element, write.as_mut_slice())?; + Ok(diskann::provider::NoopGuard::new(*id)) + }; + + ready(work) + } +} + //////////// // Search // //////////// @@ -135,7 +191,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { .get(i.into_usize(), &mut self.ids) .unwrap(); for neighbor in self.ids.iter().filter(|i| pred.eval_mut(i)) { - if let Some(data) = self.reader.read(i.into_usize()) { + if let Some(data) = self.reader.read(neighbor.into_usize()) { on_neighbors(*neighbor, self.distance.evaluate(data)?) } } @@ -148,28 +204,6 @@ impl glue::SearchAccessor for SearchAccessor<'_> { } } -impl diskann::provider::SetElement for Provider -where - L: layers::Layer + layers::Set -{ - type SetError = ANNError; - - fn set_element( - &self, - context: &Self::Context, - id: &Self::ExternalId, - element: T, - ) -> impl std::future::Future> + Send { - let work = move || { - let mut write = self.primary.write(id.into_usize()).unwrap(); - >::into_bytes(&self.layer, element, write.as_mut_slice())?; - Ok(diskann::provider::NoopGuard::new(*id)) - }; - - ready(work) - } -} - //////////// // Insert // //////////// @@ -300,9 +334,10 @@ where &'a self, provider: &'a Provider, context: &'a Context, - query: T + query: T, ) -> ANNResult> { - let distance = >::as_query_distance(&provider.layer, query)?; + let distance = + >::query_distance(&provider.layer, query)?; let accessor = SearchAccessor { reader: provider.primary.reader(), distance, @@ -344,3 +379,38 @@ where *self } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use diskann::graph::DiskANNIndex; + use diskann_vector::distance::Metric; + + use crate::layers::Full; + + #[tokio::test] + async fn smoke() { + let full = Full::::new(1, Metric::L2); + let start_points: [&[f32]; _] = [&[1.0], &[2.0]]; + + let provider = Provider::new(full, 10, start_points); + + let config = diskann::graph::config::Builder::new( + 10, + diskann::graph::config::MaxDegree::Same, + 100, + (Metric::L2).into() + ).build().unwrap(); + + let index = DiskANNIndex::new(config, provider, None); + + index.insert(&Strategy, &Context, &0, &[3.0]).await.unwrap(); + index.insert(&Strategy, &Context, &1, &[4.0]).await.unwrap(); + index.insert(&Strategy, &Context, &2, &[5.0]).await.unwrap(); + } +} diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index d6cc5f4e5..0bf2385d8 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -25,6 +25,7 @@ pub struct Primary { // These tags are mirrored from `tags` - with the latter being used for secondary scans // offering slightly better locality. buffer: Buffer, + unpadded: usize, tags: Vec, freelist: Freelist, registry: epoch::Registry, @@ -36,8 +37,12 @@ const SPLIT: usize = std::mem::size_of::(); impl Primary { pub fn new(entries: usize, bytes: Bytes, max_neighbors: usize) -> Self { + let unpadded = bytes.0 + SPLIT; + let padded_bytes = unpadded.checked_next_multiple_of(SPLIT).unwrap(); + Self { - buffer: Buffer::new(entries, Bytes(bytes.0 + SPLIT), Align(128)), + buffer: Buffer::new(entries, Bytes(padded_bytes), Align(128)), + unpadded, tags: repeat_n(Generation::AVAILABLE, entries) .map(|v| generation::Tag::new(v)) .collect(), @@ -53,6 +58,10 @@ impl Primary { self.tags.get(i).map(|v| v.as_ref()) } + pub fn capacity(&self) -> usize { + self.buffer.len() + } + pub fn drain(&self) -> usize { let mut drain = self.drain.lock().unwrap(); let waiter = self.registry.waiting(); @@ -71,6 +80,7 @@ impl Primary { pub fn reader(&self) -> Reader<'_> { Reader { buffer: &self.buffer, + unpadded: self.unpadded, neighbors: &self.neighbors, epoch: self.registry.register(), } @@ -128,7 +138,9 @@ impl Primary { } unsafe fn data(&self, i: usize) -> (generation::Mut<'_>, Slice<'_>) { - let (mirror, data) = unsafe { self.buffer.get_unchecked(i) }.split(SPLIT); + let (mirror, data) = unsafe { self.buffer.get_unchecked(i) } + .truncate(self.unpadded) + .split(SPLIT); ( unsafe { generation::Tag::from_ptr(mirror.as_ptr().as_ptr().cast()) }.as_mut(), data, @@ -145,6 +157,7 @@ impl Primary { #[derive(Debug)] pub struct Reader<'a> { buffer: &'a Buffer, + unpadded: usize, neighbors: &'a Neighbors, epoch: epoch::Guard<'a>, } @@ -158,7 +171,7 @@ impl<'a> Reader<'a> { #[inline] pub fn read(&self, i: usize) -> Option<&[u8]> { let (generation, rest) = match self.buffer.get(i) { - Some(slice) => slice.split(SPLIT), + Some(slice) => slice.truncate(self.unpadded).split(SPLIT), None => return None, }; diff --git a/diskann-vector/src/distance/distance_provider.rs b/diskann-vector/src/distance/distance_provider.rs index 2bee175a3..60cd02079 100644 --- a/diskann-vector/src/distance/distance_provider.rs +++ b/diskann-vector/src/distance/distance_provider.rs @@ -113,7 +113,8 @@ impl Copy for Distance where T: 'static, U: 'static, -{} +{ +} impl crate::DistanceFunction<&[T], &[U], f32> for Distance where From 1a367cb020cee84c6e869706e230e90a9af03fa8 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 10 Jun 2026 11:59:24 -0700 Subject: [PATCH 03/45] Checkpoint. --- Cargo.lock | 2 +- Cargo.toml | 1 + diskann-benchmark/Cargo.toml | 2 +- diskann-benchmark/src/backend/mod.rs | 23 ++ diskann-benchmark/src/index/benchmarks.rs | 68 +++--- diskann-benchmark/src/index/inmem2.rs | 271 ++++++++++++++++++++++ diskann-benchmark/src/index/mod.rs | 4 +- diskann-inmem/Cargo.toml | 2 +- diskann-inmem/src/lib.rs | 5 +- diskann-inmem/src/provider.rs | 9 +- 10 files changed, 346 insertions(+), 41 deletions(-) create mode 100644 diskann-benchmark/src/backend/mod.rs create mode 100644 diskann-benchmark/src/index/inmem2.rs diff --git a/Cargo.lock b/Cargo.lock index efc0e7f38..9553e67f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -809,7 +809,7 @@ dependencies = [ [[package]] name = "diskann-inmem" -version = "0.53.0" +version = "0.54.0" dependencies = [ "bytemuck", "diskann", diff --git a/Cargo.toml b/Cargo.toml index 5ba752b97..a62586882 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ diskann-platform = { path = "diskann-platform", version = "0.54.0" } diskann = { path = "diskann", version = "0.54.0" } # Providers diskann-providers = { path = "diskann-providers", default-features = false, version = "0.54.0" } +diskann-inmem = { path = "diskann-inmem", default-features = false, version = "0.54.0" } diskann-disk = { path = "diskann-disk", version = "0.54.0" } diskann-label-filter = { path = "diskann-label-filter", version = "0.54.0" } # Infra diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index 79a944cba..b4c29cc5a 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -39,7 +39,7 @@ opentelemetry_sdk = { workspace = true, optional = true } scopeguard = { version = "1.2", optional = true } diskann-benchmark-core = { workspace = true, features = ["bigann"] } itertools.workspace = true -diskann-inmem = { version = "0.53.0", path = "../diskann-inmem" } +diskann-inmem = { workspace = true } [lints] clippy.undocumented_unsafe_blocks = "warn" diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs new file mode 100644 index 000000000..6bde405ad --- /dev/null +++ b/diskann-benchmark/src/backend/mod.rs @@ -0,0 +1,23 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_benchmark_runner::Registry; + +mod disk_index; +mod exhaustive; +mod filters; +mod index; +mod inmem2; +mod multi_vector; + +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + exhaustive::register_benchmarks(registry)?; + disk_index::register_benchmarks(registry)?; + index::register_benchmarks(registry)?; + filters::register_benchmarks(registry)?; + multi_vector::register_benchmarks(registry)?; + inmem2::register_benchmarks(registry)?; + Ok(()) +} diff --git a/diskann-benchmark/src/index/benchmarks.rs b/diskann-benchmark/src/index/benchmarks.rs index f229557dd..d5e4d944b 100644 --- a/diskann-benchmark/src/index/benchmarks.rs +++ b/diskann-benchmark/src/index/benchmarks.rs @@ -81,40 +81,40 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> .search(plugins::TopkInlineFilter), )?; - registry.register( - "graph-index-full-precision-f16", - FullPrecision::::new().search(plugins::Topk), - )?; - registry.register( - "graph-index-full-precision-u8", - FullPrecision::::new().search(plugins::Topk), - )?; - registry.register( - "graph-index-full-precision-i8", - FullPrecision::::new().search(plugins::Topk), - )?; - - // Dynamic Full Precision - registry.register( - "graph-index-dynamic-full-precision-f32", - DynamicFullPrecision::::new(), - )?; - registry.register( - "graph-index-dynamic-full-precision-f16", - DynamicFullPrecision::::new(), - )?; - registry.register( - "graph-index-dynamic-full-precision-u8", - DynamicFullPrecision::::new(), - )?; - registry.register( - "graph-index-dynamic-full-precision-i8", - DynamicFullPrecision::::new(), - )?; - - product::register_benchmarks(registry)?; - scalar::register_benchmarks(registry)?; - spherical::register_benchmarks(registry)?; + // registry.register( + // "graph-index-full-precision-f16", + // FullPrecision::::new().search(plugins::Topk), + // )?; + // registry.register( + // "graph-index-full-precision-u8", + // FullPrecision::::new().search(plugins::Topk), + // )?; + // registry.register( + // "graph-index-full-precision-i8", + // FullPrecision::::new().search(plugins::Topk), + // )?; + + // // Dynamic Full Precision + // registry.register( + // "graph-index-dynamic-full-precision-f32", + // DynamicFullPrecision::::new(), + // )?; + // registry.register( + // "graph-index-dynamic-full-precision-f16", + // DynamicFullPrecision::::new(), + // )?; + // registry.register( + // "graph-index-dynamic-full-precision-u8", + // DynamicFullPrecision::::new(), + // )?; + // registry.register( + // "graph-index-dynamic-full-precision-i8", + // DynamicFullPrecision::::new(), + // )?; + + // product::register_benchmarks(registry)?; + // scalar::register_benchmarks(registry)?; + // spherical::register_benchmarks(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs new file mode 100644 index 000000000..f9720ae8a --- /dev/null +++ b/diskann-benchmark/src/index/inmem2.rs @@ -0,0 +1,271 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Benchmark backend for the `diskann-inmem` (inmem2) provider. +//! +//! This wires up the inmem2 `Provider>` to the standard build and search +//! infrastructure in `diskann-benchmark-core`, giving us parallel insertion via +//! [`SingleInsert`] and KNN search with recall/latency reporting via [`KNN`]. + +use std::{io::Write, num::NonZeroUsize, sync::Arc}; + +use diskann::graph::{self, DiskANNIndex}; +use diskann_benchmark_core::{ + self as benchmark_core, + build as build_core, + recall::GroundTruthMode, + search as core_search, +}; +use diskann_benchmark_runner::{ + benchmark::{FailureScore, MatchScore}, + files::InputFile, + output::Output, + Benchmark, Checker, Checkpoint, Input, Registry, +}; +use diskann_inmem::{layers::Full, Provider, Strategy}; +use diskann_utils::views::Matrix; +use diskann_vector::distance::Metric; +use serde::{Deserialize, Serialize}; + +use crate::{backend::index::build::ProgressMeter, utils::datafiles}; + +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + registry.register("inmem2-f32", Inmem2)?; + Ok(()) +} + +/////////// +// Input // +/////////// + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Inmem2Build { + data: InputFile, + queries: InputFile, + groundtruth: InputFile, + + max_degree: usize, + l_build: usize, + alpha: f32, + + search_n: usize, + search_l: Vec, + recall_k: usize, + + num_threads: usize, + reps: NonZeroUsize, +} + +impl Input for Inmem2Build { + type Raw = Inmem2Build; + + fn tag() -> &'static str { + "inmem2" + } + + fn from_raw(mut raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { + raw.data.resolve(checker)?; + raw.queries.resolve(checker)?; + raw.groundtruth.resolve(checker)?; + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { + Self { + data: InputFile::new("path/to/base.bin"), + queries: InputFile::new("path/to/query.bin"), + groundtruth: InputFile::new("path/to/groundtruth.bin"), + max_degree: 64, + l_build: 100, + alpha: 1.2, + search_n: 10, + search_l: vec![10, 20, 50, 100], + recall_k: 10, + num_threads: 4, + reps: NonZeroUsize::new(3).unwrap(), + } + } +} + +impl std::fmt::Display for Inmem2Build { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "inmem2 f32 benchmark")?; + writeln!(f, " max_degree: {}", self.max_degree)?; + writeln!(f, " l_build: {}", self.l_build)?; + writeln!(f, " alpha: {}", self.alpha)?; + writeln!(f, " search_n: {}", self.search_n)?; + writeln!(f, " search_l: {:?}", self.search_l)?; + writeln!(f, " recall_k: {}", self.recall_k)?; + writeln!(f, " num_threads: {}", self.num_threads)?; + writeln!(f, " reps: {}", self.reps) + } +} + +/////////////// +// Benchmark // +/////////////// + +#[derive(Debug)] +struct Inmem2; + +impl Benchmark for Inmem2 { + type Input = Inmem2Build; + type Output = (); + + fn try_match(&self, _input: &Inmem2Build) -> Result { + Ok(MatchScore(0)) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&Inmem2Build>, + ) -> std::fmt::Result { + match input { + Some(i) => write!(f, "{i}"), + None => write!(f, "inmem2 f32 benchmark"), + } + } + + fn run( + &self, + input: &Inmem2Build, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result<()> { + writeln!(output, "{input}")?; + + // Load data. + let data: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.data))?); + + let dim = data.ncols(); + let num_points = data.nrows(); + writeln!(output, "Loaded {num_points} points, dim={dim}")?; + + // Compute a start point as the centroid of the first min(1000, N) points. + let sample = num_points.min(1000); + let mut centroid = vec![0.0f32; dim]; + for i in 0..sample { + for (c, &v) in centroid.iter_mut().zip(data.row(i)) { + *c += v; + } + } + let inv = 1.0 / sample as f32; + centroid.iter_mut().for_each(|c| *c *= inv); + + // Build inmem2 provider. + let metric = Metric::L2; + let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; + let layer = Full::::new(dim, metric); + let start_points: [&[f32]; 1] = [¢roid]; + let provider = Provider::new(layer, num_points, start_points); + + let config = graph::config::Builder::new_with( + input.max_degree, + graph::config::MaxDegree::new(exact_max_degree), + input.l_build, + metric.into(), + |b| { + b.alpha(input.alpha); + }, + ) + .build()?; + + let index = Arc::new(DiskANNIndex::new(config, provider, None)); + + // Build via SingleInsert. + let rt = benchmark_core::tokio::runtime(input.num_threads)?; + let builder = build_core::graph::SingleInsert::new( + index.clone(), + data, + Strategy, + build_core::ids::Identity::::new(), + ); + + writeln!(output, "Building index with {} threads...", input.num_threads)?; + let build_results = build_core::build_tracked( + builder, + build_core::Parallelism::dynamic( + diskann::utils::ONE, + NonZeroUsize::new(input.num_threads).unwrap(), + ), + &rt, + Some(&ProgressMeter::new(output)), + )?; + + let total_build_time = build_results.end_to_end_latency(); + writeln!(output, "Build complete in {:.2}s", total_build_time.as_seconds())?; + checkpoint.checkpoint(&total_build_time)?; + + // Search. + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.queries))?); + let max_k = input.recall_k; + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&input.groundtruth), Some(max_k))?; + + writeln!(output, "Loaded {} queries", queries.nrows())?; + + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(Strategy), + )?; + + let num_threads = NonZeroUsize::new(input.num_threads).unwrap(); + + for &search_l in &input.search_l { + let params = + graph::search::Knn::new(input.search_n, search_l, None)?; + + let setup = core_search::Setup { + threads: num_threads, + tasks: num_threads, + reps: input.reps, + }; + + let run = core_search::Run::new(params, setup); + + let summaries = core_search::search_all( + knn.clone(), + std::iter::once(run), + benchmark_core::search::graph::knn::Aggregator::new( + &groundtruth, + input.recall_k, + input.search_n, + GroundTruthMode::Fixed, + ), + )?; + + for summary in &summaries { + let qps: Vec = summary + .end_to_end_latencies + .iter() + .map(|lat| summary.recall.num_queries as f64 / lat.as_seconds()) + .collect(); + let max_qps = qps.iter().cloned().fold(0.0f64, f64::max); + let mean_qps = qps.iter().sum::() / qps.len() as f64; + + writeln!( + output, + " L={:<4} recall={:.4} QPS={:.0} (max {:.0}) cmps={:.1} hops={:.1}", + search_l, + summary.recall.average, + mean_qps, + max_qps, + summary.mean_cmps, + summary.mean_hops, + )?; + } + } + + Ok(()) + } +} diff --git a/diskann-benchmark/src/index/mod.rs b/diskann-benchmark/src/index/mod.rs index 3900dc337..e902d8b1b 100644 --- a/diskann-benchmark/src/index/mod.rs +++ b/diskann-benchmark/src/index/mod.rs @@ -5,19 +5,21 @@ use diskann_benchmark_runner::Registry; -mod build; +pub(crate) mod build; mod search; mod streaming; mod benchmarks; mod inmem; mod result; +mod inmem2; #[cfg(feature = "bftree")] mod bftree; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { benchmarks::register_benchmarks(registry)?; + inmem2::register_benchmarks(registry)?; #[cfg(feature = "bftree")] bftree::register_benchmarks(registry)?; diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 3147434e0..235a70ea1 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -3,7 +3,7 @@ name = "diskann-inmem" version.workspace = true description.workspace = true authors.workspace = true -documentation.workspace = true +repository.workspace = true license.workspace = true edition.workspace = true diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index 17c89f53e..b77470612 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -5,11 +5,12 @@ mod arbiter; -mod layers; +pub mod layers; mod store; pub mod neighbors; pub mod num; -mod provider; +pub mod provider; pub use neighbors::Neighbors; +pub use provider::{Context, Provider, Strategy}; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index b6832f0ae..cc84ea43e 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -60,7 +60,7 @@ impl Provider { // Data Provider // /////////////////// -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Context; impl diskann::provider::ExecutionContext for Context {} @@ -347,6 +347,13 @@ where } } +impl<'a, T, L> glue::DefaultPostProcessor<'a, Provider, T> for Strategy +where + L: layers::Search<'a, T>, +{ + diskann::default_post_processor!(glue::CopyIds); +} + impl glue::PruneStrategy> for Strategy where L: layers::Layer + layers::AsDistance, From 3203d3253d2fa88666c822ba747d66cda7ea382a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 10 Jun 2026 17:10:15 -0700 Subject: [PATCH 04/45] More stuff. --- diskann-benchmark/src/index/inmem2.rs | 20 ++-- diskann-inmem/src/layers/full.rs | 13 +-- diskann-inmem/src/layers/mod.rs | 1 - diskann-inmem/src/num.rs | 2 +- diskann-inmem/src/provider.rs | 132 ++++++++++++++++++++++++-- diskann-inmem/src/store.rs | 24 ++++- diskann-vector/src/lib.rs | 2 +- 7 files changed, 164 insertions(+), 30 deletions(-) diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index f9720ae8a..2587bcadd 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -13,10 +13,7 @@ use std::{io::Write, num::NonZeroUsize, sync::Arc}; use diskann::graph::{self, DiskANNIndex}; use diskann_benchmark_core::{ - self as benchmark_core, - build as build_core, - recall::GroundTruthMode, - search as core_search, + self as benchmark_core, build as build_core, recall::GroundTruthMode, search as core_search, }; use diskann_benchmark_runner::{ benchmark::{FailureScore, MatchScore}, @@ -189,7 +186,11 @@ impl Benchmark for Inmem2 { build_core::ids::Identity::::new(), ); - writeln!(output, "Building index with {} threads...", input.num_threads)?; + writeln!( + output, + "Building index with {} threads...", + input.num_threads + )?; let build_results = build_core::build_tracked( builder, build_core::Parallelism::dynamic( @@ -201,7 +202,11 @@ impl Benchmark for Inmem2 { )?; let total_build_time = build_results.end_to_end_latency(); - writeln!(output, "Build complete in {:.2}s", total_build_time.as_seconds())?; + writeln!( + output, + "Build complete in {:.2}s", + total_build_time.as_seconds() + )?; checkpoint.checkpoint(&total_build_time)?; // Search. @@ -222,8 +227,7 @@ impl Benchmark for Inmem2 { let num_threads = NonZeroUsize::new(input.num_threads).unwrap(); for &search_l in &input.search_l { - let params = - graph::search::Knn::new(input.search_n, search_l, None)?; + let params = graph::search::Knn::new(input.search_n, search_l, None)?; let setup = core_search::Setup { threads: num_threads, diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 74c6149f0..d964f3015 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -74,10 +74,7 @@ impl<'a, T> layers::Search<'a, &'a [T]> for Full where T: std::fmt::Debug + Send + Sync + 'static, { - fn query_distance( - &'a self, - query: &'a [T], - ) -> ANNResult> { + fn query_distance(&'a self, query: &'a [T]) -> ANNResult> { Ok(Box::new(QueryDistance::new(self.distance, query))) } } @@ -91,10 +88,10 @@ where } } -impl<'a, T> layers::Insert<'a, &'a [T]> for Full -where - T: bytemuck::Pod + std::fmt::Debug + Send + Sync, -{} +impl<'a, T> layers::Insert<'a, &'a [T]> for Full where + T: bytemuck::Pod + std::fmt::Debug + Send + Sync +{ +} // impl<'a, T> layers::Insert<'a, &'a [T]> for Full // where diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index f7fc303c0..a7edd0290 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -45,4 +45,3 @@ pub trait Search<'a, T>: Send + Sync + 'static { } pub trait Insert<'a, T>: Search<'a, T> + Set + AsDistance {} - diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index e806a9113..8843fd17a 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct Bytes(pub usize); #[derive(Debug, Clone, Copy)] diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index cc84ea43e..598aa2212 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -135,6 +135,7 @@ pub struct SearchAccessor<'a> { reader: store::Reader<'a>, distance: Box, ids: AdjacencyList, + expand_beam: FExpandBeam, } impl diskann::provider::HasId for SearchAccessor<'_> { @@ -190,11 +191,20 @@ impl glue::SearchAccessor for SearchAccessor<'_> { .neighbors() .get(i.into_usize(), &mut self.ids) .unwrap(); - for neighbor in self.ids.iter().filter(|i| pred.eval_mut(i)) { - if let Some(data) = self.reader.read(neighbor.into_usize()) { - on_neighbors(*neighbor, self.distance.evaluate(data)?) - } - } + + // Filter out unvisited IDs and ensure that all the IDs we are about + self.ids + .retain(|i| pred.eval_mut(i) && self.reader.is_in_bounds(i.into_usize())); + + unsafe { + (self.expand_beam)( + &self.ids, + 8, + &self.reader, + &*self.distance, + &mut on_neighbors, + ) + }?; } Ok(()) @@ -204,6 +214,104 @@ impl glue::SearchAccessor for SearchAccessor<'_> { } } +type FExpandBeam = unsafe fn( + &[u32], + usize, + &store::Reader<'_>, + &dyn layers::QueryDistance, + &mut dyn FnMut(u32, f32), +) -> ANNResult<()>; + +fn dispatch_expand_beam(bytes: Bytes) -> FExpandBeam { + if bytes <= Bytes(CACHE_LINE_SIZE) { + expand_beam_inner::<1> + } else if bytes <= Bytes(2 * CACHE_LINE_SIZE) { + expand_beam_inner::<2> + } else if bytes <= Bytes(3 * CACHE_LINE_SIZE) { + expand_beam_inner::<3> + } else if bytes <= Bytes(4 * CACHE_LINE_SIZE) { + expand_beam_inner::<4> + } else if bytes <= Bytes(5 * CACHE_LINE_SIZE) { + expand_beam_inner::<5> + } else if bytes <= Bytes(6 * CACHE_LINE_SIZE) { + expand_beam_inner::<6> + } else if bytes <= Bytes(7 * CACHE_LINE_SIZE) { + expand_beam_inner::<7> + } else if bytes <= Bytes(16 * CACHE_LINE_SIZE) { + expand_beam_inner::<8> + } else { + expand_beam_inner::<16> + } +} + +const CACHE_LINE_SIZE: usize = 64; + +pub unsafe fn test_function( + list: &[u32], + lookahead: usize, + reader: &store::Reader<'_>, + distance: &dyn layers::QueryDistance, + f: &mut dyn FnMut(u32, f32), +) -> ANNResult<()> { + unsafe { expand_beam_inner::<4>(list, lookahead, reader, distance, f) } +} + +/// Safety (no # yet because we need to revisit this - clippy will lint) +/// +/// * All items in `list` must in-bounds with respect to `reader`. +/// * The number of bytes associated with `N` cache lines must "make sense". +unsafe fn expand_beam_inner( + list: &[u32], + lookahead: usize, + reader: &store::Reader<'_>, + distance: &dyn layers::QueryDistance, + f: &mut dyn FnMut(u32, f32), +) -> ANNResult<()> { + debug_assert!( + N * CACHE_LINE_SIZE <= reader.bytes().0.next_multiple_of(CACHE_LINE_SIZE), + "we really rely on this: {}, bytes = {}", N, reader.bytes().0 + ); + + let len = list.len(); + let lookahead = lookahead.min(len); + + for j in 0..lookahead { + unsafe { + diskann_vector::prefetch_exactly::( + reader + .read_raw_unchecked(list.get_unchecked(j).into_usize()) + .as_ptr() + .as_ptr() + .cast_const() + .cast(), + ) + } + } + + let mut j = lookahead; + for &i in list.iter() { + if j != len { + unsafe { + diskann_vector::prefetch_exactly::( + reader + .read_raw_unchecked(list.get_unchecked(j).into_usize()) + .as_ptr() + .as_ptr() + .cast_const() + .cast(), + ) + } + j += 1; + } + + if let Some(data) = reader.read(i.into_usize()) { + f(i, distance.evaluate(data)?) + } + } + + Ok(()) +} + //////////// // Insert // //////////// @@ -336,12 +444,14 @@ where context: &'a Context, query: T, ) -> ANNResult> { - let distance = - >::query_distance(&provider.layer, query)?; + let distance = >::query_distance(&provider.layer, query)?; + let reader = provider.primary.reader(); + let expand_beam = dispatch_expand_beam(reader.bytes()); let accessor = SearchAccessor { - reader: provider.primary.reader(), + reader, distance, ids: AdjacencyList::new(), + expand_beam, }; Ok(accessor) } @@ -411,8 +521,10 @@ mod tests { 10, diskann::graph::config::MaxDegree::Same, 100, - (Metric::L2).into() - ).build().unwrap(); + (Metric::L2).into(), + ) + .build() + .unwrap(); let index = DiskANNIndex::new(config, provider, None); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 0bf2385d8..c0c69e149 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -38,7 +38,7 @@ const SPLIT: usize = std::mem::size_of::(); impl Primary { pub fn new(entries: usize, bytes: Bytes, max_neighbors: usize) -> Self { let unpadded = bytes.0 + SPLIT; - let padded_bytes = unpadded.checked_next_multiple_of(SPLIT).unwrap(); + let padded_bytes = unpadded.checked_next_multiple_of(64).unwrap(); Self { buffer: Buffer::new(entries, Bytes(padded_bytes), Align(128)), @@ -189,6 +189,28 @@ impl<'a> Reader<'a> { } } + /// Return `true` if the index `i` is in-bounds. + #[inline] + #[must_use = "this function has no side-effects"] + pub fn is_in_bounds(&self, i: usize) -> bool { + i < self.buffer.len() + } + + /// Return the raw data slice for index `i` without any race guarantees. + /// + /// # Safety + /// + /// The index `i` must be in-bounds. + #[inline] + pub(crate) unsafe fn read_raw_unchecked(&self, i: usize) -> Slice<'_> { + unsafe { self.buffer.get_unchecked(i) }.truncate(self.unpadded) + } + + /// Return the number of bytes for each entry. + pub(crate) fn bytes(&self) -> Bytes { + Bytes(self.unpadded) + } + // TODO: We may want to lock `Neighbors` in some way to enable exclusive access during // operations like snapshots. pub(crate) fn neighbors(&self) -> &Neighbors { diff --git a/diskann-vector/src/lib.rs b/diskann-vector/src/lib.rs index e88dc12c9..69db308a3 100644 --- a/diskann-vector/src/lib.rs +++ b/diskann-vector/src/lib.rs @@ -42,7 +42,7 @@ cfg_if::cfg_if! { const CACHE_LINE_SIZE: usize = 64; #[inline(always)] - unsafe fn prefetch_exactly(ptr: *const i8) { + pub unsafe fn prefetch_exactly(ptr: *const i8) { use std::arch::x86_64::*; for i in 0..N { _mm_prefetch(ptr.add(i * CACHE_LINE_SIZE), _MM_HINT_T0); From 6ab25c0a8d006d9b173fb8d0c13d1eab6d9c0058 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 11 Jun 2026 18:19:52 -0700 Subject: [PATCH 05/45] Checkopint. --- diskann-inmem/src/arbiter/buffer.rs | 514 ++++++++++++++++++++---- diskann-inmem/src/arbiter/epoch.rs | 2 +- diskann-inmem/src/arbiter/freelist.rs | 2 +- diskann-inmem/src/arbiter/generation.rs | 153 ++++++- diskann-inmem/src/arbiter/mod.rs | 2 +- diskann-inmem/src/layers/full.rs | 8 +- diskann-inmem/src/layers/mod.rs | 8 +- diskann-inmem/src/neighbors.rs | 16 +- diskann-inmem/src/num.rs | 319 ++++++++++++++- diskann-inmem/src/provider.rs | 39 +- diskann-inmem/src/store.rs | 68 ++-- 11 files changed, 987 insertions(+), 144 deletions(-) diff --git a/diskann-inmem/src/arbiter/buffer.rs b/diskann-inmem/src/arbiter/buffer.rs index 7249414e4..b3c3ad95d 100644 --- a/diskann-inmem/src/arbiter/buffer.rs +++ b/diskann-inmem/src/arbiter/buffer.rs @@ -3,10 +3,18 @@ * Licensed under the MIT license. */ -use std::{alloc::Layout, marker::PhantomData, ptr::NonNull, sync::atomic::AtomicU64}; +use std::{alloc::Layout, marker::PhantomData, ptr::NonNull}; use crate::num::{Align, Bytes}; +/// An unsynchronized row-store for raw data. +/// +/// The backing data is stored as a raw pointers and interacted with via [`RawSlice`], which +/// is also raw pointer based. Careful use of this struct enables safe use of +/// [`RawSlice::as_slice`], [`RawSlice::as_mut_slice`], and other accesses from multiple +/// threads without undefined behavior. +/// +/// Note that `Buffer` is unconditionally `Send` and `Sync`. #[derive(Debug)] pub struct Buffer { ptr: NonNull, @@ -16,59 +24,73 @@ pub struct Buffer { } impl Buffer { - pub fn new(entries: usize, bytes_per_entry: Bytes, align: Align) -> Self { - let size = bytes_per_entry.0.checked_mul(entries).unwrap(); - let layout = std::alloc::Layout::from_size_align(size, align.0).unwrap(); + /// Construct a new [`Buffer`] capable of holding `entries` with each entry occupying + /// exactly `bytes_per_entry`. Subsequent entries are separated by exactly + /// `bytes_per_entry` bytes. The base point will be aligned to at least `align`. + /// + /// # Errors + /// + /// Returns an error if the number of bytes `bytes_per_entry * entries` rounded up to + /// the next multiple of `align` exceeds `isize::MAX`. + pub fn new(entries: usize, bytes_per_entry: Bytes, align: Align) -> Result { + // If we overflow `usize::MAX`, we will definitely overflow `isize::MAX`. + let bytes = bytes_per_entry.checked_mul(entries).ok_or(BufferError)?; + + // Since `align` is constrained to be a power of two, the only way this fails is + // if we overflow `isize::MAX`. + let layout = std::alloc::Layout::from_size_align(bytes.value(), align.value()) + .map_err(|_: std::alloc::LayoutError| BufferError)?; + + let ptr = if layout.size() == 0 { + std::ptr::dangling_mut() + } else { + // SAFETY: `layout.size()` is non-zero. + unsafe { std::alloc::alloc_zeroed(layout) } + }; - let ptr = unsafe { std::alloc::alloc_zeroed(layout) }; let ptr = match NonNull::new(ptr) { Some(ptr) => ptr, None => std::alloc::handle_alloc_error(layout), }; - Self { + Ok(Self { ptr, stride: bytes_per_entry, entries, layout, - } + }) } + /// Return the number of entries in this [`Buffer`]. #[inline] pub fn len(&self) -> usize { self.entries } + /// Return the number of bytes for each entry. #[inline] pub fn stride(&self) -> Bytes { self.stride } + /// Return the minimum alignment of the base pointer for the buffer. #[inline] pub fn align(&self) -> Align { - Align(self.layout.align()) + Align::from_layout(self.layout) } + /// Return the result of `self.len() == 0`. #[inline] pub fn is_empty(&self) -> bool { self.len() == 0 } - /// Issue prefetch hints for the entry at index `i`. + /// Return the `i`th entry if `i < self.len()`. /// - /// `bytes` controls how many bytes to prefetch (clamped to `stride`). - /// Uses `wrapping_add` to avoid UB on out-of-bounds indices — prefetching - /// a bad address is architecturally harmless. - #[inline(always)] - pub fn prefetch(&self, i: usize, bytes: usize) { - let offset = self.stride.0.wrapping_mul(i); - let ptr = self.ptr.as_ptr().wrapping_add(offset); - let bytes = bytes.min(self.stride.0); - prefetch_cachelines(ptr, bytes); - } - + /// The returned [`RawSlice`] is guaranteed to have a length of [`Self::stride`] and + /// begin at `self.as_ptr().add(self.stride().value() * i)`. #[inline] - pub fn get(&self, i: usize) -> Option> { + pub fn get(&self, i: usize) -> Option> { if i >= self.entries { None } else { @@ -80,48 +102,90 @@ impl Buffer { } } - /// Get the slice for entry `i` without bounds checking. + /// Return the `i`th entry without bounds checking. + /// + /// The returned [`RawSlice`] is guaranteed to have a length of [`Self::stride`] and + /// begin at `self.as_ptr().add(self.stride().value() * i)`. /// /// # Safety /// /// `i` must be less than [`len`](Self::len). #[inline] - pub unsafe fn get_unchecked(&self, i: usize) -> Slice<'_> { + pub unsafe fn get_unchecked(&self, i: usize) -> RawSlice<'_> { debug_assert!(i < self.entries); - let ptr = unsafe { self.ptr.add(self.stride.0 * i) }; - Slice { + let ptr = unsafe { self.ptr.add(self.stride().value() * i) }; + RawSlice { ptr, - len: self.stride.0, + len: self.stride, _lifetime: PhantomData, } } + + /// Return the base pointer of the [`Buffer`]. + /// + /// If the requested allocation was non-zero, this is guaranteed to be a multiple of the + /// requested alignment. + #[inline] + pub fn as_ptr(&self) -> *const u8 { + self.ptr.as_ptr().cast_const() + } } impl Drop for Buffer { fn drop(&mut self) { - // SAFETY: This is the same pointer and allocation that was previously returned - // from a successful `alloc_zeroed`. - unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) } + // If the layout size is zero, there's nothing to do because we hold a dangling pointer. + if self.layout.size() != 0 { + // SAFETY: This is the same pointer and allocation that was previously returned + // from a successful `alloc_zeroed`. + unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) } + } } } -// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `Slice` -// the needs to be arbitrated. +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub struct BufferError; + +impl std::fmt::Display for BufferError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "requested allocation exceeds `isize::MAX`") + } +} + +impl std::error::Error for BufferError {} + +// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `RawSlice` +// that needs to be arbitrated. unsafe impl Send for Buffer {} -// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `Slice` -// the needs to be arbitrated. +// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `RawSlice` +// that needs to be arbitrated. unsafe impl Sync for Buffer {} -#[derive(Debug, Clone, Copy)] -pub struct Slice<'a> { +/// A raw entry in [`Buffer`]. +/// +/// The memory in the range `[RawSlice::as_ptr(), RawSlice::as_ptr().add(slice.len()))` is +/// guaranteed to be within a single alive allocation. +/// +/// This has borrowing semantics of a raw pointer. +#[derive(Debug)] +pub struct RawSlice<'a> { ptr: NonNull, - len: usize, + len: Bytes, _lifetime: PhantomData<&'a ()>, } -impl<'a> Slice<'a> { - unsafe fn new(ptr: NonNull, len: usize) -> Self { +impl<'a> RawSlice<'a> { + /// Create a new [`RawSlice`]. + /// + /// # Safety + /// + /// The memory `[ptr, ptr.add(len.value()))` must be part of a single allocation for + /// the duration of the lifetime `'a`. + /// + /// However, this has the semantics of a pointer: multiple threads can hold a [`RawSlice`] + /// to the same piece of memory without undefined behavior. + unsafe fn new(ptr: NonNull, len: Bytes) -> Self { Self { ptr, len, @@ -129,72 +193,386 @@ impl<'a> Slice<'a> { } } + /// Create a new slice to the first `n.min(self.len())` bytes of `self`. + #[inline] + pub fn truncate(&self, n: Bytes) -> RawSlice<'a> { + // SAFETY: The `min` operation ensures we provide an argument <= `self.len()`. + unsafe { self.truncate_unchecked(self.len.min(n)) } + } + + /// Shorten the slice to the `n`. + /// + /// # Safety + /// + /// `n` must be less than or equal to `self.len()`. #[inline] - pub fn truncate(&self, n: usize) -> Slice<'a> { - unsafe { Self::new(self.ptr, self.len.min(n)) } + pub(crate) unsafe fn truncate_unchecked(&self, n: Bytes) -> RawSlice<'a> { + debug_assert!(n <= self.len); + + // SAFETY: Inherited from the caller. + unsafe { Self::new(self.ptr, n) } } + /// Create a new slice skipping the first `n.min(self.len())` bytes of self. #[inline] - pub fn skip(&self, n: usize) -> Slice<'a> { + pub fn skip(&self, n: Bytes) -> RawSlice<'a> { let advance_by = self.len.min(n); - unsafe { Self::new(self.ptr.add(advance_by), self.len - advance_by) } + + // SAFETY: `advance_by <= self.len()`, so the pointer offset is valid and the + // `unchecked_sub` cannot underflow. + unsafe { + Self::new( + self.ptr.add(advance_by.value()), + self.len.unchecked_sub(advance_by), + ) + } + } + + /// Split `self` into two as `([ptr, ptr.add(m)), [ptr.add(m), ptr.add(self.len())))` + /// where `m = n.min(self.len())`. + #[inline] + pub fn split(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { + // SAFETY: The argument is <= `self.len()`. + unsafe { self.split_unchecked(self.len.min(n)) } } + /// Split `self` into two as `([ptr, ptr.add(n)), [ptr.add(n), ptr.add(self.len())))` + /// + /// # Safety + /// + /// `n` must be less than or equal to `self.len()`. #[inline] - pub fn split(&self, n: usize) -> (Slice<'a>, Slice<'a>) { - let n = self.len.min(n); + pub(crate) unsafe fn split_unchecked(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { + debug_assert!(n <= self.len); unsafe { ( Self::new(self.ptr, n), - Self::new(self.ptr.add(n), self.len - n), + Self::new(self.ptr.add(n.value()), self.len.unchecked_sub(n)), ) } } + /// Return the length of the slice in bytes. #[inline] - pub fn len(&self) -> usize { + pub fn len(&self) -> Bytes { self.len } + /// Return the result of `self.len() == 0`. #[inline] pub fn is_empty(&self) -> bool { - self.len() == 0 + self.len() == Bytes::new(0) } - pub fn as_ptr(&self) -> NonNull { + /// Return the base [`NonNull`] pointer of the slice. + pub fn as_non_null(&self) -> NonNull { self.ptr } + /// Return the base pointer of the slice as `*const u8`. + pub fn as_ptr(&self) -> *const u8 { + self.ptr.as_ptr().cast_const() + } + + /// Return the base pointer of the slice as `*mut u8`. + /// + /// This returns a mutable pointer regardless of the receiver's mutability, matching + /// the raw-pointer semantics of [`RawSlice`]. + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + + /// Materialize the raw slice as a true shared slice. + /// + /// # Safety + /// + /// Correct adherence to the API of [`RawSlice`] will ensure that the memory behind the + /// materialized slice resides within a single allocation. + /// + /// However, it is the responsibility of the caller to ensure that materializing this + /// slice does not violate Rust's borrowing rules. #[inline] pub unsafe fn as_slice(&self) -> &'a [u8] { - unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) } + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len.value()) } } + /// Materialize the raw slice as a true mutable slice. + /// + /// # Safety + /// + /// Correct adherence to the API of [`RawSlice`] will ensure that the memory behind the + /// materialized slice resides within a single allocation. + /// + /// However, it is the responsibility of the caller to ensure that materializing this + /// slice does not violate Rust's borrowing rules. #[inline] pub unsafe fn as_mut_slice(&mut self) -> &'a mut [u8] { - unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) } + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len.value()) } } } -/// Issue prefetch hints for `bytes` starting at `ptr`. -/// -/// This is purely a performance hint and cannot cause undefined behavior, -/// even if `ptr` is invalid or out of bounds. -#[inline(always)] -pub fn prefetch_cachelines(ptr: *const u8, bytes: usize) { - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0}; - let lines = bytes.div_ceil(64); - for i in 0..lines { - // SAFETY: _mm_prefetch is a hint; invalid addresses are silently ignored. - unsafe { _mm_prefetch(ptr.wrapping_add(i * 64) as *const i8, _MM_HINT_T0) }; +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{thread, sync::Barrier}; + + #[derive(Debug)] + struct Ctx { + entries: usize, + bytes_per_entry: Bytes, + align: Align, + } + + impl std::fmt::Display for Ctx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "entries = {}, bytes_per_entry = {}, align = {}", + self.entries, self.bytes_per_entry, self.align + ) } } -} -/// Issue a prefetch hint for a single generation tag. -#[inline(always)] -pub fn prefetch_tag(tag: &AtomicU64) { - prefetch_cachelines(tag as *const AtomicU64 as *const u8, 8); + fn test_buffer_inner(entries: usize, bytes_per_entry: Bytes, align: Align) { + let ctx = Ctx { + entries, + bytes_per_entry, + align, + }; + let mut buffer = Buffer::new(entries, bytes_per_entry, align).unwrap(); + + // Initial Checks + assert_eq!(buffer.len(), entries, "{}", ctx); + assert_eq!(buffer.stride(), bytes_per_entry, "{}", ctx); + assert_eq!(buffer.align(), align, "{}", ctx); + + if entries != 0 && !bytes_per_entry.is_zero() { + let addr = buffer.as_ptr() as usize; + assert!( + addr.is_multiple_of(align.value()), + "pointer address {:#x} must be a multiple of the requested alignment: {}", + addr, + ctx, + ); + } + + if entries == 0 { + assert!(buffer.is_empty(), "{}", ctx); + } else { + assert!(!buffer.is_empty(), "{}", ctx); + } + + // Verify zero initialization + assert_is_zeroed(&mut buffer, &ctx); + + // Check Slice Methods + check_slice_methods(&mut buffer, &ctx); + + + // Check that concurrent mutation is allowed. + // + // This is mainly a Miri check. + zero(&mut buffer); + check_threaded(&mut buffer, &ctx); + } + + fn zero(buffer: &mut Buffer) { + // SAFETY NOTE: Exclusive reference to `buffer` guarantees no concurrent mutation. + for i in 0..buffer.len() { + let mut raw_slice = buffer.get(i).unwrap(); + assert_eq!(raw_slice.len(), buffer.stride()); + + let slice = unsafe { raw_slice.as_mut_slice() }; + assert_eq!(slice.len(), buffer.stride().value()); + slice.fill(0); + } + } + + fn assert_is_zeroed(buffer: &mut Buffer, ctx: &Ctx) { + // SAFETY NOTE: Exclusive reference to `buffer` guarantees no concurrent mutation. + // All `unsafe` calls below rely on this guarantee. + + for i in 0..buffer.len() { + let raw_slice = buffer.get(i).unwrap(); + assert_eq!(raw_slice.len(), buffer.stride()); + + assert_eq!(raw_slice.as_non_null().as_ptr(), raw_slice.as_mut_ptr()); + assert_eq!(raw_slice.as_non_null().as_ptr().cast_const(), raw_slice.as_ptr()); + + assert_eq!( + raw_slice.as_ptr(), + buffer + .as_ptr() + .wrapping_add(buffer.stride().checked_mul(i).unwrap().value()), + "stride mismatch - {}", + ctx + ); + + if raw_slice.len().is_zero() { + assert!(raw_slice.is_empty()); + } else { + assert!(!raw_slice.is_empty()); + } + + let slice = unsafe { raw_slice.as_slice() }; + assert_eq!(slice.len(), buffer.stride().value()); + assert!(slice.iter().all(|&i| i == 0), "{}", ctx); + } + + // Verify that bounds-checking works. + assert!(buffer.get(buffer.len()).is_none(), "{}", ctx); + } + + fn check_slice_methods(buffer: &mut Buffer, ctx: &Ctx) { + // SAFETY NOTE: We take `buffer` by exclusive reference to guarantee that there + // is no possibility of concurrent mutation outside this method. All `unsafe` calls + // below rely on this guarantee unless otherwise noted. + + if buffer.len() == 0 { + return; + } + + let mut raw = buffer.get(0).unwrap(); + let base: u8 = 5; + let base_usize: usize = base.into(); + + // truncate // + + iota(unsafe { raw.as_mut_slice() }, base); + for i in 0..raw.len().value() + base_usize { + let expected = i.min(raw.len().value()); + + let truncated = raw.truncate(Bytes::new(i)); + assert_eq!(truncated.len().value(), expected, "{}", ctx); + assert!(is_iota(unsafe { truncated.as_slice() }, base), "{}", ctx); + } + + // skip // + + for i in 0..raw.len().value() + base_usize { + let expected = raw.len().value() - i.min(raw.len().value()); + let skipped = raw.skip(Bytes::new(i)); + assert_eq!(skipped.len().value(), expected, "{}", ctx); + assert!( + is_iota(unsafe { skipped.as_slice() }, base.wrapping_add(i as u8)), + "{}", + ctx + ); + } + + // split // + + for i in 0..raw.len().value() + base_usize { + let first = i.min(raw.len().value()); + let last = raw.len().value() - first; + + let (mut prefix, mut suffix) = raw.split(Bytes::new(i)); + + assert_eq!(prefix.len().value(), first, "{}", ctx); + assert_eq!(suffix.len().value(), last, "{}", ctx); + + assert!(is_iota(unsafe { prefix.as_slice() }, base), "{}", ctx); + assert!( + is_iota(unsafe { suffix.as_slice() }, base.wrapping_add(i as u8)), + "{}", + ctx + ); + + // Verify it's okay to mutate two disjoint slices concurrently. + // + // SAFETY: `prefix` and `suffix` are non-overlapping sub-ranges of the same + // entry, so materializing both as mutable is sound. + { + let prefix = unsafe { prefix.as_mut_slice() }; + let suffix = unsafe { suffix.as_mut_slice() }; + suffix.fill(0); + prefix.fill(0); + } + + assert!(unsafe { raw.as_slice() }.iter().all(|i| *i == 0), "{}", ctx); + iota(unsafe { raw.as_mut_slice() }, base); + } + } + + fn check_threaded(buffer: &mut Buffer, ctx: &Ctx) { + let spawns = buffer.len(); + + // The goal here is to ensure that threads hold concurrent mutable references to + // disjoint entries within the `Buffer` and that when the mutate them concurrently, + // we get a coherent result. + let pre = &Barrier::new(spawns); + let post = &Barrier::new(spawns); + { + let borrowed: &Buffer = buffer; + thread::scope(|s| { + for i in 0..spawns { + s.spawn(move || { + // SAFETY: The top level method has an exclusive reference to the buffer. + // + // This loop by construction accesses disjoint offsets. This is sufficient + // to guarantee exclusivity for this thread. + let slice = unsafe { borrowed.get(i).unwrap().as_mut_slice() }; + pre.wait(); + iota(slice, i as u8); + post.wait(); + }); + } + }); + } + + for i in 0..spawns { + let slice = unsafe { buffer.get(i).unwrap().as_slice() }; + assert!(is_iota(slice, i as u8), "i = {} -- {}", i, ctx); + } + } + + fn iota(x: &mut [u8], base: u8) { + for (i, v) in x.iter_mut().enumerate() { + *v = base.wrapping_add(i as u8); + } + } + + #[must_use] + fn is_iota(x: &[u8], base: u8) -> bool { + for (i, v) in x.iter().enumerate() { + if *v != base.wrapping_add(i as u8) { + return false; + } + } + true + } + + #[test] + fn test_buffer() { + let entries = [0, 1, 2, 5]; + let bytes_per_entry = [0, 1, 2, 5, 10].map(Bytes::new); + let align = [Align::_1, Align::_64]; + + for entries in entries { + for bytes_per_entry in bytes_per_entry { + for align in align { + test_buffer_inner(entries, bytes_per_entry, align); + } + } + } + } + + #[test] + fn test_buffer_overflow_mul() { + // entries * bytes_per_entry overflows usize. + let result = Buffer::new(usize::MAX, Bytes::new(2), Align::_1); + assert!(result.is_err()); + } + + #[test] + fn test_buffer_overflow_layout() { + // Total size exceeds isize::MAX (Layout rejects this). + let result = Buffer::new(isize::MAX as usize, Bytes::new(2), Align::_1); + assert!(result.is_err()); + } } diff --git a/diskann-inmem/src/arbiter/epoch.rs b/diskann-inmem/src/arbiter/epoch.rs index 0fcd052a7..bc5547b5a 100644 --- a/diskann-inmem/src/arbiter/epoch.rs +++ b/diskann-inmem/src/arbiter/epoch.rs @@ -4,7 +4,7 @@ */ use std::sync::{ - atomic::{AtomicU32, AtomicU64, Ordering}, + atomic::{AtomicU64, Ordering}, Mutex, }; diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs index b59e18522..04bc24bae 100644 --- a/diskann-inmem/src/arbiter/freelist.rs +++ b/diskann-inmem/src/arbiter/freelist.rs @@ -6,7 +6,7 @@ use std::{ num::NonZeroU32, sync::{ - atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, Mutex, }, }; diff --git a/diskann-inmem/src/arbiter/generation.rs b/diskann-inmem/src/arbiter/generation.rs index 6d82a8cd7..4da9d4576 100644 --- a/diskann-inmem/src/arbiter/generation.rs +++ b/diskann-inmem/src/arbiter/generation.rs @@ -5,33 +5,75 @@ use std::sync::atomic::{AtomicU64, Ordering}; +/// An atomic [`Generation`] tag. +/// +/// Access is performed through [`Ref`] and [`Mut`]. #[derive(Debug)] #[repr(transparent)] pub struct Tag(AtomicU64); impl Tag { + /// Construct a new [`Tag`] initialized to `generation`. pub const fn new(generation: Generation) -> Self { Self(AtomicU64::new(generation.value())) } + /// Return a read-only [`Ref`] to `self`. pub fn as_ref(&self) -> Ref<'_> { Ref::new(&self.0) } + /// Return a read-write [`Mut`] to `self`. pub fn as_mut(&self) -> Mut<'_> { Mut::new(&self.0) } + /// Creates a new reference to a `Tag` from a raw pointer. + /// + /// # Safety + /// + /// * `ptr` must be aligned to `align_of::()`. + /// * `ptr` must be valid for both reads and writes for the whole lifetime `'a`. + /// * This must adhere to the memory model for atomic accesses. In particular, it must + /// not admit conflicting atomic and non-atomic accesses, or atomic accesses of + /// different sizes without synchronization. + /// + /// See: pub unsafe fn from_ptr<'a>(ptr: *mut Tag) -> &'a Self { unsafe { &*ptr } } } +/// A generation tag for controlling concurrent access to data. +/// +/// Generally, generations are decremented from `Generation::MAX`, with higher values +/// representing older generations. This allows zero to stand for "unused" as it is newer +/// than any valid generation. +/// +/// Certain low-numbered generations are reserved for special uses. Any generation less +/// than or equal to [`Generation::RESERVED`] is reserved. +/// +/// # Reserved Generations +/// +/// * [`Generation::AVAILABLE`]: The associated slot is not currently storing valid data +/// and is available to use. +/// +/// To acquire ownership, an atomic compare-exchange must be used away from this state. +/// +/// * [`Generation::OWNED`]: The associated data is owned by some thread. Only the thread +/// owning this slot may update it. +/// +/// Note that ownership may be transferred between threads as long as this ownership +/// transfer is unambiguous and properly synchronized. +/// +/// * [`Generation::FROZEN`]: This data is protected and is not expected to be mutated. +/// #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[repr(transparent)] pub struct Generation(u64); impl Generation { + /// The maximum generation. This is the oldest possible generation. pub const MAX: Self = Self::new(u64::MAX); // Reserved generations. @@ -41,27 +83,51 @@ impl Generation { // don't require additional initialization. // // If you add states - make sure to increment the `RESERVED` marker! - pub(crate) const AVAILABLE: Self = Self::new(0); - pub(crate) const OWNED: Self = Self::new(1); - pub(crate) const FROZEN: Self = Self::new(2); - const RESERVED: Self = Self::FROZEN; + /// See [`Generation`]. + pub const AVAILABLE: Self = Self::new(0); + + /// See [`Generation`]. + pub const OWNED: Self = Self::new(1); + + /// See [`Generation`]. + pub const FROZEN: Self = Self::new(2); + + /// The maximum reserved generation. See [`Generation`]. + pub const RESERVED: Self = Self::FROZEN; + + /// Return `true` if `self` belongs to a reserved generation. #[must_use = "this function has no side-effects"] pub(crate) fn is_reserved(self) -> bool { self <= Self::RESERVED } + /// Construct a new [`Generation`] with `value`. #[inline] pub const fn new(value: u64) -> Self { Self(value) } + /// Return the value of `self`. #[inline] pub const fn value(self) -> u64 { self.0 } + + #[cfg(test)] + const fn add(self, v: u64) -> Self { + Self(self.0 + v) + } + + #[cfg(test)] + const fn sub(self, v: u64) -> Self { + Self(self.0 - v) + } } +/// A read-only handle to a [`Tag`]. +/// +/// Provides atomic load access to the underlying generation value. #[derive(Debug, Clone, Copy)] #[repr(transparent)] pub struct Ref<'a>(&'a AtomicU64); @@ -77,12 +143,17 @@ impl<'a> Ref<'a> { self.0 } + /// Load the current [`Generation`] with the given ordering. #[inline] pub fn get(&self, ordering: Ordering) -> Generation { Generation::new(self.0.load(ordering)) } } +/// A read-write handle to a [`Tag`]. +/// +/// Provides atomic store and compare-exchange access in addition to the read access +/// inherited from [`Ref`] via [`Deref`](std::ops::Deref). #[derive(Debug, Clone, Copy)] #[repr(transparent)] pub struct Mut<'a>(Ref<'a>); @@ -93,6 +164,9 @@ impl<'a> Mut<'a> { Self(Ref::new(slot)) } + /// Attempt to atomically update the generation from `current` to `new`. + /// + /// Returns `Ok(current)` on success, or `Err(actual)` if the value was not `current`. #[inline] pub fn try_set( &self, @@ -107,6 +181,7 @@ impl<'a> Mut<'a> { .map_err(Generation::new) } + /// Atomically store a [`Generation`] with the given ordering. #[inline] pub fn set(&self, generation: Generation, ordering: Ordering) { self.inner().store(generation.value(), ordering) @@ -119,3 +194,73 @@ impl<'a> std::ops::Deref for Mut<'a> { &self.0 } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{thread, sync::Barrier}; + + use crate::{num::{Bytes,Align}, arbiter::Buffer}; + + fn spin_decrement(m: Mut<'_>, count: usize) { + for i in 0..count { + let mut current = m.get(Ordering::Relaxed); + while let Err(c) = m.try_set( + current, + current.sub(1), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + current = c; + } + } + } + + #[test] + fn test_atomic() { + let threads = 4; + let barrier = &Barrier::new(threads); + + // This dance basically verifies that we can view the tag though a proper-aligned + // raw pointer. + let buffer = Buffer::new(1, Bytes::size_of::(), Align::of::()).unwrap(); + let ptr = buffer.get(0).unwrap().as_mut_ptr().cast::(); + + { + let tag = unsafe { Tag::from_ptr(ptr) }; + tag.as_mut().set(Generation::MAX, Ordering::Relaxed); + } + + let count = 1000; + thread::scope(|s| { + for i in 0..threads { + s.spawn(|| { + // Re-derive `p` to avoid issues with `Send`. + let p = buffer.get(0).unwrap().as_mut_ptr().cast::(); + let tag = unsafe { Tag::from_ptr(p) }; + barrier.wait(); + spin_decrement(tag.as_mut(), count); + }); + } + }); + + { + let tag = unsafe { Tag::from_ptr(ptr) }; + let g = tag.as_ref().get(Ordering::Relaxed); + assert_eq!(g, Generation::MAX.sub((count * threads) as u64)); + } + } + + #[test] + fn test_is_reserved() { + assert!(Generation::AVAILABLE.is_reserved()); + assert!(Generation::OWNED.is_reserved()); + assert!(Generation::FROZEN.is_reserved()); + assert!(!Generation::FROZEN.add(1).is_reserved()); + } +} diff --git a/diskann-inmem/src/arbiter/mod.rs b/diskann-inmem/src/arbiter/mod.rs index 097962ac1..9499e4d0f 100644 --- a/diskann-inmem/src/arbiter/mod.rs +++ b/diskann-inmem/src/arbiter/mod.rs @@ -4,7 +4,7 @@ */ pub(crate) mod buffer; -pub use buffer::{prefetch_cachelines, Buffer, Slice}; +pub use buffer::{Buffer, RawSlice}; pub mod epoch; diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index d964f3015..8aea3eed1 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -6,7 +6,7 @@ use diskann::{ANNError, ANNResult}; use diskann_vector::{ distance::{self, DistanceProvider, Metric}, - AsUnaligned, UnalignedSlice, + UnalignedSlice, }; use thiserror::Error; @@ -46,7 +46,7 @@ where } pub fn bytes(&self) -> Bytes { - Bytes(self.dim() * std::mem::size_of::()) + Bytes::new(self.dim() * std::mem::size_of::()) } } @@ -54,8 +54,8 @@ impl layers::Layer for Full where T: bytemuck::Pod + Send + Sync, { - fn bytes(&self) -> usize { - >::bytes(self).0 + fn bytes(&self) -> Bytes { + >::bytes(self) } } diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index a7edd0290..95be3f604 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -3,8 +3,10 @@ * Licensed under the MIT license. */ -use diskann::{error::StandardError, utils::VectorRepr, ANNResult}; -use diskann_vector::{distance::Metric, DistanceFunction}; +use diskann::ANNResult; +use diskann_vector::DistanceFunction; + +use crate::num::Bytes; pub(crate) mod full; pub use full::Full; @@ -31,7 +33,7 @@ pub trait Layer: Send + Sync + 'static { /// Return the number of bytes needed by this layer representation. /// /// To be well-behaved, this function must be idempotent. - fn bytes(&self) -> usize; + fn bytes(&self) -> Bytes; } pub trait Set: Layer { diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index b493f6a3c..9ca3770d8 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -30,8 +30,8 @@ pub struct Neighbors { impl Neighbors { pub fn new(entries: usize, max_length: usize) -> Self { - let bytes = Bytes((max_length + 1) * std::mem::size_of::()); - let neighbors = Buffer::new(entries, bytes, Align(128)); + let bytes = Bytes::new((max_length + 1) * std::mem::size_of::()); + let neighbors = Buffer::new(entries, bytes, Align::_128).unwrap(); let locks = std::iter::repeat_with(|| RwLock::new(())) .take(entries.div_ceil(LOCK_GRANULARITY)) .collect(); @@ -42,7 +42,7 @@ impl Neighbors { /// Return the maximum length for any adjacency list. pub fn max_length(&self) -> usize { // We reserve 4 bytes at the beginning for the length of the adjacency list. - (self.neighbors.stride().0 - std::mem::size_of::()) / std::mem::size_of::() + (self.neighbors.stride().value() - std::mem::size_of::()) / std::mem::size_of::() } pub fn entries(&self) -> usize { @@ -59,9 +59,9 @@ impl Neighbors { // SAFETY: By consruction `self.buffer` has the same number of entries as // `self.locks` and we have already checked that `i` is in-bounds there. let (prefix, rest) = - unsafe { self.neighbors.get_unchecked(i) }.split(std::mem::size_of::()); + unsafe { self.neighbors.get_unchecked(i) }.split(Bytes::size_of::()); - debug_assert_eq!(prefix.len(), std::mem::size_of::()); + debug_assert_eq!(prefix.len(), Bytes::size_of::()); debug_assert!(prefix.as_ptr().is_aligned()); // SAFETY: We hold the read-lock, so reading is safe. From our bounds checks, we @@ -73,7 +73,7 @@ impl Neighbors { let mut resizer = neighbors.resize(len); unsafe { std::ptr::copy_nonoverlapping( - rest.as_ptr().as_ptr(), + rest.as_mut_ptr(), resizer.as_mut_ptr().cast::(), len * std::mem::size_of::(), ) @@ -98,8 +98,8 @@ impl Neighbors { let raw = unsafe { std::slice::from_raw_parts_mut( - slice.as_ptr().as_ptr().cast::(), - slice.len() / std::mem::size_of::(), + slice.as_mut_ptr().cast::(), + slice.len().value() / std::mem::size_of::(), ) }; diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index 8843fd17a..073e66fd2 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -3,8 +3,321 @@ * Licensed under the MIT license. */ +use std::num::NonZeroUsize; + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Bytes(usize); + +impl Bytes { + pub const CACHELINE: Self = Self::new(64); + pub const ZERO: Self = Self::new(0); + + #[inline] + pub const fn new(bytes: usize) -> Self { + Self(bytes) + } + + #[inline] + pub const fn value(self) -> usize { + self.0 + } + + #[inline] + pub const fn checked_add(self, other: Bytes) -> Option { + match self.value().checked_add(other.value()) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + #[inline] + pub const fn checked_mul(self, other: usize) -> Option { + match self.value().checked_mul(other) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + #[inline] + pub(crate) const fn unchecked_mul(self, other: usize) -> Bytes { + Bytes::new(self.value() * other) + } + + #[inline] + pub const fn checked_sub(self, other: Bytes) -> Option { + match self.value().checked_sub(other.value()) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + #[inline] + pub(crate) const fn unchecked_sub(self, other: Bytes) -> Bytes { + Self::new(self.value() - other.value()) + } + + #[inline] + pub const fn checked_next_multiple_of(self, other: Bytes) -> Option { + match self.value().checked_next_multiple_of(other.value()) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + #[inline] + pub const fn size_of() -> Self { + Self::new(std::mem::size_of::()) + } + + pub const fn is_zero(self) -> bool { + self.0 == 0 + } +} + +impl std::fmt::Display for Bytes { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} bytes", self.value()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct Bytes(pub usize); +#[repr(transparent)] +pub struct Align(NonZeroUsize); + +impl Align { + pub const fn new(value: usize) -> Option { + match NonZeroUsize::new(value) { + Some(value) => { + if value.is_power_of_two() { + Some(Self(value)) + } else { + None + } + } + None => None, + } + } + + pub const fn value(self) -> usize { + self.0.get() + } + + pub const unsafe fn new_unchecked(value: usize) -> Self { + debug_assert!(value.is_power_of_two()); + Self(unsafe { NonZeroUsize::new_unchecked(value) }) + } + + pub const fn of() -> Self { + // SAFETY: `std::mem::align_of` is guaranteed to return a power of 2. + unsafe { Self::new_unchecked(std::mem::align_of::()) } + } + + pub const fn from_layout(layout: std::alloc::Layout) -> Self { + // SAFETY: `Layout::align` is guaranteed to be a power of 2. + unsafe { Self::new_unchecked(layout.align()) } + } + + // Constants. + pub const _1: Self = Self::new(1).unwrap(); + pub const _2: Self = Self::new(2).unwrap(); + pub const _4: Self = Self::new(4).unwrap(); + pub const _8: Self = Self::new(8).unwrap(); + pub const _16: Self = Self::new(16).unwrap(); + pub const _32: Self = Self::new(32).unwrap(); + pub const _64: Self = Self::new(64).unwrap(); + pub const _128: Self = Self::new(128).unwrap(); + pub const _256: Self = Self::new(256).unwrap(); + pub const _512: Self = Self::new(512).unwrap(); + pub const _1024: Self = Self::new(1024).unwrap(); + pub const _2048: Self = Self::new(2048).unwrap(); + pub const _4096: Self = Self::new(4096).unwrap(); +} + +impl std::fmt::Display for Align { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_and_value_roundtrip() { + assert_eq!(Bytes::new(42).value(), 42); + assert_eq!(Bytes::new(0).value(), 0); + } + + #[test] + fn cacheline_constant() { + assert_eq!(Bytes::CACHELINE, Bytes::new(64)); + } + + #[test] + fn size_of_returns_correct_size() { + assert_eq!(Bytes::size_of::(), Bytes::new(1)); + assert_eq!(Bytes::size_of::(), Bytes::new(8)); + assert_eq!(Bytes::size_of::<[u8; 128]>(), Bytes::new(128)); + } + + #[test] + fn checked_add_success() { + assert_eq!( + Bytes::new(10).checked_add(Bytes::new(20)), + Some(Bytes::new(30)) + ); + } + + #[test] + fn checked_add_overflow() { + assert_eq!(Bytes::new(usize::MAX).checked_add(Bytes::new(1)), None); + } + + #[test] + fn checked_sub_success() { + assert_eq!( + Bytes::new(30).checked_sub(Bytes::new(10)), + Some(Bytes::new(20)) + ); + } + + #[test] + fn checked_sub_underflow() { + assert_eq!(Bytes::new(5).checked_sub(Bytes::new(10)), None); + } + + #[test] + fn checked_mul_success() { + assert_eq!(Bytes::new(64).checked_mul(4), Some(Bytes::new(256))); + } + + #[test] + fn checked_mul_overflow() { + assert_eq!(Bytes::new(usize::MAX).checked_mul(2), None); + } + + #[test] + fn checked_mul_by_zero() { + assert_eq!(Bytes::new(100).checked_mul(0), Some(Bytes::new(0))); + } + + #[test] + fn unchecked_mul() { + assert_eq!(Bytes::new(64).unchecked_mul(3), Bytes::new(192)); + } + + #[test] + fn unchecked_sub() { + assert_eq!( + Bytes::new(100).unchecked_sub(Bytes::new(30)), + Bytes::new(70) + ); + } + + #[test] + fn checked_next_multiple_of_already_aligned() { + assert_eq!( + Bytes::new(128).checked_next_multiple_of(Bytes::new(64)), + Some(Bytes::new(128)) + ); + } + + #[test] + fn checked_next_multiple_of_rounds_up() { + assert_eq!( + Bytes::new(100).checked_next_multiple_of(Bytes::new(64)), + Some(Bytes::new(128)) + ); + } + + #[test] + fn checked_next_multiple_of_overflow() { + assert_eq!( + Bytes::new(usize::MAX).checked_next_multiple_of(Bytes::new(2)), + None + ); + } + + #[test] + fn ordering() { + assert!(Bytes::new(10) < Bytes::new(20)); + assert!(Bytes::new(20) > Bytes::new(10)); + assert_eq!(Bytes::new(5), Bytes::new(5)); + } + + #[test] + fn display() { + assert_eq!(format!("{}", Bytes::new(256)), "256 bytes"); + } + + // Align tests + + #[test] + fn align_new_power_of_two() { + assert_eq!(Align::new(1).unwrap().value(), 1); + assert_eq!(Align::new(2).unwrap().value(), 2); + assert_eq!(Align::new(64).unwrap().value(), 64); + assert_eq!(Align::new(4096).unwrap().value(), 4096); + } + + #[test] + fn align_new_rejects_zero() { + assert!(Align::new(0).is_none()); + } + + #[test] + fn align_new_rejects_non_power_of_two() { + assert!(Align::new(3).is_none()); + assert!(Align::new(5).is_none()); + assert!(Align::new(6).is_none()); + assert!(Align::new(100).is_none()); + } + + #[test] + fn align_of_matches_std() { + assert_eq!(Align::of::<()>().value(), 1); + assert_eq!(Align::of::().value(), std::mem::align_of::()); + assert_eq!(Align::of::().value(), std::mem::align_of::()); + assert_eq!(Align::of::().value(), std::mem::align_of::()); + } + + #[test] + fn align_from_layout() { + let layout = std::alloc::Layout::from_size_align(256, 128).unwrap(); + assert_eq!(Align::from_layout(layout).value(), 128); + } + + #[test] + fn align_constants() { + assert_eq!(Align::_1.value(), 1); + assert_eq!(Align::_2.value(), 2); + assert_eq!(Align::_4.value(), 4); + assert_eq!(Align::_8.value(), 8); + assert_eq!(Align::_16.value(), 16); + assert_eq!(Align::_32.value(), 32); + assert_eq!(Align::_64.value(), 64); + assert_eq!(Align::_128.value(), 128); + assert_eq!(Align::_256.value(), 256); + assert_eq!(Align::_512.value(), 512); + assert_eq!(Align::_1024.value(), 1024); + assert_eq!(Align::_2048.value(), 2048); + assert_eq!(Align::_4096.value(), 4096); + } + + #[test] + fn align_ordering() { + assert!(Align::_1 < Align::_64); + assert!(Align::_128 > Align::_64); + assert_eq!(Align::_32, Align::new(32).unwrap()); + } -#[derive(Debug, Clone, Copy)] -pub struct Align(pub usize); + #[test] + fn align_display() { + assert_eq!(format!("{}", Align::_64), "64"); + } +} diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 598aa2212..e158aa279 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -35,11 +35,7 @@ impl Provider { { let start_points: Vec<_> = start_points.into_iter().collect(); let bytes = layers::Layer::bytes(&layer); - let primary = Primary::new( - capacity.checked_add(start_points.len()).unwrap(), - Bytes(bytes), - 32, - ); + let primary = Primary::new(capacity.checked_add(start_points.len()).unwrap(), bytes, 32); let mut i = capacity; for v in start_points.into_iter() { @@ -223,21 +219,21 @@ type FExpandBeam = unsafe fn( ) -> ANNResult<()>; fn dispatch_expand_beam(bytes: Bytes) -> FExpandBeam { - if bytes <= Bytes(CACHE_LINE_SIZE) { + if bytes <= Bytes::CACHELINE { expand_beam_inner::<1> - } else if bytes <= Bytes(2 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(2) { expand_beam_inner::<2> - } else if bytes <= Bytes(3 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(3) { expand_beam_inner::<3> - } else if bytes <= Bytes(4 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(4) { expand_beam_inner::<4> - } else if bytes <= Bytes(5 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(5) { expand_beam_inner::<5> - } else if bytes <= Bytes(6 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(6) { expand_beam_inner::<6> - } else if bytes <= Bytes(7 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(7) { expand_beam_inner::<7> - } else if bytes <= Bytes(16 * CACHE_LINE_SIZE) { + } else if bytes <= Bytes::CACHELINE.unchecked_mul(16) { expand_beam_inner::<8> } else { expand_beam_inner::<16> @@ -268,8 +264,15 @@ unsafe fn expand_beam_inner( f: &mut dyn FnMut(u32, f32), ) -> ANNResult<()> { debug_assert!( - N * CACHE_LINE_SIZE <= reader.bytes().0.next_multiple_of(CACHE_LINE_SIZE), - "we really rely on this: {}, bytes = {}", N, reader.bytes().0 + N * CACHE_LINE_SIZE + <= reader + .bytes() + .checked_next_multiple_of(Bytes::CACHELINE) + .unwrap() + .value(), + "we really rely on this: {}, bytes = {}", + N, + reader.bytes() ); let len = list.len(); @@ -281,8 +284,6 @@ unsafe fn expand_beam_inner( reader .read_raw_unchecked(list.get_unchecked(j).into_usize()) .as_ptr() - .as_ptr() - .cast_const() .cast(), ) } @@ -296,15 +297,13 @@ unsafe fn expand_beam_inner( reader .read_raw_unchecked(list.get_unchecked(j).into_usize()) .as_ptr() - .as_ptr() - .cast_const() .cast(), ) } j += 1; } - if let Some(data) = reader.read(i.into_usize()) { + if let Some(data) = unsafe { reader.read_in_bounds(i.into_usize()) } { f(i, distance.evaluate(data)?) } } diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index c0c69e149..617aa8ace 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -6,14 +6,11 @@ use std::{ iter::repeat_n, num::NonZeroU32, - sync::{ - atomic::{AtomicU64, Ordering}, - Mutex, - }, + sync::{atomic::Ordering, Mutex}, }; use crate::{ - arbiter::{self, buffer, epoch, generation, Buffer, Freelist, Generation, Slice}, + arbiter::{epoch, generation, Buffer, Freelist, Generation, RawSlice}, num::{Align, Bytes}, Neighbors, }; @@ -25,7 +22,7 @@ pub struct Primary { // These tags are mirrored from `tags` - with the latter being used for secondary scans // offering slightly better locality. buffer: Buffer, - unpadded: usize, + unpadded: Bytes, tags: Vec, freelist: Freelist, registry: epoch::Registry, @@ -33,15 +30,15 @@ pub struct Primary { drain: Mutex>, } -const SPLIT: usize = std::mem::size_of::(); +const SPLIT: Bytes = Bytes::size_of::(); impl Primary { pub fn new(entries: usize, bytes: Bytes, max_neighbors: usize) -> Self { - let unpadded = bytes.0 + SPLIT; - let padded_bytes = unpadded.checked_next_multiple_of(64).unwrap(); + let unpadded = bytes.checked_add(SPLIT).unwrap(); + let padded_bytes = unpadded.checked_next_multiple_of(Bytes::CACHELINE).unwrap(); Self { - buffer: Buffer::new(entries, Bytes(padded_bytes), Align(128)), + buffer: Buffer::new(entries, padded_bytes, Align::_128).unwrap(), unpadded, tags: repeat_n(Generation::AVAILABLE, entries) .map(|v| generation::Tag::new(v)) @@ -137,12 +134,12 @@ impl Primary { } } - unsafe fn data(&self, i: usize) -> (generation::Mut<'_>, Slice<'_>) { + unsafe fn data(&self, i: usize) -> (generation::Mut<'_>, RawSlice<'_>) { let (mirror, data) = unsafe { self.buffer.get_unchecked(i) } .truncate(self.unpadded) .split(SPLIT); ( - unsafe { generation::Tag::from_ptr(mirror.as_ptr().as_ptr().cast()) }.as_mut(), + unsafe { generation::Tag::from_ptr(mirror.as_mut_ptr().cast()) }.as_mut(), data, ) } @@ -157,7 +154,7 @@ impl Primary { #[derive(Debug)] pub struct Reader<'a> { buffer: &'a Buffer, - unpadded: usize, + unpadded: Bytes, neighbors: &'a Neighbors, epoch: epoch::Guard<'a>, } @@ -170,13 +167,33 @@ impl<'a> Reader<'a> { /// 2. The read cannot be guaranteed to be race-free. #[inline] pub fn read(&self, i: usize) -> Option<&[u8]> { - let (generation, rest) = match self.buffer.get(i) { - Some(slice) => slice.truncate(self.unpadded).split(SPLIT), - None => return None, + if self.is_in_bounds(i) { + unsafe { self.read_in_bounds(i) } + } else { + None + } + } + + /// Return `true` if the index `i` is in-bounds. + #[inline] + #[must_use = "this function has no side-effects"] + pub fn is_in_bounds(&self, i: usize) -> bool { + i < self.buffer.len() + } + + #[inline] + pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { + debug_assert!(self.is_in_bounds(i)); + + let (generation, rest) = unsafe { + self.buffer + .get_unchecked(i) + .truncate_unchecked(self.unpadded) + .split_unchecked(SPLIT) }; // NOTE: Must be `Acquire` to correctly synchronize with writes. - let generation = unsafe { generation::Tag::from_ptr(generation.as_ptr().as_ptr().cast()) } + let generation = unsafe { generation::Tag::from_ptr(generation.as_mut_ptr().cast()) } .as_ref() .get(Ordering::Acquire); @@ -189,26 +206,19 @@ impl<'a> Reader<'a> { } } - /// Return `true` if the index `i` is in-bounds. - #[inline] - #[must_use = "this function has no side-effects"] - pub fn is_in_bounds(&self, i: usize) -> bool { - i < self.buffer.len() - } - /// Return the raw data slice for index `i` without any race guarantees. /// /// # Safety /// /// The index `i` must be in-bounds. #[inline] - pub(crate) unsafe fn read_raw_unchecked(&self, i: usize) -> Slice<'_> { + pub(crate) unsafe fn read_raw_unchecked(&self, i: usize) -> RawSlice<'_> { unsafe { self.buffer.get_unchecked(i) }.truncate(self.unpadded) } /// Return the number of bytes for each entry. pub(crate) fn bytes(&self) -> Bytes { - Bytes(self.unpadded) + self.unpadded } // TODO: We may want to lock `Neighbors` in some way to enable exclusive access during @@ -223,14 +233,10 @@ pub struct Write<'a> { tag: generation::Mut<'a>, mirror: generation::Mut<'a>, generation: Generation, - data: Slice<'a>, + data: RawSlice<'a>, } impl<'a> Write<'a> { - pub fn raw_slice(&mut self) -> Slice<'_> { - self.data - } - pub fn as_mut_slice(&mut self) -> &mut [u8] { unsafe { self.data.as_mut_slice() } } From f60cfe1369554e1fd8573fb4cc5df01183deedea Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 12 Jun 2026 15:20:55 -0700 Subject: [PATCH 06/45] Checkpoint. --- Cargo.lock | 1 + diskann-inmem/Cargo.toml | 1 + diskann-inmem/src/arbiter/buffer.rs | 12 +- diskann-inmem/src/arbiter/epoch.rs | 360 +++++++++++++++++++----- diskann-inmem/src/arbiter/freelist.rs | 80 +++--- diskann-inmem/src/arbiter/generation.rs | 56 +++- diskann-inmem/src/arbiter/mod.rs | 2 +- diskann-inmem/src/provider.rs | 40 ++- diskann-inmem/src/store.rs | 167 +++++++---- 9 files changed, 518 insertions(+), 201 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9553e67f4..97dd88072 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -812,6 +812,7 @@ name = "diskann-inmem" version = "0.54.0" dependencies = [ "bytemuck", + "crossbeam-queue", "diskann", "diskann-utils", "diskann-vector", diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 235a70ea1..335a68f45 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -9,6 +9,7 @@ edition.workspace = true [dependencies] bytemuck = { workspace = true, features = ["must_cast"] } +crossbeam-queue = "0.3.12" diskann = { workspace = true } diskann-utils = { workspace = true, default-features = false } diskann-vector.workspace = true diff --git a/diskann-inmem/src/arbiter/buffer.rs b/diskann-inmem/src/arbiter/buffer.rs index b3c3ad95d..4a50b22c0 100644 --- a/diskann-inmem/src/arbiter/buffer.rs +++ b/diskann-inmem/src/arbiter/buffer.rs @@ -142,13 +142,13 @@ impl Drop for Buffer { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] #[non_exhaustive] pub struct BufferError; impl std::fmt::Display for BufferError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "requested allocation exceeds `isize::MAX`") + f.write_str("requested allocation exceeds `isize::MAX`") } } @@ -319,7 +319,7 @@ impl<'a> RawSlice<'a> { mod tests { use super::*; - use std::{thread, sync::Barrier}; + use std::{sync::Barrier, thread}; #[derive(Debug)] struct Ctx { @@ -373,7 +373,6 @@ mod tests { // Check Slice Methods check_slice_methods(&mut buffer, &ctx); - // Check that concurrent mutation is allowed. // // This is mainly a Miri check. @@ -402,7 +401,10 @@ mod tests { assert_eq!(raw_slice.len(), buffer.stride()); assert_eq!(raw_slice.as_non_null().as_ptr(), raw_slice.as_mut_ptr()); - assert_eq!(raw_slice.as_non_null().as_ptr().cast_const(), raw_slice.as_ptr()); + assert_eq!( + raw_slice.as_non_null().as_ptr().cast_const(), + raw_slice.as_ptr() + ); assert_eq!( raw_slice.as_ptr(), diff --git a/diskann-inmem/src/arbiter/epoch.rs b/diskann-inmem/src/arbiter/epoch.rs index bc5547b5a..f87fdc549 100644 --- a/diskann-inmem/src/arbiter/epoch.rs +++ b/diskann-inmem/src/arbiter/epoch.rs @@ -4,11 +4,17 @@ */ use std::sync::{ - atomic::{AtomicU64, Ordering}, - Mutex, + atomic::{AtomicUsize, Ordering}, + Mutex, TryLockError, }; -use crate::arbiter::Generation; +use crossbeam_queue::SegQueue; +use diskann::utils::IntoUsize; + +use crate::arbiter::{ + generation::{Mut, Tag}, + Generation, +}; const CAPACITY: usize = 256; @@ -16,11 +22,39 @@ const CAPACITY: usize = 256; pub struct Registry { /// A record of the active generations. /// - /// * 0 = "available". - /// * non-zero: generation is active. - slots: Box<[AtomicU64]>, - generation: AtomicU64, - barrier: Mutex, + /// * Generation::MAX = "available". + /// * Anything less = "registered". + slots: Box<[Tag]>, + + // The current epoch. This begins as `Generation::MAX.sub(1)` and decrements over time. + // + // NOTE: This can only be mutated in `try_advance`. + generation: Tag, + + // A hint for the next available registration slot. + hint: AtomicUsize, + + // We use three queues for storing slots. + // + // 1. Belongs to the current generation and is getting filled. + // 2. Ready for the next generation that will be populated on the next `try_advance`. + // Note that after a `try_advance` call, both 1 and 2 can be added to. + // 3. The queue returned from `try_advance` to be drained. Items drained are safe to + // reclaim. + retiring: [SegQueue; 3], + + // We can only retire a single generation at a time. + // This guard avoids situations. + drain: Mutex<()>, +} + +// Return the queue index for the `generation`. +fn queue(generation: Generation) -> usize { + generation.value().into_usize() % 3 +} + +fn last_queue(generation: Generation) -> usize { + queue(Generation::new(generation.value().wrapping_add(1))) } impl Registry { @@ -30,103 +64,188 @@ impl Registry { pub fn with_capacity(capacity: usize) -> Self { Self { - slots: (0..capacity).map(|_| AtomicU64::new(0)).collect(), - generation: AtomicU64::new(u64::MAX), - barrier: Mutex::new(Hint(0)), + slots: (0..capacity).map(|_| Tag::new(Generation::MAX)).collect(), + generation: Tag::new(Generation::MAX.sub(1)), + hint: AtomicUsize::new(0), + retiring: core::array::from_fn(|_| SegQueue::new()), + drain: Mutex::new(()), } } + /// Return the current generation. + /// + /// This has [`Ordering::Acquire`] semantics. pub fn generation(&self) -> Generation { - Generation::new(self.generation.load(Ordering::Acquire)) + self.generation.as_ref().get(Ordering::Acquire) } - pub fn register(&self) -> Guard<'_> { - let mut barrier = self.barrier.lock().unwrap(); + /// Register the caller with the registry. + /// + /// On success, the returned [`Guard`] will protect items tagged with + /// [`Guard::generation`] and higher. + pub fn register(&self) -> Result, Unavailable> { + self.register_inner(NoDelay) + } - // No synchronization happens on the global generation tag. - let generation = self.generation.load(Ordering::Acquire); - let hint = barrier.increment(); + fn register_inner(&self, _: T) -> Result, Unavailable> + where + T: RegisterDelay, + { + // REGISTER CHECK + let mut generation = self.generation(); + let hint = self.hint.fetch_add(1, Ordering::Relaxed); let nslots = self.slots.len(); for i in 0..nslots { let slot = (hint + i) % nslots; - if let Ok(_) = self.slots[slot].compare_exchange( - 0, + + let m = self.slots[slot].as_mut(); + if let Ok(_) = m.try_set( + Generation::MAX, generation, - Ordering::Release, + Ordering::Relaxed, Ordering::Relaxed, ) { - return Guard { - registry: self, - slot, - generation: Generation::new(generation), - }; - } - } - - panic!("Let's turn this into a proper error."); - } + let mut reset = false; + loop { + // REGISTER FENCE: This fence is paired with "WAITING FENCE". + // + // See that comment for details. + std::sync::atomic::fence(Ordering::SeqCst); - pub fn advance(&self) -> Generation { - // TODO: What to do on the unlikely event of a wrap-around? - Generation::new(self.generation.fetch_sub(1, Ordering::AcqRel)) - } + // REGISTER RECHECK + let current = self.generation(); + if current == generation { + break; + } - fn wait_for(&self, generation: Generation) { - let generation = generation.value(); - let wait_list = { - let _barrier = self.barrier.lock().unwrap(); - let mut wait_list = Vec::new(); - for (i, s) in self.slots.iter().enumerate() { - let g = s.load(Ordering::Relaxed); - if g != 0 && g >= generation { - wait_list.push(i); + reset = true; + generation = current; } - } - wait_list - }; - - for slot in wait_list { - let s = &self.slots[slot]; - loop { - let g = s.load(Ordering::Relaxed); - if g == 0 || g < generation { - break; + if reset { + m.set(generation, Ordering::Relaxed); } - std::hint::spin_loop(); + + return Ok(Guard { + slot: m, + retire: &self.retiring[queue(generation)], + generation, + }); } } - // This barrier synchronizes with all the relaxed loads on the slots, which are - // set with `Release` semantics. - std::sync::atomic::fence(Ordering::Acquire); + Err(Unavailable) } /// Return the oldest generation that is currently being protected. /// - /// Generations decrement from `Generation::MAX` + /// This uses a fast method that may be overly conservative. /// - /// This is a syncronizing operation. + /// This is a synchronizing operation with [`Ordering::Acquire`] semantics. pub fn waiting(&self) -> Generation { - let _barrier = self.barrier.lock().unwrap(); - let mut highest = 0; + self.waiting_inner(NoDelay) + } + + fn waiting_inner(&self, _: T) -> Generation + where + T: WaitingDelay, + { + // WAITING FENCE: This is a very important part for the correctness of the algorithm. + // + // What we're protecting against is a scenario where "registering" thread A reads a + // generation, then "waiting" thread B does a scan, thinks everything is safe, and + // then thread A finishes its CAS for its registration. + // + // This is prevented by the fence. Consider the following. + // + // 1. Thread A invokes "REGISTER FENCE" after a successful CAS, and then checks the + // generation at "REGISTER RECHECK". + // + // 2. Thread B now enters the this block of code, executes "WAITING FENCE", then + // reads the generation tags for all slots. + // + // With the total order induced by the sequentially consistent fence, either thread + // A's fence executes first, or thread B's executes first. + // + // * If thread A's fence executes first, then thread B will see the CAS and the set + // value is guaranteed to be greater-than or equal to "WAITING CHECK" since the + // generation check since is monotonically decreasing and thread A's + // "REGISTER CHECK" is forced to happen before. + // + // * If Thread B's fence executes first, then thread A's "REGISTER RECHECK" will + // observe at least the result of "WAITING CHECK" and update itself on the retry. + // + // It's possible that thread B observes the CAS to "REGISTER CHECK", but since + // thread A will monotonically decrease it before exiting, the value thread B + // observes is conservative and not incorrect. + std::sync::atomic::fence(Ordering::SeqCst); + + // WAITING CHECK + let mut max = self.generation(); + for s in self.slots.iter() { - let g = s.load(Ordering::Relaxed); - highest = highest.max(g); + let generation = s.as_ref().get(Ordering::Relaxed); + if generation != Generation::MAX { + max = max.max(generation); + } } - // `acquires` with respect to all previous relaxed loads. + // This synchronizes with all the guard's `Release`s. std::sync::atomic::fence(Ordering::Acquire); + max + } - Generation::new(highest) + pub fn try_advance(&self) -> Option> { + self.try_advance_inner(NoDelay) + } + + fn try_advance_inner(&self, _: T) -> Option> + where + T: TryAdvanceDelay, + { + // We first try to acquire the `drain` lock. + // + // It can only fail if someone else is holding the drain lock, which means we can't + // proceed anyways. + // + // This can help save an expensive slot scan. + // + // We intentionally ignore lock-poison since we expect the guarded queue to be + // robust with respect to panics. + let drain = match self.drain.try_lock() { + Ok(drain) => drain, + Err(TryLockError::Poisoned(drain)) => drain.into_inner(), + Err(TryLockError::WouldBlock) => return None, + }; + + let waiting = self.waiting(); + let current = self.generation.as_ref().get(Ordering::Relaxed); + + // All waiters belong to the current generation. Therefore, it is safe to release + // the old array queue + if waiting == current { + // We are safe to use a `fetch_sub` here because `drain` is ensuring exclusivity + // of the access. + // + // However, this still needs to be `SeqCst` so that this properly synchronizes + // with "REGISTER FENCE" and "WAITER FENCE". + let _previous = self.generation.as_mut().fetch_decrement(Ordering::SeqCst); + debug_assert_eq!(_previous, current, "concurrency violation"); + + let queue = &self.retiring[last_queue(current)]; + Some(Drain { queue, drain }) + } else { + // Previous generation has not completely retired. + None + } } } #[derive(Debug)] pub struct Guard<'a> { - registry: &'a Registry, - slot: usize, + slot: Mut<'a>, + retire: &'a SegQueue, generation: Generation, } @@ -136,21 +255,114 @@ impl Guard<'_> { pub fn generation(&self) -> Generation { self.generation } + + /// Retire the slot `i` at the current generation. + #[inline] + pub fn retire(&self, i: u32) { + self.retire.push(i) + } + + /// Retire all items in `itr`. + pub fn retire_all(&self, itr: I) + where + I: IntoIterator, + { + for i in itr { + self.retire(i) + } + } } impl Drop for Guard<'_> { fn drop(&mut self) { - self.registry.slots[self.slot].store(0, Ordering::Release) + self.slot.set(Generation::MAX, Ordering::Release); + } +} + +#[derive(Debug)] +pub struct Drain<'a> { + queue: &'a SegQueue, + drain: std::sync::MutexGuard<'a, ()>, +} + +impl Drain<'_> { + #[must_use = "reclaimed ids must be reclaimed"] + pub fn pop(&self) -> Option { + self.queue.pop() + } + + pub fn len(&self) -> usize { + self.queue.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Iterator for Drain<'_> { + type Item = u32; + fn next(&mut self) -> Option { + self.pop() + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) } } +// NOTE: This relies on `Drain` holding the `drain` guard. In this state, we are guaranteed +// that no-one is writing into the queue, which would otherwise invalidate the exact-size +// iterator guarantee. +impl ExactSizeIterator for Drain<'_> {} + #[derive(Debug)] -struct Hint(usize); +#[non_exhaustive] +pub struct Unavailable; + +impl std::fmt::Display for Unavailable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("all available registry slots are occupied") + } +} + +impl std::error::Error for Unavailable {} -impl Hint { - fn increment(&mut self) -> usize { - let x = self.0; - self.0 += 1; - x +impl From for diskann::ANNError { + #[track_caller] + fn from(unavailable: Unavailable) -> Self { + diskann::ANNError::opaque(unavailable) } } + +// Delays +// +// To help test standard race scenarios without advanced tooling, we use optional delays +// that our tests can introduce to ensure threads are in various intermediate points. +// +// This does not necessarily test that the memory orderings are correct, but at least +// is a smoke test that various (known) races are handled properly. + +#[derive(Debug)] +struct NoDelay; + +trait RegisterDelay {} + +impl RegisterDelay for NoDelay {} + +trait WaitingDelay {} + +impl WaitingDelay for NoDelay {} + +trait TryAdvanceDelay {} + +impl TryAdvanceDelay for NoDelay {} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; +} diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs index 04bc24bae..6dfc85371 100644 --- a/diskann-inmem/src/arbiter/freelist.rs +++ b/diskann-inmem/src/arbiter/freelist.rs @@ -6,22 +6,27 @@ use std::{ num::NonZeroU32, sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, + atomic::{AtomicU32, AtomicUsize, Ordering}, Mutex, }, }; +use crossbeam_queue::ArrayQueue; +use diskann::utils::IntoUsize; + #[derive(Debug)] pub struct Freelist { - recycled: Mutex>, - capacity: NonZeroU32, - have_recycled: AtomicBool, + recycled: ArrayQueue, + + /// An (approximate) number of recycled IDs that exist outside the freelist. + orphaned: AtomicUsize, /// The highest ID the freelist manages. This is used when in "append" to determine the /// maximum ID we can return this way. max: u32, + /// The number of "unallocated" IDs remaining. - unallocated: AtomicU32, + current: AtomicU32, } #[derive(Debug, Clone, Copy)] @@ -33,37 +38,30 @@ pub enum Id { impl Freelist { pub fn new(max: u32, capacity: NonZeroU32) -> Self { Self { - recycled: Mutex::new(Vec::with_capacity(capacity.get() as usize)), - capacity, - have_recycled: AtomicBool::new(false), + recycled: ArrayQueue::new(capacity.get().into_usize()), + orphaned: AtomicUsize::new(0), max, - unallocated: AtomicU32::new(max), + current: AtomicU32::new(0), } } pub fn pop(&self) -> Id { - // Small performance optimization - avoid locking the mutex if looks like that won't - // succeed anyways. - if self.have_recycled.load(Ordering::Relaxed) { - let mut recycled = self.recycled.lock().unwrap(); - if let Some(id) = recycled.pop() { - return Id::Found(id); - } - self.have_recycled.store(false, Ordering::Relaxed); + if let Some(id) = self.recycled.pop() { + return Id::Found(id); } // Missed in the recycled buffer. Try pulling from the high-water mark. - let mut unallocated = self.unallocated.load(Ordering::Relaxed); - while unallocated != 0 { - match self.unallocated.compare_exchange( - unallocated, - unallocated - 1, + let mut current = self.current.load(Ordering::Relaxed); + while current != self.max { + match self.current.compare_exchange( + current, + current + 1, Ordering::Relaxed, Ordering::Relaxed, ) { - Ok(unallocated) => return Id::Found(self.max - unallocated), + Ok(current) => return Id::Found(current), Err(actual) => { - unallocated = actual; + current = actual; } } } @@ -76,13 +74,12 @@ impl Freelist { /// inserted. If `false` is returned, it is likely because the internal recycle /// buffer is full. pub fn push(&self, id: u32) -> bool { - let mut recycled = self.recycled.lock().unwrap(); - if recycled.len() < self.capacity() { - recycled.push(id); - self.have_recycled.store(true, Ordering::Relaxed); - true - } else { - false + match self.recycled.push(id) { + Ok(()) => true, + Err(_) => { + self.orphaned.fetch_add(1, Ordering::Relaxed); + false + } } } @@ -92,16 +89,19 @@ impl Freelist { where I: IntoIterator, { - let mut recycled = self.recycled.lock().unwrap(); - let available = self.capacity() - recycled.len(); + let mut itr = itr.into_iter(); let mut count = 0; - itr.into_iter().take(available).for_each(|id| { - count += 1; - recycled.push(id); - }); + while let Some(id) = itr.next() { + if let Err(_) = self.recycled.push(id) { + let (lower, _) = itr.size_hint(); - if count > 0 { - self.have_recycled.store(true, Ordering::Relaxed); + // Add 1 to "put back" the last ID. + self.orphaned + .fetch_add(lower.saturating_add(1), Ordering::Relaxed); + break; + } else { + count += 1; + } } count @@ -112,6 +112,6 @@ impl Freelist { //----------// fn capacity(&self) -> usize { - self.capacity.get() as usize + self.recycled.capacity() } } diff --git a/diskann-inmem/src/arbiter/generation.rs b/diskann-inmem/src/arbiter/generation.rs index 4da9d4576..cbb31994b 100644 --- a/diskann-inmem/src/arbiter/generation.rs +++ b/diskann-inmem/src/arbiter/generation.rs @@ -50,8 +50,8 @@ impl Tag { /// representing older generations. This allows zero to stand for "unused" as it is newer /// than any valid generation. /// -/// Certain low-numbered generations are reserved for special uses. Any generation less -/// than or equal to [`Generation::RESERVED`] is reserved. +/// Certain low-numbered generations are reserved for special uses. Any generation for which +/// [`Generation::is_reserved`] returns `true` is reserved. /// /// # Reserved Generations /// @@ -90,16 +90,16 @@ impl Generation { /// See [`Generation`]. pub const OWNED: Self = Self::new(1); - /// See [`Generation`]. - pub const FROZEN: Self = Self::new(2); - /// The maximum reserved generation. See [`Generation`]. - pub const RESERVED: Self = Self::FROZEN; + const RESERVED: Self = Self::OWNED; + + /// See [`Generation`]. + pub const FROZEN: Self = Self::MAX; /// Return `true` if `self` belongs to a reserved generation. #[must_use = "this function has no side-effects"] pub(crate) fn is_reserved(self) -> bool { - self <= Self::RESERVED + (self <= Self::RESERVED) || (self == Self::FROZEN) } /// Construct a new [`Generation`] with `value`. @@ -114,13 +114,16 @@ impl Generation { self.0 } + pub(in crate::arbiter) fn max(self, other: Self) -> Self { + Self(self.0.max(other.0)) + } + #[cfg(test)] const fn add(self, v: u64) -> Self { Self(self.0 + v) } - #[cfg(test)] - const fn sub(self, v: u64) -> Self { + pub(in crate::arbiter) const fn sub(self, v: u64) -> Self { Self(self.0 - v) } } @@ -150,6 +153,21 @@ impl<'a> Ref<'a> { } } +impl std::fmt::Display for Generation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let me = *self; + if me == Self::AVAILABLE { + f.write_str("Generation(AVAILABLE)") + } else if me == Self::OWNED { + f.write_str("Generation(OWNED)") + } else if me == Self::FROZEN { + f.write_str("Generation(FROZEN)") + } else { + write!(f, "Generation({})", me.value()) + } + } +} + /// A read-write handle to a [`Tag`]. /// /// Provides atomic store and compare-exchange access in addition to the read access @@ -181,6 +199,16 @@ impl<'a> Mut<'a> { .map_err(Generation::new) } + #[inline] + pub fn fetch_decrement(&self, ordering: Ordering) -> Generation { + Generation::new(self.inner().fetch_sub(1, ordering)) + } + + #[inline] + pub fn fetch_min(&self, generation: Generation, ordering: Ordering) -> Generation { + Generation::new(self.inner().fetch_min(generation.value(), ordering)) + } + /// Atomically store a [`Generation`] with the given ordering. #[inline] pub fn set(&self, generation: Generation, ordering: Ordering) { @@ -203,9 +231,12 @@ impl<'a> std::ops::Deref for Mut<'a> { mod tests { use super::*; - use std::{thread, sync::Barrier}; + use std::{sync::Barrier, thread}; - use crate::{num::{Bytes,Align}, arbiter::Buffer}; + use crate::{ + arbiter::Buffer, + num::{Align, Bytes}, + }; fn spin_decrement(m: Mut<'_>, count: usize) { for i in 0..count { @@ -260,7 +291,8 @@ mod tests { fn test_is_reserved() { assert!(Generation::AVAILABLE.is_reserved()); assert!(Generation::OWNED.is_reserved()); + assert!(!Generation::OWNED.add(1).is_reserved()); + assert!(Generation::FROZEN.is_reserved()); - assert!(!Generation::FROZEN.add(1).is_reserved()); } } diff --git a/diskann-inmem/src/arbiter/mod.rs b/diskann-inmem/src/arbiter/mod.rs index 9499e4d0f..268e57d18 100644 --- a/diskann-inmem/src/arbiter/mod.rs +++ b/diskann-inmem/src/arbiter/mod.rs @@ -8,7 +8,7 @@ pub use buffer::{Buffer, RawSlice}; pub mod epoch; -mod freelist; +pub mod freelist; pub use freelist::Freelist; pub mod generation; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index e158aa279..47245d5d8 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -13,9 +13,10 @@ use diskann::{ utils::IntoUsize, ANNError, ANNErrorKind, ANNResult, }; -use diskann_utils::future::{AsyncFriendly, SendFuture}; +use diskann_utils::{views::Matrix, future::{AsyncFriendly, SendFuture}}; use crate::{ + arbiter::epoch, layers::{self, Distance, QueryDistance}, num::Bytes, store::{self, Primary}, @@ -35,19 +36,28 @@ impl Provider { { let start_points: Vec<_> = start_points.into_iter().collect(); let bytes = layers::Layer::bytes(&layer); - let primary = Primary::new(capacity.checked_add(start_points.len()).unwrap(), bytes, 32); + let mut data = Matrix::new(0u8, start_points.len(), bytes.value()); - let mut i = capacity; - for v in start_points.into_iter() { - let mut writer = primary.write(i).unwrap(); - layers::Set::into_bytes(&layer, v, writer.as_mut_slice()).unwrap(); - i += 1; + for (row, point) in std::iter::zip(data.row_iter_mut(), start_points.into_iter()) { + layers::Set::into_bytes(&layer, point, row).unwrap(); } + let primary = Primary::new( + capacity, + bytes, + 32, + data.as_view(), + ); + + // for v in start_points.into_iter() { + // let mut writer = primary.acquire(); + // layers::Set::into_bytes(&layer, v, writer.as_mut_slice()).unwrap(); + // } + Self { primary, layer } } - fn reader(&self) -> store::Reader<'_> { + fn reader(&self) -> Result, epoch::Unavailable> { self.primary.reader() } } @@ -109,9 +119,9 @@ where element: T, ) -> impl std::future::Future> + Send { let work = move || { - let mut write = self.primary.write(id.into_usize()).unwrap(); - >::into_bytes(&self.layer, element, write.as_mut_slice())?; - Ok(diskann::provider::NoopGuard::new(*id)) + let mut slot = self.primary.acquire(); + >::into_bytes(&self.layer, element, slot.as_mut_slice())?; + Ok(diskann::provider::NoopGuard::new(slot.slot())) }; ready(work) @@ -444,7 +454,7 @@ where query: T, ) -> ANNResult> { let distance = >::query_distance(&provider.layer, query)?; - let reader = provider.primary.reader(); + let reader = provider.primary.reader()?; let expand_beam = dispatch_expand_beam(reader.bytes()); let accessor = SearchAccessor { reader, @@ -468,17 +478,17 @@ where L: layers::Layer + layers::AsDistance, { type PruneAccessor<'a> = PruneAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; + type PruneAccessorError = ANNError; fn prune_accessor<'a>( &self, provider: &'a Provider, context: &'a Context, capacity: usize, - ) -> Result, diskann::error::Infallible> { + ) -> ANNResult> { let set = workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity); Ok(PruneAccessor { - reader: provider.primary.reader(), + reader: provider.primary.reader()?, set, distance: ::as_distance(&provider.layer), ids: AdjacencyList::new(), diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 617aa8ace..73594692f 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -9,8 +9,11 @@ use std::{ sync::{atomic::Ordering, Mutex}, }; +use diskann::utils::IntoUsize; +use diskann_utils::views::MatrixView; + use crate::{ - arbiter::{epoch, generation, Buffer, Freelist, Generation, RawSlice}, + arbiter::{epoch, freelist, generation, Buffer, Freelist, Generation, RawSlice}, num::{Align, Bytes}, Neighbors, }; @@ -23,89 +26,137 @@ pub struct Primary { // offering slightly better locality. buffer: Buffer, unpadded: Bytes, + + // The number of unfrozen points. This is guaranteed to be less than `buffer`. + unfrozen: usize, tags: Vec, freelist: Freelist, registry: epoch::Registry, neighbors: Neighbors, - drain: Mutex>, } const SPLIT: Bytes = Bytes::size_of::(); impl Primary { - pub fn new(entries: usize, bytes: Bytes, max_neighbors: usize) -> Self { + pub fn new( + entries: usize, + bytes: Bytes, + max_neighbors: usize, + init: MatrixView<'_, u8>, + ) -> Self { + assert_eq!(init.ncols(), bytes.value()); + assert_ne!(init.nrows(), 0); + let unpadded = bytes.checked_add(SPLIT).unwrap(); let padded_bytes = unpadded.checked_next_multiple_of(Bytes::CACHELINE).unwrap(); - Self { - buffer: Buffer::new(entries, padded_bytes, Align::_128).unwrap(), + let total = entries.checked_add(init.nrows()).unwrap(); + + let this = Self { + buffer: Buffer::new(total, padded_bytes, Align::_128).unwrap(), unpadded, - tags: repeat_n(Generation::AVAILABLE, entries) + unfrozen: entries, + tags: repeat_n(Generation::AVAILABLE, total) .map(|v| generation::Tag::new(v)) .collect(), + + // NOTE: The `Freelist` is initialized to `entries` and not `total` because + // we do not want it to release frozen IDs. freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), registry: epoch::Registry::new(), neighbors: Neighbors::new(entries, max_neighbors), - drain: Mutex::new(Vec::new()), + }; + + // Populate frozen points. + for (i, data) in init.row_iter().enumerate() { + let mut slot = this.slot((entries + i).try_into().unwrap()); + slot.as_mut_slice().copy_from_slice(data); + slot.freeze(); } - } - #[inline] - fn tag(&self, i: usize) -> Option> { - self.tags.get(i).map(|v| v.as_ref()) + this } pub fn capacity(&self) -> usize { - self.buffer.len() - } - - pub fn drain(&self) -> usize { - let mut drain = self.drain.lock().unwrap(); - let waiter = self.registry.waiting(); - let before = drain.len(); - drain.retain(|(i, generation)| { - if waiter < *generation { - self.freelist.push(*i); - false - } else { - true + self.buffer.len() - self.unfrozen + } + + pub fn try_drain(&self) -> Option { + fn release(tag: generation::Mut<'_>, kind: &'static str) { + // Relaxed ordering is sufficient as all readers/writers are synchronized on + // the central generation. + if let Err(got) = tag.try_set( + Generation::OWNED, + Generation::AVAILABLE, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + panic!( + "CONCURRENCY VIOLATION: {} - expected {} - got {}", + kind, + Generation::AVAILABLE, + got, + ); } - }); - before - drain.len() + } + + let drain = self.registry.try_advance()?; + let items = drain.len(); + for i in drain { + // We release the mirror before the main tag. The other direction would + // prematurely advertise availability. + let (mirror, _) = unsafe { self.data_unchecked(i.into_usize()) }; + release(mirror, "mirror"); + release(self.tags[i.into_usize()].as_mut(), "tag"); + self.freelist.push(i); + } + Some(items) } - pub fn reader(&self) -> Reader<'_> { - Reader { + pub fn reader(&self) -> Result, epoch::Unavailable> { + Ok(Reader { buffer: &self.buffer, unpadded: self.unpadded, neighbors: &self.neighbors, - epoch: self.registry.register(), + epoch: self.registry.register()?, + }) + } + + /// Attempt to acquire new slot for writing. + pub fn acquire(&self) -> Slot<'_> { + match self.freelist.pop() { + freelist::Id::Found(id) => self.slot(id), + freelist::Id::Scan => unimplemented!("fallback scan not implemented"), } } - pub(crate) fn write(&self, i: usize) -> Option> { - let tag = self.tag_mut(i)?; - match tag.try_set( + fn slot(&self, i: u32) -> Slot<'_> { + let tag = self.tag_mut(i.into_usize()).unwrap(); + if let Err(got) = tag.try_set( Generation::AVAILABLE, Generation::OWNED, - Ordering::Acquire, + Ordering::Relaxed, Ordering::Relaxed, ) { - Ok(_) => { - let (mirror, data) = unsafe { self.data(i) }; - let write = Write { - tag, - mirror, - generation: self.registry.generation(), - data, - }; - Some(write) - } - Err(_) => None, + panic!( + "CONCURRENCY VIOLATION: acquire - expected {} - got {}", + Generation::AVAILABLE, + got + ); + } + + let (mirror, data) = unsafe { self.data_unchecked(i.into_usize()) }; + Slot { + tag, + mirror, + generation: self.registry.generation(), + data, + slot: i, } } pub(crate) fn delete(&self, i: usize) -> bool { + let guard = self.registry.register().unwrap(); let tag = self.tag_mut(i).unwrap(); let current = tag.get(Ordering::Relaxed); @@ -121,20 +172,16 @@ impl Primary { match tag.try_set(current, owned, Ordering::Relaxed, Ordering::Relaxed) { Ok(current) => { // Set the metadata in the mirror as well. - let (mirror, _) = unsafe { self.data(i) }; + let (mirror, _) = unsafe { self.data_unchecked(i) }; mirror.set(owned, Ordering::Relaxed); - let wait_for = self.registry.advance(); - self.drain - .lock() - .unwrap() - .push((i.try_into().unwrap(), wait_for)); + guard.retire(i as u32); true } Err(_) => false, } } - unsafe fn data(&self, i: usize) -> (generation::Mut<'_>, RawSlice<'_>) { + unsafe fn data_unchecked(&self, i: usize) -> (generation::Mut<'_>, RawSlice<'_>) { let (mirror, data) = unsafe { self.buffer.get_unchecked(i) } .truncate(self.unpadded) .split(SPLIT); @@ -229,20 +276,32 @@ impl<'a> Reader<'a> { } #[derive(Debug)] -pub struct Write<'a> { +pub struct Slot<'a> { tag: generation::Mut<'a>, mirror: generation::Mut<'a>, generation: Generation, data: RawSlice<'a>, + slot: u32, } -impl<'a> Write<'a> { +impl<'a> Slot<'a> { pub fn as_mut_slice(&mut self) -> &mut [u8] { unsafe { self.data.as_mut_slice() } } + + /// Return the slot associated with this write. + pub fn slot(&self) -> u32 { + self.slot + } + + fn freeze(self) { + let me = std::mem::ManuallyDrop::new(self); + me.mirror.set(Generation::FROZEN, Ordering::Release); + me.tag.set(Generation::FROZEN, Ordering::Release); + } } -impl Drop for Write<'_> { +impl Drop for Slot<'_> { fn drop(&mut self) { self.mirror.set(self.generation, Ordering::Release); self.tag.set(self.generation, Ordering::Release); From 8782b6b73c5425f54df1955e88a454ba2b0628d5 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 13 Jun 2026 08:24:55 -0700 Subject: [PATCH 07/45] Checkpoint. --- Cargo.lock | 1 + diskann-inmem/Cargo.toml | 1 + diskann-inmem/src/arbiter/epoch.rs | 296 ++++++++++++++++++++++++-- diskann-inmem/src/arbiter/freelist.rs | 5 +- diskann-inmem/src/neighbors.rs | 42 ++-- diskann-inmem/src/provider.rs | 65 ++---- diskann-inmem/src/store.rs | 55 ++--- 7 files changed, 354 insertions(+), 111 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 97dd88072..a50cd7528 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -816,6 +816,7 @@ dependencies = [ "diskann", "diskann-utils", "diskann-vector", + "parking_lot", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 335a68f45..418b09859 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -13,6 +13,7 @@ crossbeam-queue = "0.3.12" diskann = { workspace = true } diskann-utils = { workspace = true, default-features = false } diskann-vector.workspace = true +parking_lot = "0.12.5" thiserror.workspace = true [lints] diff --git a/diskann-inmem/src/arbiter/epoch.rs b/diskann-inmem/src/arbiter/epoch.rs index f87fdc549..b12e00391 100644 --- a/diskann-inmem/src/arbiter/epoch.rs +++ b/diskann-inmem/src/arbiter/epoch.rs @@ -5,11 +5,11 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, - Mutex, TryLockError, }; use crossbeam_queue::SegQueue; use diskann::utils::IntoUsize; +use parking_lot::{Mutex, MutexGuard}; use crate::arbiter::{ generation::{Mut, Tag}, @@ -72,6 +72,10 @@ impl Registry { } } + pub fn capacity(&self) -> usize { + self.slots.len() + } + /// Return the current generation. /// /// This has [`Ordering::Acquire`] semantics. @@ -87,19 +91,21 @@ impl Registry { self.register_inner(NoDelay) } - fn register_inner(&self, _: T) -> Result, Unavailable> + #[inline] + fn register_inner(&self, mut delay: T) -> Result, Unavailable> where T: RegisterDelay, { // REGISTER CHECK let mut generation = self.generation(); let hint = self.hint.fetch_add(1, Ordering::Relaxed); - + delay.post_register_check(); let nslots = self.slots.len(); for i in 0..nslots { let slot = (hint + i) % nslots; let m = self.slots[slot].as_mut(); + delay.pre_cas(); if let Ok(_) = m.try_set( Generation::MAX, generation, @@ -111,7 +117,9 @@ impl Registry { // REGISTER FENCE: This fence is paired with "WAITING FENCE". // // See that comment for details. + delay.pre_fence(); std::sync::atomic::fence(Ordering::SeqCst); + delay.post_fence(); // REGISTER RECHECK let current = self.generation(); @@ -131,6 +139,8 @@ impl Registry { slot: m, retire: &self.retiring[queue(generation)], generation, + #[cfg(test)] + slot_index: slot, }); } } @@ -143,13 +153,13 @@ impl Registry { /// This uses a fast method that may be overly conservative. /// /// This is a synchronizing operation with [`Ordering::Acquire`] semantics. - pub fn waiting(&self) -> Generation { - self.waiting_inner(NoDelay) + pub fn can_advance(&self) -> bool { + self.can_advance_inner(&mut NoDelay).0 } - fn waiting_inner(&self, _: T) -> Generation + fn can_advance_inner(&self, delay: &mut T) -> (bool, Generation) where - T: WaitingDelay, + T: CanAdvanceDelay, { // WAITING FENCE: This is a very important part for the correctness of the algorithm. // @@ -179,10 +189,13 @@ impl Registry { // It's possible that thread B observes the CAS to "REGISTER CHECK", but since // thread A will monotonically decrease it before exiting, the value thread B // observes is conservative and not incorrect. + delay.pre_fence(); std::sync::atomic::fence(Ordering::SeqCst); + delay.post_fence(); // WAITING CHECK - let mut max = self.generation(); + let current = self.generation(); + let mut max = current; for s in self.slots.iter() { let generation = s.as_ref().get(Ordering::Relaxed); @@ -193,14 +206,14 @@ impl Registry { // This synchronizes with all the guard's `Release`s. std::sync::atomic::fence(Ordering::Acquire); - max + (max == current, max) } pub fn try_advance(&self) -> Option> { self.try_advance_inner(NoDelay) } - fn try_advance_inner(&self, _: T) -> Option> + fn try_advance_inner(&self, mut delay: T) -> Option> where T: TryAdvanceDelay, { @@ -213,18 +226,13 @@ impl Registry { // // We intentionally ignore lock-poison since we expect the guarded queue to be // robust with respect to panics. - let drain = match self.drain.try_lock() { - Ok(drain) => drain, - Err(TryLockError::Poisoned(drain)) => drain.into_inner(), - Err(TryLockError::WouldBlock) => return None, - }; + let drain = self.drain.try_lock()?; - let waiting = self.waiting(); - let current = self.generation.as_ref().get(Ordering::Relaxed); + let (can_advance, current) = self.can_advance_inner(&mut delay); // All waiters belong to the current generation. Therefore, it is safe to release // the old array queue - if waiting == current { + if can_advance { // We are safe to use a `fetch_sub` here because `drain` is ensuring exclusivity // of the access. // @@ -240,6 +248,23 @@ impl Registry { None } } + + #[cfg(test)] + fn assert_no_workers(&self) { + for s in self.slots.iter() { + assert_eq!(s.as_ref().get(Ordering::Relaxed), Generation::MAX); + } + } + + #[cfg(test)] + fn snapshot(&self) -> Vec { + self.slots.iter().map(|s| s.as_ref().get(Ordering::Relaxed)).collect() + } + + #[cfg(test)] + fn waiting(&self) -> Generation { + self.can_advance_inner(&mut NoDelay).1 + } } #[derive(Debug)] @@ -247,6 +272,9 @@ pub struct Guard<'a> { slot: Mut<'a>, retire: &'a SegQueue, generation: Generation, + + #[cfg(test)] + slot_index: usize, } impl Guard<'_> { @@ -282,7 +310,7 @@ impl Drop for Guard<'_> { #[derive(Debug)] pub struct Drain<'a> { queue: &'a SegQueue, - drain: std::sync::MutexGuard<'a, ()>, + drain: MutexGuard<'a, ()>, } impl Drain<'_> { @@ -346,15 +374,23 @@ impl From for diskann::ANNError { #[derive(Debug)] struct NoDelay; -trait RegisterDelay {} +trait RegisterDelay { + fn post_register_check(&mut self) {} + fn pre_cas(&mut self) {} + fn pre_fence(&mut self) {} + fn post_fence(&mut self) {} +} impl RegisterDelay for NoDelay {} -trait WaitingDelay {} +trait CanAdvanceDelay { + fn pre_fence(&mut self) {} + fn post_fence(&mut self) {} +} -impl WaitingDelay for NoDelay {} +impl CanAdvanceDelay for NoDelay {} -trait TryAdvanceDelay {} +trait TryAdvanceDelay: CanAdvanceDelay {} impl TryAdvanceDelay for NoDelay {} @@ -365,4 +401,218 @@ impl TryAdvanceDelay for NoDelay {} #[cfg(test)] mod tests { use super::*; + + use std::{thread, sync::mpsc}; + + fn channel() -> (Sender, Receiver) { + let (s, r) = mpsc::channel(); + (Sender(s), Receiver(r)) + } + + struct Sender(mpsc::Sender<()>); + + impl Sender { + fn send(&self) { + self.0.send(()).unwrap() + } + } + + struct Receiver(mpsc::Receiver<()>); + + impl Receiver { + fn recv(&self) { + self.0.recv().unwrap(); + } + } + + // This test ensures that two threads racing on `hint` will correctly resolve themselves + // when claiming a slot. + #[test] + fn test_cas_race() { + let (a_sender, a_receiver) = channel(); + let (b_sender, b_receiver) = channel(); + let (a_done_sender, a_done_receiver) = channel(); + + let delay = TestRegisterDelay::default() + .with_post_register_check(move || { + b_sender.send(); + a_receiver.recv(); + }); + + let registry = Registry::with_capacity(2); + assert_eq!(registry.capacity(), 2); + + thread::scope(|s| { + let registry_ref = ®istry; + + // Thread A + s.spawn(move || { + let g = registry_ref.register_inner(delay).unwrap(); + assert_eq!(g.slot_index, 1); + a_done_sender.send(); + }); + + // Thread B + s.spawn(move || { + // wait for Thread A to let us know it's in the delay slot. + b_receiver.recv(); + { + let g = registry_ref.register_inner(NoDelay).unwrap(); + assert_eq!(g.slot_index, 1); + } + let g = registry_ref.register_inner(NoDelay).unwrap(); + assert_eq!(g.slot_index, 0); + a_sender.send(); + a_done_receiver.recv(); + }); + }); + + registry.assert_no_workers(); + } + + #[test] + fn test_register_wait() { + // This tests the case where a thread enters registration, reads a generation, then + // sleeps for several generation advances. It ensures that the thread recovers + // properly. + let (ready_sender, ready_receiver) = channel(); + let (step0_sender, step0_receiver) = channel(); + + let delay = TestRegisterDelay::default() + .with_post_register_check(move || { + ready_sender.send(); + step0_receiver.recv(); + }) + .with_pre_fence(move || { + }); + + let registry = Registry::with_capacity(2); + + thread::scope(|s| { + let registry_ref = ®istry; + + let handle = s.spawn(move || registry_ref.register_inner(delay).unwrap()); + + // Wait for the spawned thread to reach the critical section. + ready_receiver.recv(); + + assert_eq!(registry.waiting(), Generation::MAX.sub(1)); + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + assert_eq!(registry.generation(), Generation::MAX.sub(2)); + } + + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + assert_eq!(registry.generation(), Generation::MAX.sub(3)); + } + + step0_sender.send(); + + let expected = Generation::MAX.sub(3); + + // The generation should be the last set one - even though this thread was + // parked during the transition. + let r = handle.join().unwrap(); + assert_eq!(r.generation(), expected); + assert_eq!(registry.waiting(), expected); + }); + + registry.assert_no_workers(); + } + + //-------------// + // Test Delays // + //-------------// + + macro_rules! tester { + ($struct:ident, $trait:ident, $($with:ident => $f:ident),* $(,)?) => { + #[derive(Default)] + struct $struct<'a> { + $($f: Option>,)* + } + + impl<'a> $struct<'a> { + $( + fn $with(mut self, f: F) -> Self + where + F: FnMut() + Send + 'a + { + self.$f = Some(Box::new(f)); + self + } + )* + } + + impl $trait for $struct<'_> { + $( + fn $f(&mut self) { + if let Some(f) = self.$f.as_mut() { + f() + } + } + )* + } + } + } + + tester! { + TestRegisterDelay, + RegisterDelay, + with_post_register_check => post_register_check, + with_pre_cas => pre_cas, + with_pre_fenct => pre_fence, + with_post_fence => post_fence, + } + + // #[derive(Default)] + // struct TestRegisterDelay<'a> { + // post_register_check: Option<&'a mut dyn FnMut()>, + // pre_cas: Option<&'a mut dyn FnMut()>, + // pre_fence: Option<&'a mut dyn FnMut()>, + // post_fence: Option<&'a mut dyn FnMut()>, + // } + + // macro_rules! builder { + // ($f:ident, $field:ident) => { + // fn $f(mut self, f: &'a mut dyn FnMut()) -> Self { + // self.$field = Some(f); + // self + // } + // } + // } + + // macro_rules! forward { + // ($f:ident) => { + // fn $f(&mut self) { + // if let Some(f) = self.$f.as_mut() { + // f() + // } + // } + // } + // } + + // impl<'a> TestRegisterDelay<'a> { + // builder!(with_post_register_check, post_register_check); + // builder!(with_pre_cas, pre_cas); + // builder!(with_pre_fence, pre_fence); + // builder!(with_post_fence, post_fence); + // } + + // impl RegisterDelay for TestRegisterDelay<'_> { + // forward!(post_register_check); + // forward!(pre_cas); + // forward!(pre_fence); + // forward!(post_fence); + // } + + // struct CanAdvanceDelay; + + // impl CanAdvanceDelay for TestWaitingDelay {} + + // struct TestTryAdvanceDelay; + + // impl TryAdvanceDelay for TestTryAdvanceDelay {} } diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs index 6dfc85371..f52fb17d4 100644 --- a/diskann-inmem/src/arbiter/freelist.rs +++ b/diskann-inmem/src/arbiter/freelist.rs @@ -5,10 +5,7 @@ use std::{ num::NonZeroU32, - sync::{ - atomic::{AtomicU32, AtomicUsize, Ordering}, - Mutex, - }, + sync::atomic::{AtomicU32, AtomicUsize, Ordering}, }; use crossbeam_queue::ArrayQueue; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index 9ca3770d8..620990129 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -3,9 +3,8 @@ * Licensed under the MIT license. */ -use std::sync::RwLock; - use diskann::{graph::AdjacencyList, utils::IntoUsize}; +use parking_lot::{RawRwLock, RwLock, RwLockWriteGuard}; use thiserror::Error; use crate::{ @@ -54,7 +53,7 @@ impl Neighbors { let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }; - let _guard = dismiss_poison(lock.read()); + let _guard = lock.read(); // SAFETY: By consruction `self.buffer` has the same number of entries as // `self.locks` and we have already checked that `i` is in-bounds there. @@ -88,7 +87,7 @@ impl Neighbors { } unsafe fn lock_unchecked(&self, i: usize) -> Lock<'_> { - let lock = dismiss_poison(unsafe { self.locks.get_unchecked(lock_index(i)) }.write()); + let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }.write(); // SAFETY: By consruction `self.buffer` has the same number of entries as // `self.locks` and we have already checked that `i` is in-bounds there. @@ -140,17 +139,7 @@ pub enum SetError { TooLong(TooLong), } -// We carefully guard where locks are acquired in this function, so that panicking while -// holding a lock won't happen and if it does, we know we're still in decent shape. -fn dismiss_poison(r: std::sync::LockResult) -> T { - match r { - Ok(v) => v, - Err(poison) => poison.into_inner(), - } -} - /// A locked adjacency list to implement atomic read-modify-write operations. -#[derive(Debug)] pub struct Lock<'a> { // The raw adjacency list with the actual length stored as the first element. // @@ -160,7 +149,7 @@ pub struct Lock<'a> { raw: &'a mut [u32], // VERY IMPORTANT: `lock` has to be **after** `raw` because `lock` is guarding `raw` // and thus must be dropped **after** `raw`. - lock: std::sync::RwLockWriteGuard<'a, ()>, + lock: RwLockWriteGuard<'a, ()>, } impl Lock<'_> { @@ -200,6 +189,20 @@ impl Lock<'_> { Ok(()) } + pub fn append(self, neighbors: &[u32]) -> Result<(), TooLong> { + let len = self.len(); + let newlen = len.saturating_add(neighbors.len()); + + if newlen > self.capacity() { + return Err(TooLong); + } + + unsafe { self.raw.get_unchecked_mut(len..newlen) }.copy_from_slice(neighbors); + *unsafe { self.raw.get_unchecked_mut(0) } = newlen as u32; + Ok(()) + // `self.raw` is dropped first, then `self.lock` which was guarding it. + } + unsafe fn write_unchecked(self, neighbors: &[u32]) { let len = neighbors.len(); debug_assert!(len <= self.capacity()); @@ -211,6 +214,15 @@ impl Lock<'_> { } } +impl std::fmt::Debug for Lock<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Lock") + .field("raw", &self.raw) + .field("lock", &()) + .finish() + } +} + #[derive(Debug, Clone, Copy, Error)] #[error("too long")] pub struct TooLong; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 47245d5d8..3d13e223a 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -13,7 +13,7 @@ use diskann::{ utils::IntoUsize, ANNError, ANNErrorKind, ANNResult, }; -use diskann_utils::{views::Matrix, future::{AsyncFriendly, SendFuture}}; +use diskann_utils::views::Matrix; use crate::{ arbiter::epoch, @@ -42,18 +42,7 @@ impl Provider { layers::Set::into_bytes(&layer, point, row).unwrap(); } - let primary = Primary::new( - capacity, - bytes, - 32, - data.as_view(), - ); - - // for v in start_points.into_iter() { - // let mut writer = primary.acquire(); - // layers::Set::into_bytes(&layer, v, writer.as_mut_slice()).unwrap(); - // } - + let primary = Primary::new(capacity, bytes, 32, data.as_view()); Self { primary, layer } } @@ -83,7 +72,7 @@ where fn to_internal_id( &self, - context: &Self::Context, + _context: &Self::Context, gid: &Self::ExternalId, ) -> Result { Ok(*gid) @@ -92,7 +81,7 @@ where /// Translate an internal id to its corresponding external id. fn to_external_id( &self, - context: &Self::Context, + _context: &Self::Context, id: Self::InternalId, ) -> Result { Ok(id) @@ -114,12 +103,12 @@ where fn set_element( &self, - context: &Self::Context, + _context: &Self::Context, id: &Self::ExternalId, element: T, ) -> impl std::future::Future> + Send { let work = move || { - let mut slot = self.primary.acquire(); + let mut slot = self.primary.acquire().unwrap(); >::into_bytes(&self.layer, element, slot.as_mut_slice())?; Ok(diskann::provider::NoopGuard::new(slot.slot())) }; @@ -328,7 +317,6 @@ unsafe fn expand_beam_inner( #[derive(Debug)] pub struct PruneAccessor<'a> { reader: store::Reader<'a>, - set: workingset::Map>, distance: &'a dyn Distance, ids: AdjacencyList, } @@ -346,7 +334,7 @@ impl glue::PruneAccessor for PruneAccessor<'_> { type ElementRef<'a> = &'a [u8]; type View<'a> - = workingset::map::View<'a, u32, Box<[u8]>> + = &'a Self where Self: 'a; @@ -361,19 +349,12 @@ impl glue::PruneAccessor for PruneAccessor<'_> { async fn fill<'a, Itr>( &'a mut self, - itr: Itr, + _itr: Itr, ) -> ANNResult<(Self::View<'a>, Self::Distance<'a>)> where Itr: ExactSizeIterator + Clone + Send + Sync, { - let v = self - .set - .fill(itr, |i| -> Result<_, Infallible> { - Ok(self.reader.read(i.into_usize()).map(|v| v.into())) - }) - .unwrap(); - - Ok((v, &*self.distance)) + Ok((self, &*self.distance)) } } @@ -416,16 +397,7 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { neighbors: &[Self::Id], ) -> impl std::future::Future> + Send { let work = move || -> ANNResult<()> { - let current = self.reader.neighbors().lock(id.into_usize()).unwrap(); - - // Copy out the current neighbors. - let mut resize = self.ids.resize(current.len()); - resize.copy_from_slice(current.as_slice()); - resize.finish(current.len()); - - // Append the new neighbors. - self.ids.extend_from_slice(neighbors); - current.write(&self.ids).unwrap(); + self.reader.neighbors().lock(id.into_usize()).unwrap().append(neighbors).unwrap(); Ok(()) }; @@ -433,6 +405,17 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { } } +impl workingset::View for &PruneAccessor<'_> { + type ElementRef<'a> = &'a [u8]; + type Element<'a> + = &'a [u8] + where + Self: 'a; + fn get(&self, id: u32) -> Option<&[u8]> { + self.reader.read(id.into_usize()) + } +} + //////////////// // Strategies // //////////////// @@ -450,7 +433,7 @@ where fn search_accessor( &'a self, provider: &'a Provider, - context: &'a Context, + _context: &'a Context, query: T, ) -> ANNResult> { let distance = >::query_distance(&provider.layer, query)?; @@ -483,13 +466,11 @@ where fn prune_accessor<'a>( &self, provider: &'a Provider, - context: &'a Context, + _context: &'a Context, capacity: usize, ) -> ANNResult> { - let set = workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity); Ok(PruneAccessor { reader: provider.primary.reader()?, - set, distance: ::as_distance(&provider.layer), ids: AdjacencyList::new(), }) diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 73594692f..21fdcc80d 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -3,11 +3,7 @@ * Licensed under the MIT license. */ -use std::{ - iter::repeat_n, - num::NonZeroU32, - sync::{atomic::Ordering, Mutex}, -}; +use std::{iter::repeat_n, num::NonZeroU32, sync::atomic::Ordering}; use diskann::utils::IntoUsize; use diskann_utils::views::MatrixView; @@ -36,6 +32,7 @@ pub struct Primary { } const SPLIT: Bytes = Bytes::size_of::(); +const RETRY_LIMIT: usize = 20; impl Primary { pub fn new( @@ -69,7 +66,7 @@ impl Primary { // Populate frozen points. for (i, data) in init.row_iter().enumerate() { - let mut slot = this.slot((entries + i).try_into().unwrap()); + let mut slot = this.slot((entries + i).try_into().unwrap()).unwrap(); slot.as_mut_slice().copy_from_slice(data); slot.freeze(); } @@ -123,35 +120,39 @@ impl Primary { } /// Attempt to acquire new slot for writing. - pub fn acquire(&self) -> Slot<'_> { - match self.freelist.pop() { - freelist::Id::Found(id) => self.slot(id), - freelist::Id::Scan => unimplemented!("fallback scan not implemented"), + pub fn acquire(&self) -> Option> { + for _ in 0..RETRY_LIMIT { + match self.freelist.pop() { + freelist::Id::Found(id) => { + if let Some(slot) = self.slot(id) { + return Some(slot); + } + } + freelist::Id::Scan => unimplemented!("fallback scan not implemented"), + } } + None } - fn slot(&self, i: u32) -> Slot<'_> { + fn slot(&self, i: u32) -> Option> { let tag = self.tag_mut(i.into_usize()).unwrap(); - if let Err(got) = tag.try_set( + match tag.try_set( Generation::AVAILABLE, Generation::OWNED, Ordering::Relaxed, Ordering::Relaxed, ) { - panic!( - "CONCURRENCY VIOLATION: acquire - expected {} - got {}", - Generation::AVAILABLE, - got - ); - } - - let (mirror, data) = unsafe { self.data_unchecked(i.into_usize()) }; - Slot { - tag, - mirror, - generation: self.registry.generation(), - data, - slot: i, + Ok(_) => { + let (mirror, data) = unsafe { self.data_unchecked(i.into_usize()) }; + Some(Slot { + tag, + mirror, + generation: self.registry.generation(), + data, + slot: i, + }) + } + Err(_) => None, } } @@ -170,7 +171,7 @@ impl Primary { // Even if we make this change, we can't access any data until we wait for the // epoch to be bumped. As such, relaxed semantics are fine. match tag.try_set(current, owned, Ordering::Relaxed, Ordering::Relaxed) { - Ok(current) => { + Ok(_) => { // Set the metadata in the mirror as well. let (mirror, _) = unsafe { self.data_unchecked(i) }; mirror.set(owned, Ordering::Relaxed); From ab22b9028a049184a17876620c660268cfd9fc3a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 13 Jun 2026 10:36:44 -0700 Subject: [PATCH 08/45] Checkpoint. --- diskann-inmem/Cargo.toml | 2 +- diskann-inmem/src/arbiter/epoch.rs | 198 +++++++++++++++++++------- diskann-inmem/src/arbiter/freelist.rs | 62 ++++++-- diskann-inmem/src/layers/full.rs | 2 +- diskann-inmem/src/provider.rs | 12 +- diskann-inmem/src/store.rs | 60 +++++++- 6 files changed, 258 insertions(+), 78 deletions(-) diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 418b09859..ad22cdbd1 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -5,7 +5,7 @@ description.workspace = true authors.workspace = true repository.workspace = true license.workspace = true -edition.workspace = true +edition = "2024" [dependencies] bytemuck = { workspace = true, features = ["must_cast"] } diff --git a/diskann-inmem/src/arbiter/epoch.rs b/diskann-inmem/src/arbiter/epoch.rs index b12e00391..26862adde 100644 --- a/diskann-inmem/src/arbiter/epoch.rs +++ b/diskann-inmem/src/arbiter/epoch.rs @@ -3,17 +3,15 @@ * Licensed under the MIT license. */ -use std::sync::{ - atomic::{AtomicUsize, Ordering}, -}; +use std::sync::atomic::{AtomicUsize, Ordering}; use crossbeam_queue::SegQueue; use diskann::utils::IntoUsize; use parking_lot::{Mutex, MutexGuard}; use crate::arbiter::{ - generation::{Mut, Tag}, Generation, + generation::{Mut, Tag}, }; const CAPACITY: usize = 256; @@ -112,6 +110,7 @@ impl Registry { Ordering::Relaxed, Ordering::Relaxed, ) { + delay.post_cas(); let mut reset = false; loop { // REGISTER FENCE: This fence is paired with "WAITING FENCE". @@ -258,7 +257,10 @@ impl Registry { #[cfg(test)] fn snapshot(&self) -> Vec { - self.slots.iter().map(|s| s.as_ref().get(Ordering::Relaxed)).collect() + self.slots + .iter() + .map(|s| s.as_ref().get(Ordering::Relaxed)) + .collect() } #[cfg(test)] @@ -324,7 +326,7 @@ impl Drain<'_> { } pub fn is_empty(&self) -> bool { - self.len() == 0 + self.queue.is_empty() } } @@ -377,6 +379,7 @@ struct NoDelay; trait RegisterDelay { fn post_register_check(&mut self) {} fn pre_cas(&mut self) {} + fn post_cas(&mut self) {} fn pre_fence(&mut self) {} fn post_fence(&mut self) {} } @@ -402,26 +405,85 @@ impl TryAdvanceDelay for NoDelay {} mod tests { use super::*; - use std::{thread, sync::mpsc}; + use std::{sync::Arc, thread}; + + use parking_lot::Condvar; + + #[derive(Clone)] + struct Sequencer(Arc); - fn channel() -> (Sender, Receiver) { - let (s, r) = mpsc::channel(); - (Sender(s), Receiver(r)) + struct SequencerInner { + state: Mutex, + condvar: Condvar, } - struct Sender(mpsc::Sender<()>); + #[derive(Debug, Clone, Copy, PartialEq)] + enum State { + Empty, + Parked(usize), + Released(usize), + } - impl Sender { - fn send(&self) { - self.0.send(()).unwrap() + impl Sequencer { + fn new() -> Self { + Self(Arc::new(SequencerInner { + state: Mutex::new(State::Empty), + condvar: Condvar::new(), + })) } - } - struct Receiver(mpsc::Receiver<()>); + fn wait_for(&self, stage: usize) { + let mut state = self.0.state.lock(); + if stage == 0 { + assert_eq!(*state, State::Empty) + } else { + assert_eq!(*state, State::Released(stage - 1)) + } - impl Receiver { - fn recv(&self) { - self.0.recv().unwrap(); + *state = State::Parked(stage); + self.0.condvar.notify_all(); + self.0 + .condvar + .wait_while(&mut state, move |s| *s != State::Released(stage)); + } + + fn advance_past(&self, stage: usize) { + let mut state = self.0.state.lock(); + self.0 + .condvar + .wait_while(&mut state, move |s| Self::check_release(*s, stage)); + *state = State::Released(stage); + self.0.condvar.notify_all(); + } + + fn until_waiting_for(&self, stage: usize) { + let mut state = self.0.state.lock(); + if *state != State::Parked(stage) { + self.0 + .condvar + .wait_while(&mut state, move |s| Self::check_release(*s, stage)) + } + } + + fn check_release(current: State, stage: usize) -> bool { + match current { + State::Empty => { + assert_eq!(stage, 0); + true + } + State::Released(s) => { + if s + 1 != stage { + panic!("observed {:?} while releasing stage {}", current, stage); + } + true + } + State::Parked(s) => { + if s != stage { + panic!("observed {:?} while releasing stage {}", current, stage) + } + false + } + } } } @@ -429,41 +491,33 @@ mod tests { // when claiming a slot. #[test] fn test_cas_race() { - let (a_sender, a_receiver) = channel(); - let (b_sender, b_receiver) = channel(); - let (a_done_sender, a_done_receiver) = channel(); + let seq = Sequencer::new(); - let delay = TestRegisterDelay::default() - .with_post_register_check(move || { - b_sender.send(); - a_receiver.recv(); - }); + let delay = TestRegisterDelay::default().with_post_register_check(|| seq.wait_for(0)); let registry = Registry::with_capacity(2); assert_eq!(registry.capacity(), 2); thread::scope(|s| { - let registry_ref = ®istry; - // Thread A - s.spawn(move || { - let g = registry_ref.register_inner(delay).unwrap(); + s.spawn(|| { + let g = registry.register_inner(delay).unwrap(); assert_eq!(g.slot_index, 1); - a_done_sender.send(); + seq.wait_for(1); }); // Thread B - s.spawn(move || { - // wait for Thread A to let us know it's in the delay slot. - b_receiver.recv(); + s.spawn(|| { + // wait for Thread A to reach the delay point. + seq.until_waiting_for(0); { - let g = registry_ref.register_inner(NoDelay).unwrap(); + let g = registry.register_inner(NoDelay).unwrap(); assert_eq!(g.slot_index, 1); } - let g = registry_ref.register_inner(NoDelay).unwrap(); + let g = registry.register_inner(NoDelay).unwrap(); assert_eq!(g.slot_index, 0); - a_sender.send(); - a_done_receiver.recv(); + seq.advance_past(0); + seq.advance_past(1); }); }); @@ -473,28 +527,30 @@ mod tests { #[test] fn test_register_wait() { // This tests the case where a thread enters registration, reads a generation, then - // sleeps for several generation advances. It ensures that the thread recovers - // properly. - let (ready_sender, ready_receiver) = channel(); - let (step0_sender, step0_receiver) = channel(); + // sleeps for several generation advances. It ensures that the thread recovers properly. + let seq = Sequencer::new(); + let mut loop_count = 0; let delay = TestRegisterDelay::default() - .with_post_register_check(move || { - ready_sender.send(); - step0_receiver.recv(); - }) - .with_pre_fence(move || { - }); + .with_post_register_check(|| seq.wait_for(0)) + .with_post_cas(|| seq.wait_for(1)) + .with_pre_fence(|| loop_count += 1); let registry = Registry::with_capacity(2); thread::scope(|s| { - let registry_ref = ®istry; - - let handle = s.spawn(move || registry_ref.register_inner(delay).unwrap()); + let handle = s.spawn(|| { + let guard = registry.register_inner(delay).unwrap(); + + // Since we hit the CAS loop - this serves as a sanity check that we have + // the correct drain buffer. + guard.retire(10); + guard.retire_all([1, 2, 3]); + guard + }); // Wait for the spawned thread to reach the critical section. - ready_receiver.recv(); + seq.until_waiting_for(0); assert_eq!(registry.waiting(), Generation::MAX.sub(1)); { @@ -509,7 +565,19 @@ mod tests { assert_eq!(registry.generation(), Generation::MAX.sub(3)); } - step0_sender.send(); + // We allow the registering thread to make it past the CAS. + // + // We pause it again because we want to verify that it registers an old generation. + seq.advance_past(0); + seq.until_waiting_for(1); + let (can_advance, waiter) = registry.can_advance_inner(&mut NoDelay); + assert!(!can_advance); + assert_eq!( + waiter, + Generation::MAX.sub(1), + "waiting thread registers an older generation before observing the change" + ); + seq.advance_past(1); let expected = Generation::MAX.sub(3); @@ -520,7 +588,26 @@ mod tests { assert_eq!(registry.waiting(), expected); }); + assert_eq!( + loop_count, 2, + "the registering thread should have looped to update its generation" + ); + registry.assert_no_workers(); + + // Verify that we reclaim the ID flushed by the registering thread. + // + // This requires two epoch advancements. + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + } + + { + let drain = registry.try_advance().unwrap(); + let ids: Vec<_> = drain.collect(); + assert_eq!(ids, &[10, 1, 2, 3]); + } } //-------------// @@ -563,7 +650,8 @@ mod tests { RegisterDelay, with_post_register_check => post_register_check, with_pre_cas => pre_cas, - with_pre_fenct => pre_fence, + with_post_cas => post_cas, + with_pre_fence => pre_fence, with_post_fence => post_fence, } diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs index f52fb17d4..125cbdc0d 100644 --- a/diskann-inmem/src/arbiter/freelist.rs +++ b/diskann-inmem/src/arbiter/freelist.rs @@ -11,19 +11,20 @@ use std::{ use crossbeam_queue::ArrayQueue; use diskann::utils::IntoUsize; +const SCAN_SIZE: u32 = 16; + #[derive(Debug)] pub struct Freelist { recycled: ArrayQueue, - /// An (approximate) number of recycled IDs that exist outside the freelist. - orphaned: AtomicUsize, - /// The highest ID the freelist manages. This is used when in "append" to determine the /// maximum ID we can return this way. max: u32, /// The number of "unallocated" IDs remaining. current: AtomicU32, + + scan_cursor: AtomicU32, } #[derive(Debug, Clone, Copy)] @@ -36,9 +37,9 @@ impl Freelist { pub fn new(max: u32, capacity: NonZeroU32) -> Self { Self { recycled: ArrayQueue::new(capacity.get().into_usize()), - orphaned: AtomicUsize::new(0), max, current: AtomicU32::new(0), + scan_cursor: AtomicU32::new(0), } } @@ -67,16 +68,26 @@ impl Freelist { Id::Scan } + pub fn pop_recycled(&self) -> Option { + self.recycled.pop() + } + + pub fn scan(&self) -> Scan { + let current = self.scan_cursor.fetch_add(SCAN_SIZE, Ordering::Relaxed) % self.max; + Scan { + current, + max: self.max, + len: SCAN_SIZE.into_usize() + } + } + /// Attempt to push `id` into the recycled list. Return `true` if `id` was /// inserted. If `false` is returned, it is likely because the internal recycle /// buffer is full. pub fn push(&self, id: u32) -> bool { match self.recycled.push(id) { Ok(()) => true, - Err(_) => { - self.orphaned.fetch_add(1, Ordering::Relaxed); - false - } + Err(_) => false, } } @@ -90,11 +101,6 @@ impl Freelist { let mut count = 0; while let Some(id) = itr.next() { if let Err(_) = self.recycled.push(id) { - let (lower, _) = itr.size_hint(); - - // Add 1 to "put back" the last ID. - self.orphaned - .fetch_add(lower.saturating_add(1), Ordering::Relaxed); break; } else { count += 1; @@ -112,3 +118,33 @@ impl Freelist { self.recycled.capacity() } } + +#[derive(Debug)] +pub struct Scan { + current: u32, + max: u32, + len: usize, +} + +impl Iterator for Scan { + type Item = u32; + fn next(&mut self) -> Option { + if self.len == 0 { + None + } else { + let mut i = self.current; + self.current += 1; + self.len -= 1; + if self.current == self.max { + self.current -= self.max; + } + Some(i) + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.len)) + } +} + +impl ExactSizeIterator for Scan {} diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 8aea3eed1..d38ea0229 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -5,8 +5,8 @@ use diskann::{ANNError, ANNResult}; use diskann_vector::{ - distance::{self, DistanceProvider, Metric}, UnalignedSlice, + distance::{self, DistanceProvider, Metric}, }; use thiserror::Error; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 3d13e223a..977081243 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -4,14 +4,15 @@ */ use diskann::{ + ANNError, ANNErrorKind, ANNResult, error::Infallible, graph::{ + AdjacencyList, glue::{self, HybridPredicate}, - workingset, AdjacencyList, + workingset, }, provider, utils::IntoUsize, - ANNError, ANNErrorKind, ANNResult, }; use diskann_utils::views::Matrix; @@ -397,7 +398,12 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { neighbors: &[Self::Id], ) -> impl std::future::Future> + Send { let work = move || -> ANNResult<()> { - self.reader.neighbors().lock(id.into_usize()).unwrap().append(neighbors).unwrap(); + self.reader + .neighbors() + .lock(id.into_usize()) + .unwrap() + .append(neighbors) + .unwrap(); Ok(()) }; diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 21fdcc80d..e28bcaa93 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -9,9 +9,9 @@ use diskann::utils::IntoUsize; use diskann_utils::views::MatrixView; use crate::{ - arbiter::{epoch, freelist, generation, Buffer, Freelist, Generation, RawSlice}, - num::{Align, Bytes}, Neighbors, + arbiter::{Buffer, Freelist, Generation, RawSlice, epoch, freelist, generation}, + num::{Align, Bytes}, }; #[derive(Debug)] @@ -128,7 +128,53 @@ impl Primary { return Some(slot); } } - freelist::Id::Scan => unimplemented!("fallback scan not implemented"), + freelist::Id::Scan => match self.scan_acquire() { + Some(slot) => return Some(slot), + None => { self.try_drain(); }, + }, + } + } + None + } + + fn scan_acquire(&self) -> Option> { + // This is potentially quite slow - but stop if we've scanned the entire range + // without finding anything. + let mut remaining = self.unfrozen; + let mut chunks_since_freelist_check = 0; + let mut acquired: Option> = None; + + while remaining != 0 { + let chunk = self.freelist.scan(); + remaining = remaining.saturating_sub(chunk.len()); + + for slot in chunk { + let tag = self.tag_mut(slot.into_usize()).unwrap(); + + // If this slot is available and we haven't claimed a slot yet, try to + // claim it. Otherwise, continue with the scan to partially repopulate the + // freelist for other threads. + if tag.get(Ordering::Relaxed) == Generation::AVAILABLE { + if acquired.is_none() { + acquired = unsafe { self.try_acquire(tag, slot) }; + } else { + self.freelist.push(slot); + } + } + } + + if acquired.is_some() { + return acquired; + } + + chunks_since_freelist_check += 1; + if chunks_since_freelist_check == 4 { + if let Some(id) = self.freelist.pop_recycled() + && let Some(slot) = self.slot(id) + { + return Some(slot); + } + chunks_since_freelist_check = 0; } } None @@ -136,6 +182,10 @@ impl Primary { fn slot(&self, i: u32) -> Option> { let tag = self.tag_mut(i.into_usize()).unwrap(); + unsafe { self.try_acquire(tag, i) } + } + + unsafe fn try_acquire<'a>(&'a self, tag: generation::Mut<'a>, slot: u32) -> Option> { match tag.try_set( Generation::AVAILABLE, Generation::OWNED, @@ -143,13 +193,13 @@ impl Primary { Ordering::Relaxed, ) { Ok(_) => { - let (mirror, data) = unsafe { self.data_unchecked(i.into_usize()) }; + let (mirror, data) = unsafe { self.data_unchecked(slot.into_usize()) }; Some(Slot { tag, mirror, generation: self.registry.generation(), data, - slot: i, + slot, }) } Err(_) => None, From b3a2a2a0cb38b727c3795e09ed5bd56fad542812 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 13 Jun 2026 16:12:35 -0700 Subject: [PATCH 09/45] ID translation. --- Cargo.lock | 1 + diskann-benchmark/src/index/inmem2.rs | 1 + diskann-inmem/Cargo.toml | 1 + diskann-inmem/src/arbiter/freelist.rs | 6 +- diskann-inmem/src/ids/mod.rs | 7 ++ diskann-inmem/src/ids/sharded.rs | 150 ++++++++++++++++++++++++++ diskann-inmem/src/lib.rs | 3 +- diskann-inmem/src/neighbors.rs | 11 +- diskann-inmem/src/provider.rs | 123 +++++++++++++++++---- diskann-inmem/src/store.rs | 17 ++- 10 files changed, 284 insertions(+), 36 deletions(-) create mode 100644 diskann-inmem/src/ids/mod.rs create mode 100644 diskann-inmem/src/ids/sharded.rs diff --git a/Cargo.lock b/Cargo.lock index a50cd7528..4cec4ce43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -813,6 +813,7 @@ version = "0.54.0" dependencies = [ "bytemuck", "crossbeam-queue", + "dashmap", "diskann", "diskann-utils", "diskann-vector", diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 2587bcadd..8f8448120 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -273,3 +273,4 @@ impl Benchmark for Inmem2 { Ok(()) } } + diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index ad22cdbd1..125222513 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -10,6 +10,7 @@ edition = "2024" [dependencies] bytemuck = { workspace = true, features = ["must_cast"] } crossbeam-queue = "0.3.12" +dashmap.workspace = true diskann = { workspace = true } diskann-utils = { workspace = true, default-features = false } diskann-vector.workspace = true diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs index 125cbdc0d..84321d3ee 100644 --- a/diskann-inmem/src/arbiter/freelist.rs +++ b/diskann-inmem/src/arbiter/freelist.rs @@ -5,7 +5,7 @@ use std::{ num::NonZeroU32, - sync::atomic::{AtomicU32, AtomicUsize, Ordering}, + sync::atomic::{AtomicU32, Ordering}, }; use crossbeam_queue::ArrayQueue; @@ -77,7 +77,7 @@ impl Freelist { Scan { current, max: self.max, - len: SCAN_SIZE.into_usize() + len: SCAN_SIZE.into_usize(), } } @@ -132,7 +132,7 @@ impl Iterator for Scan { if self.len == 0 { None } else { - let mut i = self.current; + let i = self.current; self.current += 1; self.len -= 1; if self.current == self.max { diff --git a/diskann-inmem/src/ids/mod.rs b/diskann-inmem/src/ids/mod.rs new file mode 100644 index 000000000..ab2ae10e8 --- /dev/null +++ b/diskann-inmem/src/ids/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) mod sharded; +pub(crate) use sharded::Sharded; diff --git a/diskann-inmem/src/ids/sharded.rs b/diskann-inmem/src/ids/sharded.rs new file mode 100644 index 000000000..62cec3fdc --- /dev/null +++ b/diskann-inmem/src/ids/sharded.rs @@ -0,0 +1,150 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::hash::Hash; + +use dashmap::{DashMap, mapref::entry::Entry}; +use diskann::utils::IntoUsize; +use parking_lot::RwLock; +use thiserror::Error; + +const SHARD_SIZE: usize = 1024; + +/// Bidirectional mapping between an external id `I` and a dense internal `u32` id. +#[derive(Debug)] +pub(crate) struct Sharded +where + I: Hash + Eq, +{ + forward: DashMap, + backward: Vec]>>>, + capacity: usize, +} + +impl Sharded +where + I: Hash + Eq, +{ + pub(crate) fn new(capacity: usize) -> Self { + let backward = std::iter::repeat_with(|| { + let shard = std::iter::repeat_with(|| None).take(SHARD_SIZE).collect(); + RwLock::new(shard) + }) + .take(capacity.div_ceil(SHARD_SIZE)) + .collect(); + + Self { + forward: DashMap::new(), + backward, + capacity, + } + } + + pub(crate) fn capacity(&self) -> usize { + self.capacity + } + + /// Establish a mapping between `external` and `internal`. + /// + /// # Errors + /// + /// Returns [`InsertError::OutOfBounds`] if `internal` is outside the table's capacity. + /// Returns [`InsertError::ExternalExists`] if `external` is already mapped. + /// Returns [`InsertError::InternalExists`] if `internal` is already mapped. + pub(crate) fn insert(&self, external: I, internal: u32) -> Result<(), InsertError> + where + I: Eq + Hash + Clone, + { + if internal.into_usize() >= self.capacity { + return Err(InsertError::OutOfBounds); + } + + let Shard { outer, inner } = self.shard(internal); + + // Take the forward entry first and hold it vacant until the reverse slot is + // confirmed empty. This makes the pair-write atomic with respect to other + // `insert` callers: another thread racing on the same `external` will block + // on the dashmap shard, and another thread racing on the same `internal` will + // block on the backward shard's write lock. + let forward = match self.forward.entry(external.clone()) { + Entry::Occupied(_) => return Err(InsertError::ExternalExists), + Entry::Vacant(vacant) => vacant, + }; + + let mut shard = self.backward[outer].write(); + if shard[inner].is_some() { + // Forward entry drops as vacant — no insertion happened. + return Err(InsertError::InternalExists); + } + shard[inner] = Some(external); + forward.insert(internal); + Ok(()) + } + + /// Look up the internal id for an external id. + pub(crate) fn to_internal(&self, external: &Q) -> Option + where + I: std::borrow::Borrow, + Q: Eq + Hash + ?Sized, + { + self.forward.get(external).map(|v| *v) + } + + /// Look up the external id for an internal id. + pub(crate) fn to_external(&self, internal: u32) -> Option + where + I: Clone, + { + if internal.into_usize() >= self.capacity { + return None; + } + + let Shard { outer, inner } = self.shard(internal); + self.backward[outer].read()[inner].clone() + } + + /// Remove the mapping for `external`. Returns the freed internal id, or `None` if + /// no such mapping existed. + pub(crate) fn remove(&self, external: &Q) -> Option + where + I: Eq + Hash + std::borrow::Borrow, + Q: Eq + Hash + ?Sized, + { + let (_, internal) = self.forward.remove(external)?; + let Shard { outer, inner } = self.shard(internal); + + // The backward slot should be populated by the `insert` invariant. + // + // If not - this is a program bug. + let mut shard = self.backward[outer].write(); + assert!(shard[inner].is_some(), "id {} removed improperly", internal); + shard[inner] = None; + + Some(internal) + } + + fn shard(&self, i: u32) -> Shard { + let i = i.into_usize(); + Shard { + outer: i / SHARD_SIZE, + inner: i % SHARD_SIZE, + } + } +} + +struct Shard { + outer: usize, + inner: usize, +} + +#[derive(Debug, Error)] +pub(crate) enum InsertError { + #[error("internal id is out of bounds")] + OutOfBounds, + #[error("the external id is already mapped")] + ExternalExists, + #[error("the internal id is already mapped")] + InternalExists, +} diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index b77470612..b6d4e14bd 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -4,12 +4,13 @@ */ mod arbiter; +pub mod num; +pub mod ids; pub mod layers; mod store; pub mod neighbors; -pub mod num; pub mod provider; pub use neighbors::Neighbors; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index 620990129..f3cd52dc7 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -4,7 +4,7 @@ */ use diskann::{graph::AdjacencyList, utils::IntoUsize}; -use parking_lot::{RawRwLock, RwLock, RwLockWriteGuard}; +use parking_lot::{RwLock, RwLockWriteGuard}; use thiserror::Error; use crate::{ @@ -226,12 +226,3 @@ impl std::fmt::Debug for Lock<'_> { #[derive(Debug, Clone, Copy, Error)] #[error("too long")] pub struct TooLong; - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; -} diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 977081243..cf2ef70c1 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -3,14 +3,17 @@ * Licensed under the MIT license. */ +use std::hash::Hash; + use diskann::{ ANNError, ANNErrorKind, ANNResult, - error::Infallible, graph::{ AdjacencyList, glue::{self, HybridPredicate}, workingset, + SearchOutputBuffer, }, + neighbor::Neighbor, provider, utils::IntoUsize, }; @@ -18,18 +21,29 @@ use diskann_utils::views::Matrix; use crate::{ arbiter::epoch, + ids, layers::{self, Distance, QueryDistance}, num::Bytes, store::{self, Primary}, }; +pub trait Id: Send + Sync + Hash + Eq + Clone + 'static {} +impl Id for T where T: Send + Sync + Hash + Eq + Clone + 'static {} + #[derive(Debug)] -pub struct Provider { +pub struct Provider +where + M: Id, +{ primary: Primary, layer: T, + mapping: ids::Sharded, } -impl Provider { +impl Provider +where + M: Id, +{ pub fn new(layer: T, capacity: usize, start_points: I) -> Self where I: IntoIterator, @@ -44,7 +58,12 @@ impl Provider { } let primary = Primary::new(capacity, bytes, 32, data.as_view()); - Self { primary, layer } + let mapping = ids::Sharded::new(capacity); + Self { + primary, + layer, + mapping, + } } fn reader(&self) -> Result, epoch::Unavailable> { @@ -61,22 +80,24 @@ pub struct Context; impl diskann::provider::ExecutionContext for Context {} -impl diskann::provider::DataProvider for Provider +impl diskann::provider::DataProvider for Provider where T: Send + Sync + 'static, + M: Id, { type Context = Context; type InternalId = u32; - type ExternalId = u32; - type Error = diskann::error::Infallible; + type ExternalId = M; + type Error = ANNError; type Guard = diskann::provider::NoopGuard; fn to_internal_id( &self, _context: &Self::Context, - gid: &Self::ExternalId, + gid: &M, ) -> Result { - Ok(*gid) + let id = self.mapping.to_internal(gid).unwrap(); + Ok(id) } /// Translate an internal id to its corresponding external id. @@ -85,7 +106,8 @@ where _context: &Self::Context, id: Self::InternalId, ) -> Result { - Ok(id) + let gid = self.mapping.to_external(id).unwrap(); + Ok(gid) } } @@ -96,21 +118,27 @@ where std::future::ready(f()) } -impl diskann::provider::SetElement for Provider +impl diskann::provider::SetElement for Provider where L: layers::Layer + layers::Set, + M: Id, { type SetError = ANNError; fn set_element( &self, _context: &Self::Context, - id: &Self::ExternalId, + id: &M, element: T, ) -> impl std::future::Future> + Send { let work = move || { let mut slot = self.primary.acquire().unwrap(); + + // TODO: Proper cleanup via `Guard` or some other mechanism on the event of + // insert failure. >::into_bytes(&self.layer, element, slot.as_mut_slice())?; + self.mapping.insert(id.clone(), slot.slot()).unwrap(); + Ok(diskann::provider::NoopGuard::new(slot.slot())) }; @@ -132,6 +160,9 @@ pub struct SearchAccessor<'a> { distance: Box, ids: AdjacencyList, expand_beam: FExpandBeam, + + // The parent provider for the accessor. + provider: &'a (dyn std::any::Any + Send + Sync), } impl diskann::provider::HasId for SearchAccessor<'_> { @@ -319,7 +350,6 @@ unsafe fn expand_beam_inner( pub struct PruneAccessor<'a> { reader: store::Reader<'a>, distance: &'a dyn Distance, - ids: AdjacencyList, } impl diskann::provider::HasId for PruneAccessor<'_> { @@ -429,16 +459,17 @@ impl workingset::View for &PruneAccessor<'_> { #[derive(Debug, Clone, Copy)] pub struct Strategy; -impl<'a, T, L> glue::SearchStrategy<'a, Provider, T> for Strategy +impl<'a, T, L, M> glue::SearchStrategy<'a, Provider, T> for Strategy where L: layers::Search<'a, T>, + M: Id, { type SearchAccessor = SearchAccessor<'a>; type SearchAccessorError = ANNError; fn search_accessor( &'a self, - provider: &'a Provider, + provider: &'a Provider, _context: &'a Context, query: T, ) -> ANNResult> { @@ -450,42 +481,92 @@ where distance, ids: AdjacencyList::new(), expand_beam, + provider, }; Ok(accessor) } } -impl<'a, T, L> glue::DefaultPostProcessor<'a, Provider, T> for Strategy +#[derive(Debug, Clone, Copy)] +pub struct Translate(std::marker::PhantomData<(L, M)>); + +impl Default for Translate { + fn default() -> Self { + Self(std::marker::PhantomData) + } +} + +impl<'a, T, L, M> glue::SearchPostProcess, T, M> for Translate where L: layers::Search<'a, T>, + M: Id, { - diskann::default_post_processor!(glue::CopyIds); + type Error = ANNError; + + fn post_process( + &self, + accessor: &mut SearchAccessor<'_>, + query: T, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized { + + let work = move || { + // By construction - the downcast should succeed. Otherwise, this is a program bug. + let provider = accessor.provider.downcast_ref::>().unwrap(); + let mut count = 0; + for c in candidates { + if let Some(ext) = provider.mapping.to_external(c.id) { + if output.push(ext, c.distance).is_available() { + count += 1; + } else { + break; + } + } + } + Ok(count) + }; + + ready(work) + } } -impl glue::PruneStrategy> for Strategy +impl<'a, T, L, M> glue::DefaultPostProcessor<'a, Provider, T, M> for Strategy +where + L: layers::Search<'a, T>, + M: Id, +{ + diskann::default_post_processor!(Translate); +} + +impl glue::PruneStrategy> for Strategy where L: layers::Layer + layers::AsDistance, + M: Id, { type PruneAccessor<'a> = PruneAccessor<'a>; type PruneAccessorError = ANNError; fn prune_accessor<'a>( &self, - provider: &'a Provider, + provider: &'a Provider, _context: &'a Context, capacity: usize, ) -> ANNResult> { Ok(PruneAccessor { reader: provider.primary.reader()?, distance: ::as_distance(&provider.layer), - ids: AdjacencyList::new(), }) } } -impl<'a, L, T> glue::InsertStrategy<'a, Provider, T> for Strategy +impl<'a, L, M, T> glue::InsertStrategy<'a, Provider, T> for Strategy where L: layers::Insert<'a, T>, + M: Id, { type PruneStrategy = Self; fn prune_strategy(&self) -> Self::PruneStrategy { diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index e28bcaa93..610024c3b 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -130,7 +130,9 @@ impl Primary { } freelist::Id::Scan => match self.scan_acquire() { Some(slot) => return Some(slot), - None => { self.try_drain(); }, + None => { + self.try_drain(); + } }, } } @@ -279,6 +281,19 @@ impl<'a> Reader<'a> { i < self.buffer.len() } + pub(crate) fn can_read(&self, i: usize) -> Option { + if !self.is_in_bounds(i) { + return None; + } + + let generation = unsafe { self.buffer.get_unchecked(i).truncate_unchecked(SPLIT) }; + let generation = unsafe { generation::Tag::from_ptr(generation.as_mut_ptr().cast()) } + .as_ref() + .get(Ordering::Acquire); + + Some(generation >= self.epoch.generation()) + } + #[inline] pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { debug_assert!(self.is_in_bounds(i)); From 1f9f20e9245756ac1e765a3d28c1c1233c5966dd Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 13 Jun 2026 18:05:31 -0700 Subject: [PATCH 10/45] Streaming works (kinda). --- diskann-benchmark/src/index/benchmarks.rs | 10 +- diskann-benchmark/src/index/inmem2.rs | 473 +++++++++++++++++++++- diskann-inmem/src/ids/sharded.rs | 8 + diskann-inmem/src/provider.rs | 160 +++++++- diskann-inmem/src/store.rs | 6 +- 5 files changed, 613 insertions(+), 44 deletions(-) diff --git a/diskann-benchmark/src/index/benchmarks.rs b/diskann-benchmark/src/index/benchmarks.rs index d5e4d944b..2d4266dd3 100644 --- a/diskann-benchmark/src/index/benchmarks.rs +++ b/diskann-benchmark/src/index/benchmarks.rs @@ -94,11 +94,11 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> // FullPrecision::::new().search(plugins::Topk), // )?; - // // Dynamic Full Precision - // registry.register( - // "graph-index-dynamic-full-precision-f32", - // DynamicFullPrecision::::new(), - // )?; + // Dynamic Full Precision + registry.register( + "graph-index-dynamic-full-precision-f32", + DynamicFullPrecision::::new(), + )?; // registry.register( // "graph-index-dynamic-full-precision-f16", // DynamicFullPrecision::::new(), diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 8f8448120..ef104a718 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -9,11 +9,17 @@ //! infrastructure in `diskann-benchmark-core`, giving us parallel insertion via //! [`SingleInsert`] and KNN search with recall/latency reporting via [`KNN`]. -use std::{io::Write, num::NonZeroUsize, sync::Arc}; +use std::{io::Write, num::NonZeroUsize, ops::Range, sync::Arc}; -use diskann::graph::{self, DiskANNIndex}; +use diskann::{ + graph::{self, DiskANNIndex, StartPointStrategy}, + provider::{self as ann_provider}, +}; use diskann_benchmark_core::{ - self as benchmark_core, build as build_core, recall::GroundTruthMode, search as core_search, + self as benchmark_core, build as build_core, + recall::{self, GroundTruthMode}, + search as core_search, + streaming::{self, executors::bigann, Executor}, }; use diskann_benchmark_runner::{ benchmark::{FailureScore, MatchScore}, @@ -22,14 +28,18 @@ use diskann_benchmark_runner::{ Benchmark, Checker, Checkpoint, Input, Registry, }; use diskann_inmem::{layers::Full, Provider, Strategy}; -use diskann_utils::views::Matrix; +use diskann_utils::views::{Matrix, MatrixView}; use diskann_vector::distance::Metric; use serde::{Deserialize, Serialize}; -use crate::{backend::index::build::ProgressMeter, utils::datafiles}; +use crate::{ + backend::index::build::ProgressMeter, inputs::graph_index::DynamicRunbookParams, + utils::datafiles, +}; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { registry.register("inmem2-f32", Inmem2)?; + registry.register("inmem2-f32-stream", Inmem2Stream)?; Ok(()) } @@ -146,23 +156,12 @@ impl Benchmark for Inmem2 { let num_points = data.nrows(); writeln!(output, "Loaded {num_points} points, dim={dim}")?; - // Compute a start point as the centroid of the first min(1000, N) points. - let sample = num_points.min(1000); - let mut centroid = vec![0.0f32; dim]; - for i in 0..sample { - for (c, &v) in centroid.iter_mut().zip(data.row(i)) { - *c += v; - } - } - let inv = 1.0 / sample as f32; - centroid.iter_mut().for_each(|c| *c *= inv); - - // Build inmem2 provider. + // Compute the medoid of the dataset as the single start point. + let start = StartPointStrategy::Medoid.compute(data.as_view())?; let metric = Metric::L2; let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; let layer = Full::::new(dim, metric); - let start_points: [&[f32]; 1] = [¢roid]; - let provider = Provider::new(layer, num_points, start_points); + let provider = Provider::new(layer, num_points, start.row_iter()); let config = graph::config::Builder::new_with( input.max_degree, @@ -274,3 +273,439 @@ impl Benchmark for Inmem2 { } } +/////////////// +// Streaming // +/////////////// + +/// Input for the streaming inmem2 benchmark. +/// +/// Drives the inmem2 provider through a BigANN-style runbook. Because the inmem2 +/// provider already does external↔internal id translation, no `Managed`/ +/// `TagSlotManager` adapter is needed — the runbook talks to the provider directly. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Inmem2StreamInput { + /// Full dataset that the runbook indexes into. + data: InputFile, + /// Query set used for every search stage. + queries: InputFile, + + /// Runbook parameters (path, dataset name, gt directory, ...). + runbook_params: DynamicRunbookParams, + + max_degree: usize, + l_build: usize, + alpha: f32, + + search_n: usize, + search_l: Vec, + recall_k: usize, + + num_threads: usize, + reps: NonZeroUsize, +} + +impl Input for Inmem2StreamInput { + type Raw = Inmem2StreamInput; + + fn tag() -> &'static str { + "inmem2-stream" + } + + fn from_raw(mut raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { + raw.data.resolve(checker)?; + raw.queries.resolve(checker)?; + raw.runbook_params.validate(checker)?; + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { + Self { + data: InputFile::new("path/to/base.bin"), + queries: InputFile::new("path/to/query.bin"), + runbook_params: ::example(), + max_degree: 64, + l_build: 100, + alpha: 1.2, + search_n: 10, + search_l: vec![10, 20, 50, 100], + recall_k: 10, + num_threads: 4, + reps: NonZeroUsize::new(3).unwrap(), + } + } +} + +impl std::fmt::Display for Inmem2StreamInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "inmem2 f32 streaming benchmark")?; + writeln!( + f, + " runbook: {}", + self.runbook_params.runbook_path.display() + )?; + writeln!(f, " dataset: {}", self.runbook_params.dataset_name)?; + writeln!(f, " max_degree: {}", self.max_degree)?; + writeln!(f, " l_build: {}", self.l_build)?; + writeln!(f, " alpha: {}", self.alpha)?; + writeln!(f, " search_n: {}", self.search_n)?; + writeln!(f, " search_l: {:?}", self.search_l)?; + writeln!(f, " recall_k: {}", self.recall_k)?; + writeln!(f, " num_threads: {}", self.num_threads)?; + writeln!(f, " reps: {}", self.reps) + } +} + +#[derive(Debug)] +struct Inmem2Stream; + +impl Benchmark for Inmem2Stream { + type Input = Inmem2StreamInput; + type Output = Vec; + + fn try_match(&self, _input: &Inmem2StreamInput) -> Result { + Ok(MatchScore(0)) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&Inmem2StreamInput>, + ) -> std::fmt::Result { + match input { + Some(i) => write!(f, "{i}"), + None => write!(f, "inmem2 f32 streaming benchmark"), + } + } + + fn run( + &self, + input: &Inmem2StreamInput, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + writeln!(output, "{input}")?; + + // Load the runbook so we know the eventual capacity. + let gt_dir = input + .runbook_params + .resolved_gt_directory + .as_ref() + .ok_or_else(|| anyhow::anyhow!("groundtruth directory not resolved"))?; + + let runbook = bigann::RunBook::load( + &input.runbook_params.runbook_path, + &input.runbook_params.dataset_name, + &mut bigann::ScanDirectory::new(gt_dir)?, + )?; + let max_points = runbook.max_points(); + + // Load the dataset (consumed by `WithData`) and queries. + let dataset: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data))?; + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.queries))?); + let dim = dataset.ncols(); + + writeln!( + output, + "Loaded dataset: {} points, dim={}", + dataset.nrows(), + dim + )?; + writeln!(output, "Loaded queries: {}", queries.nrows())?; + writeln!(output, "Runbook max_points: {max_points}")?; + + // Compute the medoid of the dataset as the single start point. + let start = StartPointStrategy::Medoid.compute(dataset.as_view())?; + let metric = Metric::L2; + let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; + let layer = Full::::new(dim, metric); + let provider = Provider::new(layer, max_points, start.row_iter()); + + let config = graph::config::Builder::new_with( + input.max_degree, + graph::config::MaxDegree::new(exact_max_degree), + input.l_build, + metric.into(), + |b| { + b.alpha(input.alpha); + }, + ) + .build()?; + + let index = Arc::new(DiskANNIndex::new(config, provider, None)); + + let num_threads = NonZeroUsize::new(input.num_threads.max(1)).unwrap(); + let runtime = benchmark_core::tokio::runtime(num_threads.get())?; + + // Build the inner stream and wrap it with `WithData`. + let stream = Stream { + index: index.clone(), + runtime, + ntasks: num_threads, + search_n: input.search_n, + search_l: input.search_l.clone(), + recall_k: input.recall_k, + reps: input.reps, + }; + + let max_k = input.recall_k; + let queries_for_search = queries.clone(); + let mut layered = bigann::WithData::new(stream, dataset, queries_for_search, move |path| { + Ok(Box::new(datafiles::load_groundtruth( + datafiles::BinFile(path), + Some(max_k), + )?)) + }); + + // Drive the runbook. + let mut runbook = runbook; + let mut results = Vec::new(); + let stages = runbook.len(); + let mut stage_idx = 1usize; + + runbook.run_with(&mut layered, |o: StreamOutput| -> anyhow::Result<()> { + let banner = format!("Stage {} of {}: {}", stage_idx, stages, o.kind()); + write!(output, "{}", crate::utils::SmallBanner(&banner))?; + writeln!(output, "{o}")?; + stage_idx += 1; + results.push(o); + Ok(()) + })?; + + write!( + output, + "{}", + crate::utils::SmallBanner("End of Run Summary") + )?; + let total_inserts: usize = results.iter().filter_map(|r| r.insert_count()).sum(); + let total_deletes: usize = results.iter().filter_map(|r| r.delete_count()).sum(); + let n_searches = results + .iter() + .filter(|r| matches!(r, StreamOutput::Search { .. })) + .count(); + writeln!( + output, + "stages={stages} inserts={total_inserts} deletes={total_deletes} searches={n_searches}", + )?; + + Ok(results) + } +} + +///////////////// +// Stream impl // +///////////////// + +/// Inner streaming index over `inmem2`. +/// +/// Implements `streaming::Stream>` so it can be wrapped +/// by `bigann::WithData` and driven by `bigann::RunBook`. Replace and maintain are +/// not supported in v1; deletes are eager so no consolidation pass is needed. +struct Stream { + index: Arc>>>, + runtime: tokio::runtime::Runtime, + ntasks: NonZeroUsize, + search_n: usize, + search_l: Vec, + recall_k: usize, + reps: NonZeroUsize, +} + +#[derive(Debug, Serialize)] +pub(crate) enum StreamOutput { + Insert { count: usize, latency_s: f64 }, + Delete { count: usize, latency_s: f64 }, + Search(Vec), +} + +#[derive(Debug, Serialize)] +pub(crate) struct SearchPoint { + pub search_l: usize, + pub recall: f64, + pub mean_qps: f64, + pub max_qps: f64, +} + +impl StreamOutput { + fn kind(&self) -> &'static str { + match self { + Self::Insert { .. } => "insert", + Self::Delete { .. } => "delete", + Self::Search(_) => "search", + } + } + + fn insert_count(&self) -> Option { + match self { + Self::Insert { count, .. } => Some(*count), + _ => None, + } + } + + fn delete_count(&self) -> Option { + match self { + Self::Delete { count, .. } => Some(*count), + _ => None, + } + } +} + +impl std::fmt::Display for StreamOutput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Insert { count, latency_s } => { + writeln!(f, " inserted {count} points in {latency_s:.3}s") + } + Self::Delete { count, latency_s } => { + writeln!(f, " deleted {count} points in {latency_s:.3}s") + } + Self::Search(points) => { + for p in points { + writeln!( + f, + " L={:<4} recall={:.4} QPS={:.0} (max {:.0})", + p.search_l, p.recall, p.mean_qps, p.max_qps, + )?; + } + Ok(()) + } + } + } +} + +impl streaming::Stream> for Stream { + type Output = StreamOutput; + + fn search( + &mut self, + (queries, groundtruth): (Arc>, &dyn recall::Rows), + ) -> anyhow::Result { + let knn = benchmark_core::search::graph::KNN::new( + self.index.clone(), + queries, + benchmark_core::search::graph::Strategy::broadcast(Strategy), + )?; + + let mut points = Vec::with_capacity(self.search_l.len()); + for &search_l in &self.search_l { + let params = graph::search::Knn::new(self.search_n, search_l, None)?; + let setup = core_search::Setup { + threads: self.ntasks, + tasks: self.ntasks, + reps: self.reps, + }; + let run = core_search::Run::new(params, setup); + + let summaries = core_search::search_all( + knn.clone(), + std::iter::once(run), + benchmark_core::search::graph::knn::Aggregator::new( + groundtruth, + self.recall_k, + self.search_n, + GroundTruthMode::Fixed, + ), + )?; + + for summary in &summaries { + let qps: Vec = summary + .end_to_end_latencies + .iter() + .map(|lat| summary.recall.num_queries as f64 / lat.as_seconds()) + .collect(); + let max_qps = qps.iter().cloned().fold(0.0f64, f64::max); + let mean_qps = qps.iter().sum::() / qps.len().max(1) as f64; + points.push(SearchPoint { + search_l, + recall: summary.recall.average, + mean_qps, + max_qps, + }); + } + } + Ok(StreamOutput::Search(points)) + } + + fn insert( + &mut self, + (data, ids): (MatrixView<'_, f32>, Range), + ) -> anyhow::Result { + anyhow::ensure!( + data.nrows() == ids.len(), + "insert: data rows ({}) != ids range ({})", + data.nrows(), + ids.len(), + ); + + let count = data.nrows(); + let slots: Box<[u32]> = ids + .map(|id| u32::try_from(id)) + .collect::, _>>()?; + + let runner = build_core::graph::SingleInsert::new( + self.index.clone(), + Arc::new(data.to_owned()), + Strategy, + build_core::ids::Slice::new(slots), + ); + + let results = build_core::build( + runner, + build_core::Parallelism::dynamic(diskann::utils::ONE, self.ntasks), + &self.runtime, + )?; + + let latency_s = results.end_to_end_latency().as_seconds(); + Ok(StreamOutput::Insert { count, latency_s }) + } + + fn delete(&mut self, ids: Range) -> anyhow::Result { + let count = ids.len(); + let provider = self.index.provider(); + let ctx = diskann_inmem::Context; + + let start = std::time::Instant::now(); + + let runner = streaming::graph::InplaceDelete::new( + self.index.clone(), + Strategy, + 3, + diskann::graph::InplaceDeleteMethod::OneHop, + build_core::ids::Slice::new(ids.clone().into_iter().map(|i| i as u32).collect()), + ); + + let _ = build_core::build( + runner, + diskann_benchmark_core::build::Parallelism::fixed( + Some(diskann::utils::ONE), + self.ntasks, + ), + &self.runtime, + )?; + + let latency_s = start.elapsed().as_secs_f64(); + + Ok(StreamOutput::Delete { count, latency_s }) + } + + fn replace( + &mut self, + _args: (MatrixView<'_, f32>, Range), + ) -> anyhow::Result { + anyhow::bail!("inmem2-f32-stream: replace is not supported in v1") + } + + fn maintain(&mut self, _: ()) -> anyhow::Result { + anyhow::bail!( + "inmem2-f32-stream: maintain is not supported (deletes are eager, no consolidation needed)" + ) + } + + fn needs_maintenance(&mut self) -> bool { + false + } +} diff --git a/diskann-inmem/src/ids/sharded.rs b/diskann-inmem/src/ids/sharded.rs index 62cec3fdc..2714a1192 100644 --- a/diskann-inmem/src/ids/sharded.rs +++ b/diskann-inmem/src/ids/sharded.rs @@ -83,6 +83,14 @@ where Ok(()) } + pub(crate) fn contains_external(&self, external: &Q) -> bool + where + I: std::borrow::Borrow, + Q: Eq + Hash + ?Sized, + { + self.forward.contains_key(external) + } + /// Look up the internal id for an external id. pub(crate) fn to_internal(&self, external: &Q) -> Option where diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index cf2ef70c1..b954521d2 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -8,10 +8,9 @@ use std::hash::Hash; use diskann::{ ANNError, ANNErrorKind, ANNResult, graph::{ - AdjacencyList, + AdjacencyList, SearchOutputBuffer, glue::{self, HybridPredicate}, workingset, - SearchOutputBuffer, }, neighbor::Neighbor, provider, @@ -111,6 +110,84 @@ where } } +// TODO: The element-status checks here are profoundly expensive as they require epoch +// registration for each check! +// +// `diskann` has plans to move deletion checks behind an accessor trait, which will help +// with this situation. +impl diskann::provider::Delete for Provider +where + L: Send + Sync + 'static, + M: Id, +{ + async fn delete(&self, _context: &Context, gid: &M) -> ANNResult<()> { + // TODO: These need to actually happen in lock-step. + let internal = self.mapping.remove(gid).unwrap(); + assert!(self.primary.delete(internal.into_usize())); + Ok(()) + } + + async fn release(&self, _context: &Context, id: Self::InternalId) -> ANNResult<()> { + Ok(()) + } + + async fn status_by_internal_id( + &self, + _context: &Context, + id: u32, + ) -> ANNResult { + if self + .primary + .reader() + .unwrap() + .can_read(id.into_usize()) + .unwrap() + { + Ok(diskann::provider::ElementStatus::Valid) + } else { + Ok(diskann::provider::ElementStatus::Deleted) + } + } + + /// Check the status via external ID. + async fn status_by_external_id( + &self, + _context: &Context, + gid: &M, + ) -> ANNResult { + if self.mapping.contains_external(gid) { + Ok(diskann::provider::ElementStatus::Valid) + } else { + Ok(diskann::provider::ElementStatus::Deleted) + } + } + + fn statuses_unordered( + &self, + context: &Self::Context, + itr: Itr, + mut f: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + F: FnMut(ANNResult, Self::InternalId) + Send, + { + let work = move || { + let reader = self.primary.reader().unwrap(); + for i in itr { + if reader.can_read(i.into_usize()).unwrap() { + f(Ok(diskann::provider::ElementStatus::Valid), i) + } else { + f(Ok(diskann::provider::ElementStatus::Deleted), i) + } + } + Ok(()) + }; + + ready(work) + } +} + fn ready(f: F) -> std::future::Ready where F: FnOnce() -> R, @@ -150,10 +227,6 @@ where // Search // //////////// -const fn start_point() -> u32 { - 0 -} - #[derive(Debug)] pub struct SearchAccessor<'a> { reader: store::Reader<'a>, @@ -163,6 +236,7 @@ pub struct SearchAccessor<'a> { // The parent provider for the accessor. provider: &'a (dyn std::any::Any + Send + Sync), + start_points: std::ops::Range, } impl diskann::provider::HasId for SearchAccessor<'_> { @@ -173,7 +247,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { fn starting_points( &self, ) -> impl std::future::Future>> + Send { - std::future::ready(Ok(vec![start_point()])) + std::future::ready(Ok(self.start_points.clone().collect())) } fn start_point_distances( @@ -184,18 +258,20 @@ impl glue::SearchAccessor for SearchAccessor<'_> { F: FnMut(Self::Id, f32) + Send, { let work = move || { - let start = start_point(); - match self.reader.read(start.into_usize()) { - Some(point) => { - f(start, self.distance.evaluate(point)?); - Ok(()) + for p in self.start_points.clone() { + match self.reader.read(p.into_usize()) { + Some(point) => { + f(p, self.distance.evaluate(point)?); + } + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "could not retrieve start point", + )); + } } - // TODO: "lock" start points. - None => Err(ANNError::message( - ANNErrorKind::Opaque, - "could not retrieve start point", - )), } + Ok(()) }; ready(work) @@ -482,6 +558,7 @@ where ids: AdjacencyList::new(), expand_beam, provider, + start_points: provider.primary.frozen(), }; Ok(accessor) } @@ -512,8 +589,8 @@ where ) -> impl std::future::Future> + Send where I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized { - + B: SearchOutputBuffer + Send + ?Sized, + { let work = move || { // By construction - the downcast should succeed. Otherwise, this is a program bug. let provider = accessor.provider.downcast_ref::>().unwrap(); @@ -574,6 +651,51 @@ where } } +// TODO: This is such a hack. +impl glue::InplaceDeleteStrategy, M>> for Strategy +where + Self: glue::PruneStrategy, M>>, + Self: for<'a> glue::InsertStrategy<'a, Provider, M>, &'a [f32], SearchAccessor = SearchAccessor<'a>>, + M: Id, +{ + type DeleteElement<'a> = &'a [f32]; + type DeleteElementGuard = Box<[f32]>; + type DeleteElementError = ANNError; + + type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = SearchAccessor<'a>; + type SearchPostProcessor = glue::CopyIds; + type SearchStrategy = Self; + + fn prune_strategy(&self) -> Self { + *self + } + + fn search_strategy(&self) -> Self { + *self + } + + fn search_post_processor(&self) -> Self::SearchPostProcessor { + glue::CopyIds + } + + fn get_delete_element<'a>( + &'a self, + provider: &'a Provider, M>, + context: &'a Context, + id: u32 + ) -> impl Future> + Send { + let work = move || { + let reader = provider.primary.reader().unwrap(); + let mut buf: Box<[_]> = std::iter::repeat_n(0.0, provider.layer.dim()).collect(); + let data = reader.read(id.into_usize()).unwrap(); + bytemuck::must_cast_slice_mut::(&mut buf).copy_from_slice(data); + Ok(buf) + }; + ready(work) + } +} + /////////// // Tests // /////////// diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 610024c3b..8500511e5 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -61,7 +61,7 @@ impl Primary { // we do not want it to release frozen IDs. freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), registry: epoch::Registry::new(), - neighbors: Neighbors::new(entries, max_neighbors), + neighbors: Neighbors::new(total, max_neighbors), }; // Populate frozen points. @@ -74,6 +74,10 @@ impl Primary { this } + pub fn frozen(&self) -> std::ops::Range { + (self.unfrozen as u32)..(self.buffer.len() as u32) + } + pub fn capacity(&self) -> usize { self.buffer.len() - self.unfrozen } From c4e20b35cd2e3cf2632ba5a822c60bab48ae7607 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 17 Jun 2026 14:18:29 -0700 Subject: [PATCH 11/45] Checkpoint. --- Cargo.lock | 1 + diskann-inmem/Cargo.toml | 1 + diskann-inmem/src/arbiter/epoch.rs | 706 ---------------- diskann-inmem/src/arbiter/freelist.rs | 150 ---- diskann-inmem/src/arbiter/generation.rs | 298 ------- diskann-inmem/src/{arbiter => }/buffer.rs | 127 +-- diskann-inmem/src/layers/full.rs | 20 +- diskann-inmem/src/layers/mod.rs | 47 +- diskann-inmem/src/lib.rs | 6 +- diskann-inmem/src/neighbors.rs | 429 +++++++++- diskann-inmem/src/provider.rs | 60 +- diskann-inmem/src/store.rs | 111 +-- diskann-inmem/src/sync/epoch.rs | 911 +++++++++++++++++++++ diskann-inmem/src/sync/freelist.rs | 502 ++++++++++++ diskann-inmem/src/{arbiter => sync}/mod.rs | 11 +- diskann-inmem/src/sync/tag.rs | 308 +++++++ diskann-inmem/src/sync/test.rs | 282 +++++++ diskann-inmem/src/test.rs | 86 ++ 18 files changed, 2660 insertions(+), 1396 deletions(-) delete mode 100644 diskann-inmem/src/arbiter/epoch.rs delete mode 100644 diskann-inmem/src/arbiter/freelist.rs delete mode 100644 diskann-inmem/src/arbiter/generation.rs rename diskann-inmem/src/{arbiter => }/buffer.rs (84%) create mode 100644 diskann-inmem/src/sync/epoch.rs create mode 100644 diskann-inmem/src/sync/freelist.rs rename diskann-inmem/src/{arbiter => sync}/mod.rs (56%) create mode 100644 diskann-inmem/src/sync/tag.rs create mode 100644 diskann-inmem/src/sync/test.rs create mode 100644 diskann-inmem/src/test.rs diff --git a/Cargo.lock b/Cargo.lock index 4cec4ce43..a30491ab8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -818,6 +818,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "parking_lot", + "rand 0.9.4", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 125222513..c5f8636cd 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -21,4 +21,5 @@ thiserror.workspace = true workspace = true [dev-dependencies] +rand.workspace = true tokio = { workspace = true, features = ["macros"] } diff --git a/diskann-inmem/src/arbiter/epoch.rs b/diskann-inmem/src/arbiter/epoch.rs deleted file mode 100644 index 26862adde..000000000 --- a/diskann-inmem/src/arbiter/epoch.rs +++ /dev/null @@ -1,706 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::sync::atomic::{AtomicUsize, Ordering}; - -use crossbeam_queue::SegQueue; -use diskann::utils::IntoUsize; -use parking_lot::{Mutex, MutexGuard}; - -use crate::arbiter::{ - Generation, - generation::{Mut, Tag}, -}; - -const CAPACITY: usize = 256; - -#[derive(Debug)] -pub struct Registry { - /// A record of the active generations. - /// - /// * Generation::MAX = "available". - /// * Anything less = "registered". - slots: Box<[Tag]>, - - // The current epoch. This begins as `Generation::MAX.sub(1)` and decrements over time. - // - // NOTE: This can only be mutated in `try_advance`. - generation: Tag, - - // A hint for the next available registration slot. - hint: AtomicUsize, - - // We use three queues for storing slots. - // - // 1. Belongs to the current generation and is getting filled. - // 2. Ready for the next generation that will be populated on the next `try_advance`. - // Note that after a `try_advance` call, both 1 and 2 can be added to. - // 3. The queue returned from `try_advance` to be drained. Items drained are safe to - // reclaim. - retiring: [SegQueue; 3], - - // We can only retire a single generation at a time. - // This guard avoids situations. - drain: Mutex<()>, -} - -// Return the queue index for the `generation`. -fn queue(generation: Generation) -> usize { - generation.value().into_usize() % 3 -} - -fn last_queue(generation: Generation) -> usize { - queue(Generation::new(generation.value().wrapping_add(1))) -} - -impl Registry { - pub fn new() -> Self { - Self::with_capacity(CAPACITY) - } - - pub fn with_capacity(capacity: usize) -> Self { - Self { - slots: (0..capacity).map(|_| Tag::new(Generation::MAX)).collect(), - generation: Tag::new(Generation::MAX.sub(1)), - hint: AtomicUsize::new(0), - retiring: core::array::from_fn(|_| SegQueue::new()), - drain: Mutex::new(()), - } - } - - pub fn capacity(&self) -> usize { - self.slots.len() - } - - /// Return the current generation. - /// - /// This has [`Ordering::Acquire`] semantics. - pub fn generation(&self) -> Generation { - self.generation.as_ref().get(Ordering::Acquire) - } - - /// Register the caller with the registry. - /// - /// On success, the returned [`Guard`] will protect items tagged with - /// [`Guard::generation`] and higher. - pub fn register(&self) -> Result, Unavailable> { - self.register_inner(NoDelay) - } - - #[inline] - fn register_inner(&self, mut delay: T) -> Result, Unavailable> - where - T: RegisterDelay, - { - // REGISTER CHECK - let mut generation = self.generation(); - let hint = self.hint.fetch_add(1, Ordering::Relaxed); - delay.post_register_check(); - let nslots = self.slots.len(); - for i in 0..nslots { - let slot = (hint + i) % nslots; - - let m = self.slots[slot].as_mut(); - delay.pre_cas(); - if let Ok(_) = m.try_set( - Generation::MAX, - generation, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - delay.post_cas(); - let mut reset = false; - loop { - // REGISTER FENCE: This fence is paired with "WAITING FENCE". - // - // See that comment for details. - delay.pre_fence(); - std::sync::atomic::fence(Ordering::SeqCst); - delay.post_fence(); - - // REGISTER RECHECK - let current = self.generation(); - if current == generation { - break; - } - - reset = true; - generation = current; - } - - if reset { - m.set(generation, Ordering::Relaxed); - } - - return Ok(Guard { - slot: m, - retire: &self.retiring[queue(generation)], - generation, - #[cfg(test)] - slot_index: slot, - }); - } - } - - Err(Unavailable) - } - - /// Return the oldest generation that is currently being protected. - /// - /// This uses a fast method that may be overly conservative. - /// - /// This is a synchronizing operation with [`Ordering::Acquire`] semantics. - pub fn can_advance(&self) -> bool { - self.can_advance_inner(&mut NoDelay).0 - } - - fn can_advance_inner(&self, delay: &mut T) -> (bool, Generation) - where - T: CanAdvanceDelay, - { - // WAITING FENCE: This is a very important part for the correctness of the algorithm. - // - // What we're protecting against is a scenario where "registering" thread A reads a - // generation, then "waiting" thread B does a scan, thinks everything is safe, and - // then thread A finishes its CAS for its registration. - // - // This is prevented by the fence. Consider the following. - // - // 1. Thread A invokes "REGISTER FENCE" after a successful CAS, and then checks the - // generation at "REGISTER RECHECK". - // - // 2. Thread B now enters the this block of code, executes "WAITING FENCE", then - // reads the generation tags for all slots. - // - // With the total order induced by the sequentially consistent fence, either thread - // A's fence executes first, or thread B's executes first. - // - // * If thread A's fence executes first, then thread B will see the CAS and the set - // value is guaranteed to be greater-than or equal to "WAITING CHECK" since the - // generation check since is monotonically decreasing and thread A's - // "REGISTER CHECK" is forced to happen before. - // - // * If Thread B's fence executes first, then thread A's "REGISTER RECHECK" will - // observe at least the result of "WAITING CHECK" and update itself on the retry. - // - // It's possible that thread B observes the CAS to "REGISTER CHECK", but since - // thread A will monotonically decrease it before exiting, the value thread B - // observes is conservative and not incorrect. - delay.pre_fence(); - std::sync::atomic::fence(Ordering::SeqCst); - delay.post_fence(); - - // WAITING CHECK - let current = self.generation(); - let mut max = current; - - for s in self.slots.iter() { - let generation = s.as_ref().get(Ordering::Relaxed); - if generation != Generation::MAX { - max = max.max(generation); - } - } - - // This synchronizes with all the guard's `Release`s. - std::sync::atomic::fence(Ordering::Acquire); - (max == current, max) - } - - pub fn try_advance(&self) -> Option> { - self.try_advance_inner(NoDelay) - } - - fn try_advance_inner(&self, mut delay: T) -> Option> - where - T: TryAdvanceDelay, - { - // We first try to acquire the `drain` lock. - // - // It can only fail if someone else is holding the drain lock, which means we can't - // proceed anyways. - // - // This can help save an expensive slot scan. - // - // We intentionally ignore lock-poison since we expect the guarded queue to be - // robust with respect to panics. - let drain = self.drain.try_lock()?; - - let (can_advance, current) = self.can_advance_inner(&mut delay); - - // All waiters belong to the current generation. Therefore, it is safe to release - // the old array queue - if can_advance { - // We are safe to use a `fetch_sub` here because `drain` is ensuring exclusivity - // of the access. - // - // However, this still needs to be `SeqCst` so that this properly synchronizes - // with "REGISTER FENCE" and "WAITER FENCE". - let _previous = self.generation.as_mut().fetch_decrement(Ordering::SeqCst); - debug_assert_eq!(_previous, current, "concurrency violation"); - - let queue = &self.retiring[last_queue(current)]; - Some(Drain { queue, drain }) - } else { - // Previous generation has not completely retired. - None - } - } - - #[cfg(test)] - fn assert_no_workers(&self) { - for s in self.slots.iter() { - assert_eq!(s.as_ref().get(Ordering::Relaxed), Generation::MAX); - } - } - - #[cfg(test)] - fn snapshot(&self) -> Vec { - self.slots - .iter() - .map(|s| s.as_ref().get(Ordering::Relaxed)) - .collect() - } - - #[cfg(test)] - fn waiting(&self) -> Generation { - self.can_advance_inner(&mut NoDelay).1 - } -} - -#[derive(Debug)] -pub struct Guard<'a> { - slot: Mut<'a>, - retire: &'a SegQueue, - generation: Generation, - - #[cfg(test)] - slot_index: usize, -} - -impl Guard<'_> { - /// Return the generation associated with the [`Guard`]'s creation. - #[inline] - pub fn generation(&self) -> Generation { - self.generation - } - - /// Retire the slot `i` at the current generation. - #[inline] - pub fn retire(&self, i: u32) { - self.retire.push(i) - } - - /// Retire all items in `itr`. - pub fn retire_all(&self, itr: I) - where - I: IntoIterator, - { - for i in itr { - self.retire(i) - } - } -} - -impl Drop for Guard<'_> { - fn drop(&mut self) { - self.slot.set(Generation::MAX, Ordering::Release); - } -} - -#[derive(Debug)] -pub struct Drain<'a> { - queue: &'a SegQueue, - drain: MutexGuard<'a, ()>, -} - -impl Drain<'_> { - #[must_use = "reclaimed ids must be reclaimed"] - pub fn pop(&self) -> Option { - self.queue.pop() - } - - pub fn len(&self) -> usize { - self.queue.len() - } - - pub fn is_empty(&self) -> bool { - self.queue.is_empty() - } -} - -impl Iterator for Drain<'_> { - type Item = u32; - fn next(&mut self) -> Option { - self.pop() - } - - fn size_hint(&self) -> (usize, Option) { - (self.len(), Some(self.len())) - } -} - -// NOTE: This relies on `Drain` holding the `drain` guard. In this state, we are guaranteed -// that no-one is writing into the queue, which would otherwise invalidate the exact-size -// iterator guarantee. -impl ExactSizeIterator for Drain<'_> {} - -#[derive(Debug)] -#[non_exhaustive] -pub struct Unavailable; - -impl std::fmt::Display for Unavailable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("all available registry slots are occupied") - } -} - -impl std::error::Error for Unavailable {} - -impl From for diskann::ANNError { - #[track_caller] - fn from(unavailable: Unavailable) -> Self { - diskann::ANNError::opaque(unavailable) - } -} - -// Delays -// -// To help test standard race scenarios without advanced tooling, we use optional delays -// that our tests can introduce to ensure threads are in various intermediate points. -// -// This does not necessarily test that the memory orderings are correct, but at least -// is a smoke test that various (known) races are handled properly. - -#[derive(Debug)] -struct NoDelay; - -trait RegisterDelay { - fn post_register_check(&mut self) {} - fn pre_cas(&mut self) {} - fn post_cas(&mut self) {} - fn pre_fence(&mut self) {} - fn post_fence(&mut self) {} -} - -impl RegisterDelay for NoDelay {} - -trait CanAdvanceDelay { - fn pre_fence(&mut self) {} - fn post_fence(&mut self) {} -} - -impl CanAdvanceDelay for NoDelay {} - -trait TryAdvanceDelay: CanAdvanceDelay {} - -impl TryAdvanceDelay for NoDelay {} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - use std::{sync::Arc, thread}; - - use parking_lot::Condvar; - - #[derive(Clone)] - struct Sequencer(Arc); - - struct SequencerInner { - state: Mutex, - condvar: Condvar, - } - - #[derive(Debug, Clone, Copy, PartialEq)] - enum State { - Empty, - Parked(usize), - Released(usize), - } - - impl Sequencer { - fn new() -> Self { - Self(Arc::new(SequencerInner { - state: Mutex::new(State::Empty), - condvar: Condvar::new(), - })) - } - - fn wait_for(&self, stage: usize) { - let mut state = self.0.state.lock(); - if stage == 0 { - assert_eq!(*state, State::Empty) - } else { - assert_eq!(*state, State::Released(stage - 1)) - } - - *state = State::Parked(stage); - self.0.condvar.notify_all(); - self.0 - .condvar - .wait_while(&mut state, move |s| *s != State::Released(stage)); - } - - fn advance_past(&self, stage: usize) { - let mut state = self.0.state.lock(); - self.0 - .condvar - .wait_while(&mut state, move |s| Self::check_release(*s, stage)); - *state = State::Released(stage); - self.0.condvar.notify_all(); - } - - fn until_waiting_for(&self, stage: usize) { - let mut state = self.0.state.lock(); - if *state != State::Parked(stage) { - self.0 - .condvar - .wait_while(&mut state, move |s| Self::check_release(*s, stage)) - } - } - - fn check_release(current: State, stage: usize) -> bool { - match current { - State::Empty => { - assert_eq!(stage, 0); - true - } - State::Released(s) => { - if s + 1 != stage { - panic!("observed {:?} while releasing stage {}", current, stage); - } - true - } - State::Parked(s) => { - if s != stage { - panic!("observed {:?} while releasing stage {}", current, stage) - } - false - } - } - } - } - - // This test ensures that two threads racing on `hint` will correctly resolve themselves - // when claiming a slot. - #[test] - fn test_cas_race() { - let seq = Sequencer::new(); - - let delay = TestRegisterDelay::default().with_post_register_check(|| seq.wait_for(0)); - - let registry = Registry::with_capacity(2); - assert_eq!(registry.capacity(), 2); - - thread::scope(|s| { - // Thread A - s.spawn(|| { - let g = registry.register_inner(delay).unwrap(); - assert_eq!(g.slot_index, 1); - seq.wait_for(1); - }); - - // Thread B - s.spawn(|| { - // wait for Thread A to reach the delay point. - seq.until_waiting_for(0); - { - let g = registry.register_inner(NoDelay).unwrap(); - assert_eq!(g.slot_index, 1); - } - let g = registry.register_inner(NoDelay).unwrap(); - assert_eq!(g.slot_index, 0); - seq.advance_past(0); - seq.advance_past(1); - }); - }); - - registry.assert_no_workers(); - } - - #[test] - fn test_register_wait() { - // This tests the case where a thread enters registration, reads a generation, then - // sleeps for several generation advances. It ensures that the thread recovers properly. - let seq = Sequencer::new(); - - let mut loop_count = 0; - let delay = TestRegisterDelay::default() - .with_post_register_check(|| seq.wait_for(0)) - .with_post_cas(|| seq.wait_for(1)) - .with_pre_fence(|| loop_count += 1); - - let registry = Registry::with_capacity(2); - - thread::scope(|s| { - let handle = s.spawn(|| { - let guard = registry.register_inner(delay).unwrap(); - - // Since we hit the CAS loop - this serves as a sanity check that we have - // the correct drain buffer. - guard.retire(10); - guard.retire_all([1, 2, 3]); - guard - }); - - // Wait for the spawned thread to reach the critical section. - seq.until_waiting_for(0); - - assert_eq!(registry.waiting(), Generation::MAX.sub(1)); - { - let drain = registry.try_advance().unwrap(); - assert!(drain.is_empty()); - assert_eq!(registry.generation(), Generation::MAX.sub(2)); - } - - { - let drain = registry.try_advance().unwrap(); - assert!(drain.is_empty()); - assert_eq!(registry.generation(), Generation::MAX.sub(3)); - } - - // We allow the registering thread to make it past the CAS. - // - // We pause it again because we want to verify that it registers an old generation. - seq.advance_past(0); - seq.until_waiting_for(1); - let (can_advance, waiter) = registry.can_advance_inner(&mut NoDelay); - assert!(!can_advance); - assert_eq!( - waiter, - Generation::MAX.sub(1), - "waiting thread registers an older generation before observing the change" - ); - seq.advance_past(1); - - let expected = Generation::MAX.sub(3); - - // The generation should be the last set one - even though this thread was - // parked during the transition. - let r = handle.join().unwrap(); - assert_eq!(r.generation(), expected); - assert_eq!(registry.waiting(), expected); - }); - - assert_eq!( - loop_count, 2, - "the registering thread should have looped to update its generation" - ); - - registry.assert_no_workers(); - - // Verify that we reclaim the ID flushed by the registering thread. - // - // This requires two epoch advancements. - { - let drain = registry.try_advance().unwrap(); - assert!(drain.is_empty()); - } - - { - let drain = registry.try_advance().unwrap(); - let ids: Vec<_> = drain.collect(); - assert_eq!(ids, &[10, 1, 2, 3]); - } - } - - //-------------// - // Test Delays // - //-------------// - - macro_rules! tester { - ($struct:ident, $trait:ident, $($with:ident => $f:ident),* $(,)?) => { - #[derive(Default)] - struct $struct<'a> { - $($f: Option>,)* - } - - impl<'a> $struct<'a> { - $( - fn $with(mut self, f: F) -> Self - where - F: FnMut() + Send + 'a - { - self.$f = Some(Box::new(f)); - self - } - )* - } - - impl $trait for $struct<'_> { - $( - fn $f(&mut self) { - if let Some(f) = self.$f.as_mut() { - f() - } - } - )* - } - } - } - - tester! { - TestRegisterDelay, - RegisterDelay, - with_post_register_check => post_register_check, - with_pre_cas => pre_cas, - with_post_cas => post_cas, - with_pre_fence => pre_fence, - with_post_fence => post_fence, - } - - // #[derive(Default)] - // struct TestRegisterDelay<'a> { - // post_register_check: Option<&'a mut dyn FnMut()>, - // pre_cas: Option<&'a mut dyn FnMut()>, - // pre_fence: Option<&'a mut dyn FnMut()>, - // post_fence: Option<&'a mut dyn FnMut()>, - // } - - // macro_rules! builder { - // ($f:ident, $field:ident) => { - // fn $f(mut self, f: &'a mut dyn FnMut()) -> Self { - // self.$field = Some(f); - // self - // } - // } - // } - - // macro_rules! forward { - // ($f:ident) => { - // fn $f(&mut self) { - // if let Some(f) = self.$f.as_mut() { - // f() - // } - // } - // } - // } - - // impl<'a> TestRegisterDelay<'a> { - // builder!(with_post_register_check, post_register_check); - // builder!(with_pre_cas, pre_cas); - // builder!(with_pre_fence, pre_fence); - // builder!(with_post_fence, post_fence); - // } - - // impl RegisterDelay for TestRegisterDelay<'_> { - // forward!(post_register_check); - // forward!(pre_cas); - // forward!(pre_fence); - // forward!(post_fence); - // } - - // struct CanAdvanceDelay; - - // impl CanAdvanceDelay for TestWaitingDelay {} - - // struct TestTryAdvanceDelay; - - // impl TryAdvanceDelay for TestTryAdvanceDelay {} -} diff --git a/diskann-inmem/src/arbiter/freelist.rs b/diskann-inmem/src/arbiter/freelist.rs deleted file mode 100644 index 84321d3ee..000000000 --- a/diskann-inmem/src/arbiter/freelist.rs +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{ - num::NonZeroU32, - sync::atomic::{AtomicU32, Ordering}, -}; - -use crossbeam_queue::ArrayQueue; -use diskann::utils::IntoUsize; - -const SCAN_SIZE: u32 = 16; - -#[derive(Debug)] -pub struct Freelist { - recycled: ArrayQueue, - - /// The highest ID the freelist manages. This is used when in "append" to determine the - /// maximum ID we can return this way. - max: u32, - - /// The number of "unallocated" IDs remaining. - current: AtomicU32, - - scan_cursor: AtomicU32, -} - -#[derive(Debug, Clone, Copy)] -pub enum Id { - Found(u32), - Scan, -} - -impl Freelist { - pub fn new(max: u32, capacity: NonZeroU32) -> Self { - Self { - recycled: ArrayQueue::new(capacity.get().into_usize()), - max, - current: AtomicU32::new(0), - scan_cursor: AtomicU32::new(0), - } - } - - pub fn pop(&self) -> Id { - if let Some(id) = self.recycled.pop() { - return Id::Found(id); - } - - // Missed in the recycled buffer. Try pulling from the high-water mark. - let mut current = self.current.load(Ordering::Relaxed); - while current != self.max { - match self.current.compare_exchange( - current, - current + 1, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(current) => return Id::Found(current), - Err(actual) => { - current = actual; - } - } - } - - // Missed in the recycle bin and from unallocated IDs. Time to indicate a scan. - Id::Scan - } - - pub fn pop_recycled(&self) -> Option { - self.recycled.pop() - } - - pub fn scan(&self) -> Scan { - let current = self.scan_cursor.fetch_add(SCAN_SIZE, Ordering::Relaxed) % self.max; - Scan { - current, - max: self.max, - len: SCAN_SIZE.into_usize(), - } - } - - /// Attempt to push `id` into the recycled list. Return `true` if `id` was - /// inserted. If `false` is returned, it is likely because the internal recycle - /// buffer is full. - pub fn push(&self, id: u32) -> bool { - match self.recycled.push(id) { - Ok(()) => true, - Err(_) => false, - } - } - - /// Append items from `itr` into the recycled buffer. Return the number of items - /// actually added. - pub fn append(&self, itr: I) -> usize - where - I: IntoIterator, - { - let mut itr = itr.into_iter(); - let mut count = 0; - while let Some(id) = itr.next() { - if let Err(_) = self.recycled.push(id) { - break; - } else { - count += 1; - } - } - - count - } - - //----------// - // Internal // - //----------// - - fn capacity(&self) -> usize { - self.recycled.capacity() - } -} - -#[derive(Debug)] -pub struct Scan { - current: u32, - max: u32, - len: usize, -} - -impl Iterator for Scan { - type Item = u32; - fn next(&mut self) -> Option { - if self.len == 0 { - None - } else { - let i = self.current; - self.current += 1; - self.len -= 1; - if self.current == self.max { - self.current -= self.max; - } - Some(i) - } - } - - fn size_hint(&self) -> (usize, Option) { - (self.len, Some(self.len)) - } -} - -impl ExactSizeIterator for Scan {} diff --git a/diskann-inmem/src/arbiter/generation.rs b/diskann-inmem/src/arbiter/generation.rs deleted file mode 100644 index cbb31994b..000000000 --- a/diskann-inmem/src/arbiter/generation.rs +++ /dev/null @@ -1,298 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::sync::atomic::{AtomicU64, Ordering}; - -/// An atomic [`Generation`] tag. -/// -/// Access is performed through [`Ref`] and [`Mut`]. -#[derive(Debug)] -#[repr(transparent)] -pub struct Tag(AtomicU64); - -impl Tag { - /// Construct a new [`Tag`] initialized to `generation`. - pub const fn new(generation: Generation) -> Self { - Self(AtomicU64::new(generation.value())) - } - - /// Return a read-only [`Ref`] to `self`. - pub fn as_ref(&self) -> Ref<'_> { - Ref::new(&self.0) - } - - /// Return a read-write [`Mut`] to `self`. - pub fn as_mut(&self) -> Mut<'_> { - Mut::new(&self.0) - } - - /// Creates a new reference to a `Tag` from a raw pointer. - /// - /// # Safety - /// - /// * `ptr` must be aligned to `align_of::()`. - /// * `ptr` must be valid for both reads and writes for the whole lifetime `'a`. - /// * This must adhere to the memory model for atomic accesses. In particular, it must - /// not admit conflicting atomic and non-atomic accesses, or atomic accesses of - /// different sizes without synchronization. - /// - /// See: - pub unsafe fn from_ptr<'a>(ptr: *mut Tag) -> &'a Self { - unsafe { &*ptr } - } -} - -/// A generation tag for controlling concurrent access to data. -/// -/// Generally, generations are decremented from `Generation::MAX`, with higher values -/// representing older generations. This allows zero to stand for "unused" as it is newer -/// than any valid generation. -/// -/// Certain low-numbered generations are reserved for special uses. Any generation for which -/// [`Generation::is_reserved`] returns `true` is reserved. -/// -/// # Reserved Generations -/// -/// * [`Generation::AVAILABLE`]: The associated slot is not currently storing valid data -/// and is available to use. -/// -/// To acquire ownership, an atomic compare-exchange must be used away from this state. -/// -/// * [`Generation::OWNED`]: The associated data is owned by some thread. Only the thread -/// owning this slot may update it. -/// -/// Note that ownership may be transferred between threads as long as this ownership -/// transfer is unambiguous and properly synchronized. -/// -/// * [`Generation::FROZEN`]: This data is protected and is not expected to be mutated. -/// -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[repr(transparent)] -pub struct Generation(u64); - -impl Generation { - /// The maximum generation. This is the oldest possible generation. - pub const MAX: Self = Self::new(u64::MAX); - - // Reserved generations. - // - // These all have small values, with `0` marking the "available" state. - // In this way, zeroed allocations for tags naturally begin in the "available" state and - // don't require additional initialization. - // - // If you add states - make sure to increment the `RESERVED` marker! - - /// See [`Generation`]. - pub const AVAILABLE: Self = Self::new(0); - - /// See [`Generation`]. - pub const OWNED: Self = Self::new(1); - - /// The maximum reserved generation. See [`Generation`]. - const RESERVED: Self = Self::OWNED; - - /// See [`Generation`]. - pub const FROZEN: Self = Self::MAX; - - /// Return `true` if `self` belongs to a reserved generation. - #[must_use = "this function has no side-effects"] - pub(crate) fn is_reserved(self) -> bool { - (self <= Self::RESERVED) || (self == Self::FROZEN) - } - - /// Construct a new [`Generation`] with `value`. - #[inline] - pub const fn new(value: u64) -> Self { - Self(value) - } - - /// Return the value of `self`. - #[inline] - pub const fn value(self) -> u64 { - self.0 - } - - pub(in crate::arbiter) fn max(self, other: Self) -> Self { - Self(self.0.max(other.0)) - } - - #[cfg(test)] - const fn add(self, v: u64) -> Self { - Self(self.0 + v) - } - - pub(in crate::arbiter) const fn sub(self, v: u64) -> Self { - Self(self.0 - v) - } -} - -/// A read-only handle to a [`Tag`]. -/// -/// Provides atomic load access to the underlying generation value. -#[derive(Debug, Clone, Copy)] -#[repr(transparent)] -pub struct Ref<'a>(&'a AtomicU64); - -impl<'a> Ref<'a> { - #[inline] - pub(crate) fn new(slot: &'a AtomicU64) -> Self { - Self(slot) - } - - #[inline] - fn inner(&self) -> &'a AtomicU64 { - self.0 - } - - /// Load the current [`Generation`] with the given ordering. - #[inline] - pub fn get(&self, ordering: Ordering) -> Generation { - Generation::new(self.0.load(ordering)) - } -} - -impl std::fmt::Display for Generation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let me = *self; - if me == Self::AVAILABLE { - f.write_str("Generation(AVAILABLE)") - } else if me == Self::OWNED { - f.write_str("Generation(OWNED)") - } else if me == Self::FROZEN { - f.write_str("Generation(FROZEN)") - } else { - write!(f, "Generation({})", me.value()) - } - } -} - -/// A read-write handle to a [`Tag`]. -/// -/// Provides atomic store and compare-exchange access in addition to the read access -/// inherited from [`Ref`] via [`Deref`](std::ops::Deref). -#[derive(Debug, Clone, Copy)] -#[repr(transparent)] -pub struct Mut<'a>(Ref<'a>); - -impl<'a> Mut<'a> { - #[inline] - pub(crate) fn new(slot: &'a AtomicU64) -> Self { - Self(Ref::new(slot)) - } - - /// Attempt to atomically update the generation from `current` to `new`. - /// - /// Returns `Ok(current)` on success, or `Err(actual)` if the value was not `current`. - #[inline] - pub fn try_set( - &self, - current: Generation, - new: Generation, - success: Ordering, - failure: Ordering, - ) -> Result { - self.inner() - .compare_exchange(current.value(), new.value(), success, failure) - .map(Generation::new) - .map_err(Generation::new) - } - - #[inline] - pub fn fetch_decrement(&self, ordering: Ordering) -> Generation { - Generation::new(self.inner().fetch_sub(1, ordering)) - } - - #[inline] - pub fn fetch_min(&self, generation: Generation, ordering: Ordering) -> Generation { - Generation::new(self.inner().fetch_min(generation.value(), ordering)) - } - - /// Atomically store a [`Generation`] with the given ordering. - #[inline] - pub fn set(&self, generation: Generation, ordering: Ordering) { - self.inner().store(generation.value(), ordering) - } -} - -impl<'a> std::ops::Deref for Mut<'a> { - type Target = Ref<'a>; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - use std::{sync::Barrier, thread}; - - use crate::{ - arbiter::Buffer, - num::{Align, Bytes}, - }; - - fn spin_decrement(m: Mut<'_>, count: usize) { - for i in 0..count { - let mut current = m.get(Ordering::Relaxed); - while let Err(c) = m.try_set( - current, - current.sub(1), - Ordering::Relaxed, - Ordering::Relaxed, - ) { - current = c; - } - } - } - - #[test] - fn test_atomic() { - let threads = 4; - let barrier = &Barrier::new(threads); - - // This dance basically verifies that we can view the tag though a proper-aligned - // raw pointer. - let buffer = Buffer::new(1, Bytes::size_of::(), Align::of::()).unwrap(); - let ptr = buffer.get(0).unwrap().as_mut_ptr().cast::(); - - { - let tag = unsafe { Tag::from_ptr(ptr) }; - tag.as_mut().set(Generation::MAX, Ordering::Relaxed); - } - - let count = 1000; - thread::scope(|s| { - for i in 0..threads { - s.spawn(|| { - // Re-derive `p` to avoid issues with `Send`. - let p = buffer.get(0).unwrap().as_mut_ptr().cast::(); - let tag = unsafe { Tag::from_ptr(p) }; - barrier.wait(); - spin_decrement(tag.as_mut(), count); - }); - } - }); - - { - let tag = unsafe { Tag::from_ptr(ptr) }; - let g = tag.as_ref().get(Ordering::Relaxed); - assert_eq!(g, Generation::MAX.sub((count * threads) as u64)); - } - } - - #[test] - fn test_is_reserved() { - assert!(Generation::AVAILABLE.is_reserved()); - assert!(Generation::OWNED.is_reserved()); - assert!(!Generation::OWNED.add(1).is_reserved()); - - assert!(Generation::FROZEN.is_reserved()); - } -} diff --git a/diskann-inmem/src/arbiter/buffer.rs b/diskann-inmem/src/buffer.rs similarity index 84% rename from diskann-inmem/src/arbiter/buffer.rs rename to diskann-inmem/src/buffer.rs index 4a50b22c0..4bee9fb07 100644 --- a/diskann-inmem/src/arbiter/buffer.rs +++ b/diskann-inmem/src/buffer.rs @@ -16,7 +16,7 @@ use crate::num::{Align, Bytes}; /// /// Note that `Buffer` is unconditionally `Send` and `Sync`. #[derive(Debug)] -pub struct Buffer { +pub(crate) struct Buffer { ptr: NonNull, stride: Bytes, entries: usize, @@ -32,7 +32,7 @@ impl Buffer { /// /// Returns an error if the number of bytes `bytes_per_entry * entries` rounded up to /// the next multiple of `align` exceeds `isize::MAX`. - pub fn new(entries: usize, bytes_per_entry: Bytes, align: Align) -> Result { + pub(crate) fn new(entries: usize, bytes_per_entry: Bytes, align: Align) -> Result { // If we overflow `usize::MAX`, we will definitely overflow `isize::MAX`. let bytes = bytes_per_entry.checked_mul(entries).ok_or(BufferError)?; @@ -63,45 +63,16 @@ impl Buffer { /// Return the number of entries in this [`Buffer`]. #[inline] - pub fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.entries } /// Return the number of bytes for each entry. #[inline] - pub fn stride(&self) -> Bytes { + pub(crate) fn stride(&self) -> Bytes { self.stride } - /// Return the minimum alignment of the base pointer for the buffer. - #[inline] - pub fn align(&self) -> Align { - Align::from_layout(self.layout) - } - - /// Return the result of `self.len() == 0`. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Return the `i`th entry if `i < self.len()`. - /// - /// The returned [`RawSlice`] is guaranteed to have a length of [`Self::stride`] and - /// begin at `self.as_ptr().add(self.stride().value() * i)`. - #[inline] - pub fn get(&self, i: usize) -> Option> { - if i >= self.entries { - None - } else { - // SAFETY: We have validated that `i < self.entries`. This does two things: - // - // 1. Ensure that the multiplication will not overflow. - // 2. Ensures that the computed offset is within the original allocation. - Some(unsafe { self.get_unchecked(i) }) - } - } - /// Return the `i`th entry without bounds checking. /// /// The returned [`RawSlice`] is guaranteed to have a length of [`Self::stride`] and @@ -111,7 +82,7 @@ impl Buffer { /// /// `i` must be less than [`len`](Self::len). #[inline] - pub unsafe fn get_unchecked(&self, i: usize) -> RawSlice<'_> { + pub(crate) unsafe fn get_unchecked(&self, i: usize) -> RawSlice<'_> { debug_assert!(i < self.entries); let ptr = unsafe { self.ptr.add(self.stride().value() * i) }; RawSlice { @@ -121,12 +92,21 @@ impl Buffer { } } - /// Return the base pointer of the [`Buffer`]. - /// - /// If the requested allocation was non-zero, this is guaranteed to be a multiple of the - /// requested alignment. - #[inline] - pub fn as_ptr(&self) -> *const u8 { + #[cfg(test)] + pub(crate) fn get(&self, i: usize) -> Option> { + if i >= self.entries { + None + } else { + // SAFETY: We have validated that `i < self.entries`. This does two things: + // + // 1. Ensure that the multiplication will not overflow. + // 2. Ensures that the computed offset is within the original allocation. + Some(unsafe { self.get_unchecked(i) }) + } + } + + #[cfg(test)] + fn as_ptr(&self) -> *const u8 { self.ptr.as_ptr().cast_const() } } @@ -144,7 +124,7 @@ impl Drop for Buffer { #[derive(Debug)] #[non_exhaustive] -pub struct BufferError; +pub(crate) struct BufferError; impl std::fmt::Display for BufferError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -169,7 +149,7 @@ unsafe impl Sync for Buffer {} /// /// This has borrowing semantics of a raw pointer. #[derive(Debug)] -pub struct RawSlice<'a> { +pub(crate) struct RawSlice<'a> { ptr: NonNull, len: Bytes, _lifetime: PhantomData<&'a ()>, @@ -195,7 +175,7 @@ impl<'a> RawSlice<'a> { /// Create a new slice to the first `n.min(self.len())` bytes of `self`. #[inline] - pub fn truncate(&self, n: Bytes) -> RawSlice<'a> { + pub(crate) fn truncate(&self, n: Bytes) -> RawSlice<'a> { // SAFETY: The `min` operation ensures we provide an argument <= `self.len()`. unsafe { self.truncate_unchecked(self.len.min(n)) } } @@ -213,25 +193,10 @@ impl<'a> RawSlice<'a> { unsafe { Self::new(self.ptr, n) } } - /// Create a new slice skipping the first `n.min(self.len())` bytes of self. - #[inline] - pub fn skip(&self, n: Bytes) -> RawSlice<'a> { - let advance_by = self.len.min(n); - - // SAFETY: `advance_by <= self.len()`, so the pointer offset is valid and the - // `unchecked_sub` cannot underflow. - unsafe { - Self::new( - self.ptr.add(advance_by.value()), - self.len.unchecked_sub(advance_by), - ) - } - } - /// Split `self` into two as `([ptr, ptr.add(m)), [ptr.add(m), ptr.add(self.len())))` /// where `m = n.min(self.len())`. #[inline] - pub fn split(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { + pub(crate) fn split(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { // SAFETY: The argument is <= `self.len()`. unsafe { self.split_unchecked(self.len.min(n)) } } @@ -254,23 +219,17 @@ impl<'a> RawSlice<'a> { /// Return the length of the slice in bytes. #[inline] - pub fn len(&self) -> Bytes { + pub(crate) fn len(&self) -> Bytes { self.len } - /// Return the result of `self.len() == 0`. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == Bytes::new(0) - } - /// Return the base [`NonNull`] pointer of the slice. - pub fn as_non_null(&self) -> NonNull { + pub(crate) fn as_non_null(&self) -> NonNull { self.ptr } /// Return the base pointer of the slice as `*const u8`. - pub fn as_ptr(&self) -> *const u8 { + pub(crate) fn as_ptr(&self) -> *const u8 { self.ptr.as_ptr().cast_const() } @@ -278,7 +237,7 @@ impl<'a> RawSlice<'a> { /// /// This returns a mutable pointer regardless of the receiver's mutability, matching /// the raw-pointer semantics of [`RawSlice`]. - pub fn as_mut_ptr(&self) -> *mut u8 { + pub(crate) fn as_mut_ptr(&self) -> *mut u8 { self.ptr.as_ptr() } @@ -292,7 +251,7 @@ impl<'a> RawSlice<'a> { /// However, it is the responsibility of the caller to ensure that materializing this /// slice does not violate Rust's borrowing rules. #[inline] - pub unsafe fn as_slice(&self) -> &'a [u8] { + pub(crate) unsafe fn as_slice(&self) -> &'a [u8] { unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len.value()) } } @@ -306,7 +265,7 @@ impl<'a> RawSlice<'a> { /// However, it is the responsibility of the caller to ensure that materializing this /// slice does not violate Rust's borrowing rules. #[inline] - pub unsafe fn as_mut_slice(&mut self) -> &'a mut [u8] { + pub(crate) unsafe fn as_mut_slice(&mut self) -> &'a mut [u8] { unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len.value()) } } } @@ -349,7 +308,6 @@ mod tests { // Initial Checks assert_eq!(buffer.len(), entries, "{}", ctx); assert_eq!(buffer.stride(), bytes_per_entry, "{}", ctx); - assert_eq!(buffer.align(), align, "{}", ctx); if entries != 0 && !bytes_per_entry.is_zero() { let addr = buffer.as_ptr() as usize; @@ -361,12 +319,6 @@ mod tests { ); } - if entries == 0 { - assert!(buffer.is_empty(), "{}", ctx); - } else { - assert!(!buffer.is_empty(), "{}", ctx); - } - // Verify zero initialization assert_is_zeroed(&mut buffer, &ctx); @@ -415,12 +367,6 @@ mod tests { ctx ); - if raw_slice.len().is_zero() { - assert!(raw_slice.is_empty()); - } else { - assert!(!raw_slice.is_empty()); - } - let slice = unsafe { raw_slice.as_slice() }; assert_eq!(slice.len(), buffer.stride().value()); assert!(slice.iter().all(|&i| i == 0), "{}", ctx); @@ -454,19 +400,6 @@ mod tests { assert!(is_iota(unsafe { truncated.as_slice() }, base), "{}", ctx); } - // skip // - - for i in 0..raw.len().value() + base_usize { - let expected = raw.len().value() - i.min(raw.len().value()); - let skipped = raw.skip(Bytes::new(i)); - assert_eq!(skipped.len().value(), expected, "{}", ctx); - assert!( - is_iota(unsafe { skipped.as_slice() }, base.wrapping_add(i as u8)), - "{}", - ctx - ); - } - // split // for i in 0..raw.len().value() + base_usize { diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index d38ea0229..9edf3f88a 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -70,11 +70,13 @@ where } } -impl<'a, T> layers::Search<'a, &'a [T]> for Full +impl layers::Search for Full where T: std::fmt::Debug + Send + Sync + 'static, { - fn query_distance(&'a self, query: &'a [T]) -> ANNResult> { + type Query<'a> = &'a [T]; + + fn query_distance<'a>(&'a self, query: &'a [T]) -> ANNResult> { Ok(Box::new(QueryDistance::new(self.distance, query))) } } @@ -88,23 +90,11 @@ where } } -impl<'a, T> layers::Insert<'a, &'a [T]> for Full where +impl layers::Insert for Full where T: bytemuck::Pod + std::fmt::Debug + Send + Sync { } -// impl<'a, T> layers::Insert<'a, &'a [T]> for Full -// where -// T: bytemuck::Pod + std::fmt::Debug + Send + Sync, -// { -// fn search_distance( -// &'a self, -// query: &'a [T], -// ) -> ANNResult> { -// Ok(Box::new(QueryDistance::new(self.distance, query))) -// } -// } - ////////////// // Distance // ////////////// diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index 95be3f604..1a82e9d9a 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -11,6 +11,33 @@ use crate::num::Bytes; pub(crate) mod full; pub use full::Full; +pub trait AddLifetime: Send + Sync + 'static { + type Of<'a>: Send + Sync; +} + +#[derive(Debug)] +pub struct Slice(std::marker::PhantomData); + +impl Slice { + pub fn new() -> Self { + Self(std::marker::PhantomData) + } +} + +impl Clone for Slice { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Slice {} + +impl Default for Slice { + fn default() -> Self { + Self::new() + } +} + pub trait Distance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult; } @@ -38,12 +65,24 @@ pub trait Layer: Send + Sync + 'static { pub trait Set: Layer { /// Write into the stored representation. - fn into_bytes<'a>(&self, element: T, bytes: &'a mut [u8]) -> ANNResult<()>; + fn into_bytes(&self, element: T, bytes: &mut [u8]) -> ANNResult<()>; } // Meta traits for `Search` and `Insert` compatibility. -pub trait Search<'a, T>: Send + Sync + 'static { - fn query_distance(&'a self, query: T) -> ANNResult>; +pub trait Search: Send + Sync + 'static { + type Query<'a>; + + fn query_distance<'a>( + &'a self, + query: Self::Query<'a>, + ) -> ANNResult>; } -pub trait Insert<'a, T>: Search<'a, T> + Set + AsDistance {} +pub trait Insert: Search + for<'a> Set> + AsDistance { + fn insert_distance<'a>( + &'a self, + query: Self::Query<'a>, + ) -> ANNResult> { + self.query_distance(query) + } +} diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index b6d4e14bd..c8d438a5f 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -3,8 +3,9 @@ * Licensed under the MIT license. */ -mod arbiter; +mod buffer; pub mod num; +mod sync; pub mod ids; pub mod layers; @@ -15,3 +16,6 @@ pub mod provider; pub use neighbors::Neighbors; pub use provider::{Context, Provider, Strategy}; + +#[cfg(test)] +mod test; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index f3cd52dc7..7138433f7 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -3,17 +3,28 @@ * Licensed under the MIT license. */ +use std::ptr::NonNull; + use diskann::{graph::AdjacencyList, utils::IntoUsize}; use parking_lot::{RwLock, RwLockWriteGuard}; use thiserror::Error; use crate::{ - arbiter::Buffer, + buffer::{Buffer, BufferError}, num::{Align, Bytes}, }; type Id = u32; +/// Locks are shared among groups of adjacency lists. +/// +/// Adjacency lists whose indices map to the same lock group (i.e. `i / LOCK_GRANULARITY`) +/// share a single `RwLock`. This means that holding a [`Lock`] on slot `i` will also block +/// operations on any slot `j` in the same group. +/// +/// **Deadlock hazard**: attempting to acquire two [`Lock`]s simultaneously can deadlock if +/// they fall in the same lock group — or even across groups, depending on acquisition order. +/// Callers must not hold more than one [`Lock`] at a time. const LOCK_GRANULARITY: usize = 16; fn lock_index(i: usize) -> usize { @@ -23,19 +34,41 @@ fn lock_index(i: usize) -> usize { #[derive(Debug)] pub struct Neighbors { neighbors: Buffer, - // One lock for each slot in `neighbors`. locks: Vec>, } impl Neighbors { - pub fn new(entries: usize, max_length: usize) -> Self { - let bytes = Bytes::new((max_length + 1) * std::mem::size_of::()); - let neighbors = Buffer::new(entries, bytes, Align::_128).unwrap(); + pub fn new(entries: usize, max_length: usize) -> Result { + // This is exceedingly unlikely and + if max_length > (u32::MAX).into_usize() { + return Err(NeighborsError::AdjacencyListTooLong(max_length)); + } + + let bytes = max_length + .checked_add(1) + .and_then(|len| len.checked_mul(std::mem::size_of::())) + .map(Bytes::new) + .ok_or(NeighborsError::AdjacencyListTooLong(max_length))?; + + // We materialize slices of `Id` into the raw byte buffers. + // + // To make this sound, the base allocation must be that of `Id` so the slice + // materialization is properly aligned. + const ALIGN: Align = Align::_128; + const { + assert!( + ALIGN.value() >= Align::of::().value(), + "buffer alignment must be at least that of the ID" + ); + } + + let neighbors = Buffer::new(entries, bytes, ALIGN)?; + let locks = std::iter::repeat_with(|| RwLock::new(())) .take(entries.div_ceil(LOCK_GRANULARITY)) .collect(); - Self { neighbors, locks } + Ok(Self { neighbors, locks }) } /// Return the maximum length for any adjacency list. @@ -61,7 +94,7 @@ impl Neighbors { unsafe { self.neighbors.get_unchecked(i) }.split(Bytes::size_of::()); debug_assert_eq!(prefix.len(), Bytes::size_of::()); - debug_assert!(prefix.as_ptr().is_aligned()); + debug_assert!(prefix.as_ptr().cast::().is_aligned()); // SAFETY: We hold the read-lock, so reading is safe. From our bounds checks, we // know that this pointer is valid. @@ -93,16 +126,13 @@ impl Neighbors { // `self.locks` and we have already checked that `i` is in-bounds there. let slice = unsafe { self.neighbors.get_unchecked(i) }; - debug_assert!(slice.as_ptr().is_aligned()); - - let raw = unsafe { - std::slice::from_raw_parts_mut( - slice.as_mut_ptr().cast::(), - slice.len().value() / std::mem::size_of::(), - ) - }; + debug_assert!(slice.as_ptr().cast::().is_aligned()); - Lock { raw, lock } + Lock { + ptr: slice.as_non_null().cast::(), + capacity: self.max_length(), + _lock: lock, + } } pub fn set(&self, i: usize, neighbors: &[u32]) -> Result<(), SetError> { @@ -140,22 +170,19 @@ pub enum SetError { } /// A locked adjacency list to implement atomic read-modify-write operations. +/// +/// Callers must not hold more than one `Lock` at a time. See [`LOCK_GRANULARITY`] for +/// details on the deadlock hazard. pub struct Lock<'a> { - // The raw adjacency list with the actual length stored as the first element. - // - // This **must** have a length of at least one. - // - // Also, `raw.len()` must be less than `u32::MAX`. - raw: &'a mut [u32], - // VERY IMPORTANT: `lock` has to be **after** `raw` because `lock` is guarding `raw` - // and thus must be dropped **after** `raw`. - lock: RwLockWriteGuard<'a, ()>, + ptr: NonNull, + capacity: usize, + _lock: RwLockWriteGuard<'a, ()>, } impl Lock<'_> { /// Return the capacity of the neighbor buffer. pub fn capacity(&self) -> usize { - self.raw.len() - 1 + self.capacity } /// Return the current length of the neighbor list. @@ -165,15 +192,18 @@ impl Lock<'_> { // SAFETY: By construction, `self.raw` has a length of at least 1. // // The `min` operation is to be conservative. - unsafe { self.raw.get_unchecked(0) } - .into_usize() - .min(self.capacity()) + unsafe { self.ptr.read() }.into_usize().min(self.capacity()) + } + + /// Return `true` only if `self.len() == 0`. + pub fn is_empty(&self) -> bool { + self.len() == 0 } /// View the current contents of the locked adjacency list as a slice. pub fn as_slice(&self) -> &[u32] { let len = self.len(); - unsafe { self.raw.get_unchecked(1..len + 1) } + unsafe { std::slice::from_raw_parts(self.ptr.add(1).as_ptr().cast_const(), len) } } /// Consume the [`Lock`] - copying the contents of `neighbors`. @@ -197,32 +227,353 @@ impl Lock<'_> { return Err(TooLong); } - unsafe { self.raw.get_unchecked_mut(len..newlen) }.copy_from_slice(neighbors); - *unsafe { self.raw.get_unchecked_mut(0) } = newlen as u32; + unsafe { + std::ptr::copy_nonoverlapping( + neighbors.as_ptr(), + self.ptr.add(len + 1).as_ptr(), + neighbors.len(), + ) + } + + unsafe { self.ptr.write(newlen as u32) }; Ok(()) - // `self.raw` is dropped first, then `self.lock` which was guarding it. } unsafe fn write_unchecked(self, neighbors: &[u32]) { let len = neighbors.len(); debug_assert!(len <= self.capacity()); - unsafe { - std::ptr::copy_nonoverlapping(neighbors.as_ptr(), self.raw.as_mut_ptr().add(1), len) - } - *unsafe { self.raw.get_unchecked_mut(0) } = len as u32; - // `self.raw` is dropped first, then `self.lock` which was guarding it. + unsafe { std::ptr::copy_nonoverlapping(neighbors.as_ptr(), self.ptr.as_ptr().add(1), len) } + unsafe { self.ptr.write(len as u32) }; } } impl std::fmt::Debug for Lock<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Lock") - .field("raw", &self.raw) + .field("ptr", &self.ptr) + .field("capacity", &self.capacity) .field("lock", &()) .finish() } } +#[derive(Debug, Error)] +pub enum NeighborsError { + #[error("adjacency list length of {} is too long", 0)] + AdjacencyListTooLong(usize), + #[error("neighbor bufffer allocation failed")] + AllocationFailed(#[from] BufferError), +} + #[derive(Debug, Clone, Copy, Error)] #[error("too long")] pub struct TooLong; + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test::Sequencer; + + // Constructor errors + + #[test] + fn new_rejects_max_length_exceeding_u32_max() { + let result = Neighbors::new(10, (u32::MAX as usize) + 1); + assert!(matches!( + result, + Err(NeighborsError::AdjacencyListTooLong(_)) + )); + } + + #[test] + fn new_rejects_allocation_overflow() { + // entries * (max_length + 1) * sizeof(Id) overflows. + let result = Neighbors::new(usize::MAX, 64); + assert!(matches!(result, Err(NeighborsError::AllocationFailed(_)))); + } + + // TooLong errors + + #[test] + fn set_rejects_oversized_neighbors() { + let n = Neighbors::new(4, 3).unwrap(); + let too_many = &[1, 2, 3, 4]; + assert!(matches!(n.set(0, too_many), Err(SetError::TooLong(_)))); + } + + #[test] + fn lock_write_rejects_oversized_neighbors() { + let n = Neighbors::new(4, 3).unwrap(); + let lock = n.lock(0).unwrap(); + assert!(lock.write(&[1, 2, 3, 4]).is_err()); + } + + #[test] + fn lock_append_rejects_overflow() { + let n = Neighbors::new(4, 3).unwrap(); + n.set(0, &[1, 2]).unwrap(); + let lock = n.lock(0).unwrap(); + assert!(lock.append(&[3, 4]).is_err()); + } + + #[test] + fn lock_implements_debug() { + let n = Neighbors::new(4, 3).unwrap(); + let lock = n.lock(0).unwrap(); + let _ = format!("{:?}", lock); + } + + // -- Lock::append -- + + #[test] + fn append_preserves_existing_and_adds_new() { + let n = Neighbors::new(4, 6).unwrap(); + n.set(0, &[10, 20]).unwrap(); + + let lock = n.lock(0).unwrap(); + assert_eq!(lock.as_slice(), &[10, 20]); + lock.append(&[30, 40, 50]).unwrap(); + + let mut out = AdjacencyList::with_capacity(6); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[10, 20, 30, 40, 50]); + } + + #[test] + fn append_to_empty() { + let n = Neighbors::new(4, 4).unwrap(); + + let lock = n.lock(0).unwrap(); + assert_eq!(lock.as_slice(), &[]); + lock.append(&[1, 2, 3]).unwrap(); + + let mut out = AdjacencyList::with_capacity(4); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[1, 2, 3]); + } + + #[test] + fn append_fills_to_capacity() { + let n = Neighbors::new(1, 3).unwrap(); + n.set(0, &[1]).unwrap(); + + let lock = n.lock(0).unwrap(); + lock.append(&[2, 3]).unwrap(); + + let mut out = AdjacencyList::with_capacity(3); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[1, 2, 3]); + } + + #[test] + fn append_empty_slice_is_noop() { + let n = Neighbors::new(1, 4).unwrap(); + n.set(0, &[10, 20]).unwrap(); + + let lock = n.lock(0).unwrap(); + lock.append(&[]).unwrap(); + + let mut out = AdjacencyList::with_capacity(4); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[10, 20]); + } + + #[test] + fn write_overwrites_longer_list() { + let n = Neighbors::new(1, 5).unwrap(); + n.set(0, &[1, 2, 3, 4, 5]).unwrap(); + + // Overwrite with a shorter list. + let lock = n.lock(0).unwrap(); + assert_eq!(lock.len(), 5); + lock.write(&[99]).unwrap(); + + // The length must reflect the new shorter list, not the old one. + let mut out = AdjacencyList::with_capacity(5); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[99]); + } + + // Clear the adjacency list in `neighbors`. + // + // Receives by `&mut` to ensure exclusivity. + fn clear(neighbors: &mut Neighbors) { + for i in 0..neighbors.entries() { + neighbors.set(i, &[]).unwrap(); + } + + assert_is_cleared(neighbors); + } + + fn assert_is_cleared(neighbors: &mut Neighbors) { + for i in 0..neighbors.entries() { + assert!(neighbors.lock(i).unwrap().is_empty()); + } + } + + #[test] + fn basic_test() { + let mut neighbors = Neighbors::new(10, 4).unwrap(); + assert_eq!(neighbors.entries(), 10); + assert_eq!(neighbors.max_length(), 4); + + let mut list = AdjacencyList::new(); + for i in 0..neighbors.entries() { + list.clear(); + list.extend_from_slice(&[1, 2, 3, 4]); + neighbors.get(i, &mut list).unwrap(); + assert!(list.is_empty()); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.capacity(), neighbors.max_length()); + assert_eq!(lock.len(), 0); + assert!(lock.is_empty()); + assert_eq!(lock.as_slice(), &[]); + } + + // Verify out-of-bounds accesses error. + let oob = neighbors.entries(); + assert!(matches!(neighbors.get(oob, &mut list), Err(OutOfBounds(_)))); + assert!(matches!(neighbors.lock(oob), Err(OutOfBounds(_)))); + assert!(matches!( + neighbors.set(oob, &[1, 2, 3, 4, 5, 6]), + Err(SetError::OutOfBounds(_)) + )); + + let generate = |round: usize, entry: usize| -> Vec { + (0..(round + 1)) + .map(|r| (entry + r).try_into().unwrap()) + .collect() + }; + + // Test mutation via `Neighbors::set`. + for round in 0..neighbors.max_length() { + for i in 0..neighbors.entries() { + let v = generate(round, i); + neighbors.set(i, &v).unwrap(); + } + + for i in 0..neighbors.entries() { + let expected = generate(round, i); + neighbors.get(i, &mut list).unwrap(); + assert_eq!(&*list, &*expected); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.as_slice(), &*expected); + } + } + + clear(&mut neighbors); + + // Test mutation via `lock + write`. + for round in 0..neighbors.max_length() { + for i in 0..neighbors.entries() { + let v = generate(round, i); + neighbors.lock(i).unwrap().write(&v).unwrap(); + } + + for i in 0..neighbors.entries() { + let expected = generate(round, i); + neighbors.get(i, &mut list).unwrap(); + assert_eq!(&*list, &*expected); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.as_slice(), &*expected); + } + } + + clear(&mut neighbors); + + // Test mutation via `lock + append`. + for round in 0..neighbors.max_length() { + for i in 0..neighbors.entries() { + neighbors + .lock(i) + .unwrap() + .append(&[(round + i).try_into().unwrap()]) + .unwrap(); + } + + for i in 0..neighbors.entries() { + let expected = generate(round, i); + + neighbors.get(i, &mut list).unwrap(); + assert_eq!(&*list, &*expected); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.as_slice(), &*expected); + } + } + + clear(&mut neighbors); + } + + //-------------------// + // Concurrency Tests // + //-------------------// + + // Verify that holding a `Lock` correctly blocks reads for the same adjacency list. + #[test] + fn lock_blocks_get() { + for i in 0..10 { + let neighbors = Neighbors::new(3, 4).unwrap(); + let seq = Sequencer::new(); + + std::thread::scope(|s| { + let handle = s.spawn(|| { + seq.wait_for(0); + let mut list = AdjacencyList::new(); + neighbors.get(0, &mut list).unwrap(); + list + }); + + seq.until_waiting_for(0); + let lock = neighbors.lock(0).unwrap(); + seq.advance_past(0); + + lock.write(&[1, 2, 3, 4]).unwrap(); + let list = handle.join().unwrap(); + assert_eq!(&*list, &[1, 2, 3, 4]); + }); + } + } + + #[test] + fn many_appends() { + let max_length = if cfg!(miri) { 100 } else { 1000 }; + + let neighbors = Neighbors::new(1, max_length).unwrap(); + + let num_threads = 4; + let barrier = std::sync::Barrier::new(num_threads); + + std::thread::scope(|s| { + let neighbors_ref = &neighbors; + let barrier_ref = &barrier; + + for thread_id in 0..num_threads { + s.spawn(move || { + barrier_ref.wait(); + let mut i = thread_id as u32; + let upper = neighbors_ref.max_length() as u32; + while i < upper { + neighbors_ref.lock(0).unwrap().append(&[i]).unwrap(); + i += num_threads as u32; + } + }); + } + }); + + let mut list = AdjacencyList::new(); + let expected: Vec<_> = (0..neighbors.max_length()).map(|i| i as u32).collect(); + neighbors.get(0, &mut list).unwrap(); + list.sort(); + + assert_eq!(&*list, &*expected); + } +} diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index b954521d2..4f03c9ee0 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -19,34 +19,34 @@ use diskann::{ use diskann_utils::views::Matrix; use crate::{ - arbiter::epoch, ids, layers::{self, Distance, QueryDistance}, num::Bytes, store::{self, Primary}, + sync::epoch::Unavailable, }; pub trait Id: Send + Sync + Hash + Eq + Clone + 'static {} impl Id for T where T: Send + Sync + Hash + Eq + Clone + 'static {} #[derive(Debug)] -pub struct Provider +pub struct Provider where M: Id, { primary: Primary, - layer: T, + layer: L, mapping: ids::Sharded, } -impl Provider +impl Provider where M: Id, { - pub fn new(layer: T, capacity: usize, start_points: I) -> Self + pub fn new(layer: L, capacity: usize, start_points: I) -> Self where - I: IntoIterator, - T: layers::Set, + I: IntoIterator, + L: layers::Set, { let start_points: Vec<_> = start_points.into_iter().collect(); let bytes = layers::Layer::bytes(&layer); @@ -65,7 +65,7 @@ where } } - fn reader(&self) -> Result, epoch::Unavailable> { + fn reader(&self) -> Result, Unavailable> { self.primary.reader() } } @@ -127,7 +127,7 @@ where Ok(()) } - async fn release(&self, _context: &Context, id: Self::InternalId) -> ANNResult<()> { + async fn release(&self, _context: &Context, _id: Self::InternalId) -> ANNResult<()> { Ok(()) } @@ -164,7 +164,7 @@ where fn statuses_unordered( &self, - context: &Self::Context, + _context: &Self::Context, itr: Itr, mut f: F, ) -> impl std::future::Future> + Send @@ -197,7 +197,7 @@ where impl diskann::provider::SetElement for Provider where - L: layers::Layer + layers::Set, + L: layers::Set, M: Id, { type SetError = ANNError; @@ -535,9 +535,9 @@ impl workingset::View for &PruneAccessor<'_> { #[derive(Debug, Clone, Copy)] pub struct Strategy; -impl<'a, T, L, M> glue::SearchStrategy<'a, Provider, T> for Strategy +impl<'a, L, M> glue::SearchStrategy<'a, Provider, L::Query<'a>> for Strategy where - L: layers::Search<'a, T>, + L: layers::Search, M: Id, { type SearchAccessor = SearchAccessor<'a>; @@ -547,9 +547,9 @@ where &'a self, provider: &'a Provider, _context: &'a Context, - query: T, + query: L::Query<'a>, ) -> ANNResult> { - let distance = >::query_distance(&provider.layer, query)?; + let distance = ::query_distance(&provider.layer, query)?; let reader = provider.primary.reader()?; let expand_beam = dispatch_expand_beam(reader.bytes()); let accessor = SearchAccessor { @@ -573,9 +573,9 @@ impl Default for Translate { } } -impl<'a, T, L, M> glue::SearchPostProcess, T, M> for Translate +impl<'a, L, M> glue::SearchPostProcess, L::Query<'a>, M> for Translate where - L: layers::Search<'a, T>, + L: layers::Search, M: Id, { type Error = ANNError; @@ -583,7 +583,7 @@ where fn post_process( &self, accessor: &mut SearchAccessor<'_>, - query: T, + _query: L::Query<'a>, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -611,9 +611,9 @@ where } } -impl<'a, T, L, M> glue::DefaultPostProcessor<'a, Provider, T, M> for Strategy +impl<'a, L, M> glue::DefaultPostProcessor<'a, Provider, L::Query<'a>, M> for Strategy where - L: layers::Search<'a, T>, + L: layers::Search, M: Id, { diskann::default_post_processor!(Translate); @@ -631,7 +631,7 @@ where &self, provider: &'a Provider, _context: &'a Context, - capacity: usize, + _capacity: usize, ) -> ANNResult> { Ok(PruneAccessor { reader: provider.primary.reader()?, @@ -640,9 +640,9 @@ where } } -impl<'a, L, M, T> glue::InsertStrategy<'a, Provider, T> for Strategy +impl<'a, L, M> glue::InsertStrategy<'a, Provider, L::Query<'a>> for Strategy where - L: layers::Insert<'a, T>, + L: layers::Insert, M: Id, { type PruneStrategy = Self; @@ -655,7 +655,12 @@ where impl glue::InplaceDeleteStrategy, M>> for Strategy where Self: glue::PruneStrategy, M>>, - Self: for<'a> glue::InsertStrategy<'a, Provider, M>, &'a [f32], SearchAccessor = SearchAccessor<'a>>, + Self: for<'a> glue::InsertStrategy< + 'a, + Provider, M>, + &'a [f32], + SearchAccessor = SearchAccessor<'a>, + >, M: Id, { type DeleteElement<'a> = &'a [f32]; @@ -682,9 +687,10 @@ where fn get_delete_element<'a>( &'a self, provider: &'a Provider, M>, - context: &'a Context, - id: u32 - ) -> impl Future> + Send { + _context: &'a Context, + id: u32, + ) -> impl Future> + Send + { let work = move || { let reader = provider.primary.reader().unwrap(); let mut buf: Box<[_]> = std::iter::repeat_n(0.0, provider.layer.dim()).collect(); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 8500511e5..deeadae9c 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -10,8 +10,9 @@ use diskann_utils::views::MatrixView; use crate::{ Neighbors, - arbiter::{Buffer, Freelist, Generation, RawSlice, epoch, freelist, generation}, + buffer::{Buffer, RawSlice}, num::{Align, Bytes}, + sync::{AtomicTag, Freelist, Tag, Registry, epoch, freelist}, }; #[derive(Debug)] @@ -25,13 +26,13 @@ pub struct Primary { // The number of unfrozen points. This is guaranteed to be less than `buffer`. unfrozen: usize, - tags: Vec, + tags: Vec, freelist: Freelist, - registry: epoch::Registry, + registry: Registry, neighbors: Neighbors, } -const SPLIT: Bytes = Bytes::size_of::(); +const SPLIT: Bytes = Bytes::size_of::(); const RETRY_LIMIT: usize = 20; impl Primary { @@ -53,15 +54,15 @@ impl Primary { buffer: Buffer::new(total, padded_bytes, Align::_128).unwrap(), unpadded, unfrozen: entries, - tags: repeat_n(Generation::AVAILABLE, total) - .map(|v| generation::Tag::new(v)) + tags: repeat_n(Tag::AVAILABLE, total) + .map(|v| AtomicTag::new(v)) .collect(), // NOTE: The `Freelist` is initialized to `entries` and not `total` because // we do not want it to release frozen IDs. freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), - registry: epoch::Registry::new(), - neighbors: Neighbors::new(total, max_neighbors), + registry: Registry::new(), + neighbors: Neighbors::new(total, max_neighbors).unwrap(), }; // Populate frozen points. @@ -74,28 +75,33 @@ impl Primary { this } + /// Return the range of slots containing frozen items in `self`. pub fn frozen(&self) -> std::ops::Range { (self.unfrozen as u32)..(self.buffer.len() as u32) } + /// Return the number of unfrozen slots managed by `self`. pub fn capacity(&self) -> usize { self.buffer.len() - self.unfrozen } + /// Attempt to reclaim retired slots. + /// + /// If successful, returns the number of slots reclaimed. pub fn try_drain(&self) -> Option { - fn release(tag: generation::Mut<'_>, kind: &'static str) { + fn release(tag: &AtomicTag, kind: &'static str) { // Relaxed ordering is sufficient as all readers/writers are synchronized on // the central generation. - if let Err(got) = tag.try_set( - Generation::OWNED, - Generation::AVAILABLE, + if let Err(got) = tag.compare_exchange( + Tag::RETIRING, + Tag::AVAILABLE, Ordering::Relaxed, Ordering::Relaxed, ) { panic!( "CONCURRENCY VIOLATION: {} - expected {} - got {}", kind, - Generation::AVAILABLE, + Tag::AVAILABLE, got, ); } @@ -108,18 +114,23 @@ impl Primary { // prematurely advertise availability. let (mirror, _) = unsafe { self.data_unchecked(i.into_usize()) }; release(mirror, "mirror"); - release(self.tags[i.into_usize()].as_mut(), "tag"); + release(&self.tags[i.into_usize()], "tag"); self.freelist.push(i); } Some(items) } + /// Return a [`Reader`] into the store. + /// + /// # Errors + /// + /// Returns [`epoch::Unavailable`] if there are too many active readers. pub fn reader(&self) -> Result, epoch::Unavailable> { Ok(Reader { buffer: &self.buffer, unpadded: self.unpadded, neighbors: &self.neighbors, - epoch: self.registry.register()?, + epoch: self.registry.guard()?, }) } @@ -155,12 +166,12 @@ impl Primary { remaining = remaining.saturating_sub(chunk.len()); for slot in chunk { - let tag = self.tag_mut(slot.into_usize()).unwrap(); + let tag = self.tags.get(slot.into_usize()).unwrap(); // If this slot is available and we haven't claimed a slot yet, try to // claim it. Otherwise, continue with the scan to partially repopulate the // freelist for other threads. - if tag.get(Ordering::Relaxed) == Generation::AVAILABLE { + if tag.load(Ordering::Relaxed) == Tag::AVAILABLE { if acquired.is_none() { acquired = unsafe { self.try_acquire(tag, slot) }; } else { @@ -187,14 +198,14 @@ impl Primary { } fn slot(&self, i: u32) -> Option> { - let tag = self.tag_mut(i.into_usize()).unwrap(); + let tag = &self.tags.get(i.into_usize()).unwrap(); unsafe { self.try_acquire(tag, i) } } - unsafe fn try_acquire<'a>(&'a self, tag: generation::Mut<'a>, slot: u32) -> Option> { - match tag.try_set( - Generation::AVAILABLE, - Generation::OWNED, + unsafe fn try_acquire<'a>(&'a self, tag: &'a AtomicTag, slot: u32) -> Option> { + match tag.compare_exchange( + Tag::AVAILABLE, + Tag::OWNED, Ordering::Relaxed, Ordering::Relaxed, ) { @@ -203,7 +214,6 @@ impl Primary { Some(Slot { tag, mirror, - generation: self.registry.generation(), data, slot, }) @@ -213,24 +223,24 @@ impl Primary { } pub(crate) fn delete(&self, i: usize) -> bool { - let guard = self.registry.register().unwrap(); - let tag = self.tag_mut(i).unwrap(); - let current = tag.get(Ordering::Relaxed); + let guard = self.registry.guard().unwrap(); + let tag = self.tags.get(i).unwrap(); + let current = tag.load(Ordering::Relaxed); // We can only perform a deletion if the generation is not in a reserved state. if current.is_reserved() { return false; } - let owned = Generation::OWNED; + let retiring = Tag::RETIRING; // Even if we make this change, we can't access any data until we wait for the // epoch to be bumped. As such, relaxed semantics are fine. - match tag.try_set(current, owned, Ordering::Relaxed, Ordering::Relaxed) { + match tag.compare_exchange(current, retiring, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => { // Set the metadata in the mirror as well. let (mirror, _) = unsafe { self.data_unchecked(i) }; - mirror.set(owned, Ordering::Relaxed); + mirror.store(retiring, Ordering::Relaxed); guard.retire(i as u32); true } @@ -238,21 +248,15 @@ impl Primary { } } - unsafe fn data_unchecked(&self, i: usize) -> (generation::Mut<'_>, RawSlice<'_>) { + unsafe fn data_unchecked(&self, i: usize) -> (&AtomicTag, RawSlice<'_>) { let (mirror, data) = unsafe { self.buffer.get_unchecked(i) } .truncate(self.unpadded) .split(SPLIT); ( - unsafe { generation::Tag::from_ptr(mirror.as_mut_ptr().cast()) }.as_mut(), + unsafe { AtomicTag::from_ptr(mirror.as_mut_ptr().cast()) }, data, ) } - - /// Creating a `Mut` is impossible for user code. Exposing this functionality would - /// allow user code to break all safety invariantes this data structure relies on. - fn tag_mut(&self, i: usize) -> Option> { - self.tags.get(i).map(|v| v.as_mut()) - } } #[derive(Debug)] @@ -290,19 +294,19 @@ impl<'a> Reader<'a> { return None; } - let generation = unsafe { self.buffer.get_unchecked(i).truncate_unchecked(SPLIT) }; - let generation = unsafe { generation::Tag::from_ptr(generation.as_mut_ptr().cast()) } - .as_ref() - .get(Ordering::Acquire); + let tag_ptr = unsafe { self.buffer.get_unchecked(i).truncate_unchecked(SPLIT) }; + let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.as_mut_ptr().cast()) } + .load(Ordering::Acquire) + .can_read(); - Some(generation >= self.epoch.generation()) + Some(can_read) } #[inline] pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { debug_assert!(self.is_in_bounds(i)); - let (generation, rest) = unsafe { + let (tag_ptr, rest) = unsafe { self.buffer .get_unchecked(i) .truncate_unchecked(self.unpadded) @@ -310,11 +314,11 @@ impl<'a> Reader<'a> { }; // NOTE: Must be `Acquire` to correctly synchronize with writes. - let generation = unsafe { generation::Tag::from_ptr(generation.as_mut_ptr().cast()) } - .as_ref() - .get(Ordering::Acquire); + let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.as_mut_ptr().cast()) } + .load(Ordering::Acquire) + .can_read(); - if generation >= self.epoch.generation() { + if can_read { // SAFETY: tags and buffer always have the same length, and we // verified i < tags.len() above. Some(unsafe { rest.as_slice() }) @@ -347,9 +351,8 @@ impl<'a> Reader<'a> { #[derive(Debug)] pub struct Slot<'a> { - tag: generation::Mut<'a>, - mirror: generation::Mut<'a>, - generation: Generation, + tag: &'a AtomicTag, + mirror: &'a AtomicTag, data: RawSlice<'a>, slot: u32, } @@ -366,14 +369,14 @@ impl<'a> Slot<'a> { fn freeze(self) { let me = std::mem::ManuallyDrop::new(self); - me.mirror.set(Generation::FROZEN, Ordering::Release); - me.tag.set(Generation::FROZEN, Ordering::Release); + me.mirror.store(Tag::FROZEN, Ordering::Release); + me.tag.store(Tag::FROZEN, Ordering::Release); } } impl Drop for Slot<'_> { fn drop(&mut self) { - self.mirror.set(self.generation, Ordering::Release); - self.tag.set(self.generation, Ordering::Release); + self.mirror.store(Tag::PUBLISHED, Ordering::Release); + self.tag.store(Tag::PUBLISHED, Ordering::Release); } } diff --git a/diskann-inmem/src/sync/epoch.rs b/diskann-inmem/src/sync/epoch.rs new file mode 100644 index 000000000..b36c95403 --- /dev/null +++ b/diskann-inmem/src/sync/epoch.rs @@ -0,0 +1,911 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! The core logic for the epoch-based reclamation algorithm. +//! +//! ## What Problem is Being Solved? +//! +//! Epoch-based reclamation (EBR) can be used to safely implement read-heavy algorithms with +//! a moderate level of concurrent writes. In this context, we want readers to be able to ask +//! the question: "can I safely read some data" in a way that generates only read traffic to +//! the CPU caches. +//! +//! The crux is that after the safety check, a reader can hold a reference to the associated +//! data for an arbitrary period of time. Any actor trying to *write* to that data needs +//! to figure out when it is safe to do so. +//! +//! EBR solves this problem by separating when data is "retired" versus "reclaimed". +//! Retirement involves disabling the safety check. When an item is retired, concurrent +//! readers will fail the safety check and no longer try to read the associated data. +//! However, we still need to wait until we can prove that readers who passed the safety +//! check before retirement are no longer accessing the data. At this point, the data can be +//! "reclaimed" and written to safely. +//! +//! We can prove this by using a monotonically increasing epoch: if an item was "retired" +//! at epoch `N` its associated data could be in use by any reader belonging to any epoch +//! `N` or lower. Therefore, it is only safe to "reclaim" when all readers belong to epoch +//! `N+1` or higher. +//! +//! One consequence of this design is that misbehaving (e.g. long-lived) readers can delay +//! reclamation indefinitely. As such, this system must be used with care and in situations +//! where there is enough slack in the system to accommodate the lifetime of any readers. +//! +//! ## Primitives +//! +//! Actors call [`Registry::guard`] to receive a [`Guard`]. This guard protects items +//! at its creation epoch. Any items pushed to [`Guard::retire`] will be buffered until the +//! [`Registry`] can prove that all [`Guard`]s (correctly using the data structure) that +//! could have observed the retired item have been destroyed. +//! +//! Items can be reclaimed via [`Registry::try_advance`]. If successful, a [`Drain`] of +//! such items will be returned for processing. +//! +//! Note that retired payloads are fixed to `u32` ids (typically interpreted by the caller +//! as indices into some external storage); this is not a general-purpose deferred-drop EBR +//! system. + +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + +use crossbeam_queue::SegQueue; +use diskann::utils::IntoUsize; +use parking_lot::{Mutex, MutexGuard}; + +const CAPACITY: usize = 256; + +/// A registry of epoch-based [`Guard`]s. See the [module-level docs](self). +#[derive(Debug)] +pub struct Registry { + // A record of the active guards. + // + // * 0 = "available". + // * Anything less = "guarded". + guards: Box<[AtomicU64]>, + + // A hint for the next available registration slot. + hint: AtomicUsize, + + // The current epoch. This begins at 1 (to not be conflated with the 0 state in `guards`) + // and increments over time. + // + // NOTE: This can **only** be mutated in `try_advance`. + // + // Additionally, the logic in the module ensures that there are at most two active epochs + // at any given time. It is only safe to advance an epoch if *all* readers belong to the + // current `epoch`. + epoch: AtomicU64, + + // We can only retire a single generation at a time. + // + // This guard avoids situations where two threads concurrently advance the epoch and + // hand out overlapping `Drain`s referring to the same retiring queue. + drain: Mutex<()>, + + // We use three queues for storing retiring items. + // + // 1. Belongs to the current generation and is getting filled. + // 2. Ready for the next generation that will be populated on the next `try_advance`. + // Note that after a `try_advance` call, both 1 and 2 can be receiving retired items. + // 3. The queue returned from `try_advance` to be drained. Items drained are safe to + // reclaim. + // + // We cycle among the queues in a round-robin manner. + retiring: [SegQueue; 3], +} + +// Return the queue index for the `generation`. +fn queue(epoch: u64) -> usize { + epoch.into_usize() % 3 +} + +fn last_queue(epoch: u64) -> usize { + queue(epoch.wrapping_sub(1)) +} + +impl Registry { + /// Construct a new [`Registry`] with the default number of guard slots (256). + pub fn new() -> Self { + Self::with_capacity(CAPACITY) + } + + /// Construct a new [`Registry`] with `capacity` guard slots. + /// + /// This is the number of [`Guard`]s that can be registered concurrently. + pub fn with_capacity(capacity: usize) -> Self { + Self { + guards: (0..capacity).map(|_| AtomicU64::new(0)).collect(), + hint: AtomicUsize::new(0), + epoch: AtomicU64::new(1), + retiring: core::array::from_fn(|_| SegQueue::new()), + drain: Mutex::new(()), + } + } + + /// Return the number of [`Guard`]s this [`Registry`] supports. + pub fn capacity(&self) -> usize { + self.guards.len() + } + + /// Return the current epoch. + /// + /// This has [`Ordering::Acquire`] semantics. + pub fn epoch(&self) -> u64 { + self.epoch.load(Ordering::Acquire) + } + + /// Register the caller with `self`. + /// + /// Any items retired while [`Guard`] is held will be protected. + /// + /// # Errors + /// + /// Returns an error if the number of currently active guards exceeds [`Self::capacity`] + /// and thus a new guard cannot be made. + pub fn guard(&self) -> Result, Unavailable> { + self.guard_inner(NoDelay) + } + + #[inline] + fn guard_inner(&self, mut delay: T) -> Result, Unavailable> + where + T: GuardDelay, + { + // GUARD CHECK + let mut epoch = self.epoch(); + let hint = self.hint.fetch_add(1, Ordering::Relaxed); + delay.post_guard_check(); + let nguards = self.guards.len(); + for i in 0..nguards { + let slot = hint.wrapping_add(i) % nguards; + + let m = &self.guards[slot]; + delay.pre_cas(); + if let Ok(_) = m.compare_exchange(0, epoch, Ordering::Relaxed, Ordering::Relaxed) { + delay.post_cas(); + let mut reset = false; + loop { + // GUARD FENCE: This fence is paired with "WAITING FENCE". + // + // See that comment for details. + delay.pre_fence(); + std::sync::atomic::fence(Ordering::SeqCst); + delay.post_fence(); + + // GUARD RECHECK + let current = self.epoch(); + if current == epoch { + break; + } + + reset = true; + epoch = current; + } + + if reset { + m.store(epoch, Ordering::Relaxed); + } + + return Ok(Guard { + slot: m, + retire: &self.retiring[queue(epoch)], + epoch, + #[cfg(test)] + slot_index: slot, + }); + } + } + + Err(Unavailable) + } + + /// Return `true` if the epoch can be advanced. + /// + /// This uses a fast method that may be conservative: it can return `false` even when a + /// subsequent call to [`Self::try_advance`] would succeed (for example, if a guard slot + /// is observed to hold an old epoch but the corresponding `Guard` is about to be + /// dropped). + /// + /// This is a synchronizing operation with [`Ordering::Acquire`] semantics. + pub fn can_advance(&self) -> bool { + self.can_advance_inner(&mut NoDelay).0 + } + + fn can_advance_inner(&self, delay: &mut T) -> (bool, u64) + where + T: CanAdvanceDelay, + { + // WAITING FENCE: This is a very important part for the correctness of the algorithm. + // + // What we're protecting against is a scenario where "registering" thread A reads an + // epoch, then "waiting" thread B does a scan, thinks everything is safe, and then + // thread A finishes its CAS for its registration. + // + // This is prevented by the sequentially consistent fences. Consider the following. + // + // 1. Thread A invokes "GUARD FENCE" after a successful CAS, and then checks the + // generation at "GUARD RECHECK". + // + // 2. Thread B now enters the this block of code, executes "WAITING FENCE", then + // reads the epoch tags for all guards. + // + // With the total order induced by the sequentially consistency, either thread A's + // fence executes first, or thread B's executes first. + // + // * If thread A's fence executes first, then thread B will see the CAS and the set + // value is guaranteed to be less-than or equal to "WAITING CHECK" because: + // + // 1. The epoch is monotonically increasing. + // 2. Writes to the epoch are also sequentially consistent. + // + // * If Thread B's fence executes first, then thread A's "GUARD RECHECK" will + // observe at least the result of "WAITING CHECK" and update itself on the retry. + // + // It's possible that thread B observes the CAS to "GUARD CHECK", but since + // thread A will monotonically increase it before exiting, the value thread B + // observes is conservative and not incorrect. + delay.pre_fence(); + std::sync::atomic::fence(Ordering::SeqCst); + delay.post_fence(); + + // WAITING CHECK + let current = self.epoch(); + let mut min = current; + + for s in self.guards.iter() { + let guarded = s.load(Ordering::Relaxed); + if guarded != 0 { + min = min.min(guarded); + } + } + + // This synchronizes with all the guard's `Release`s. + std::sync::atomic::fence(Ordering::Acquire); + (min == current, min) + } + + /// Try to advance the current epoch. + /// + /// If successful, returns a [`Drain`]. All items in the drain can be reclaimed. + /// + /// Returns `None` if the epoch cannot yet be advanced (some [`Guard`] still belongs to + /// a prior epoch) or if another [`Drain`] is currently active. + /// + /// # Panics + /// + /// Panics if the epoch counter is about to overflow `u64::MAX`. In practice this is + /// effectively unreachable. + pub fn try_advance(&self) -> Option> { + self.try_advance_inner(NoDelay) + } + + fn try_advance_inner(&self, mut delay: T) -> Option> + where + T: TryAdvanceDelay, + { + // We first try to acquire the `drain` lock. + // + // It can only fail if someone else is holding the drain lock, which means we can't + // proceed anyways. + // + // This can help save an expensive slot scan. + let drain = self.drain.try_lock()?; + + let (can_advance, current) = self.can_advance_inner(&mut delay); + + // Don't wrap around! + if current == u64::MAX { + panic!( + "we've managed to go through nearly `u64::MAX` ids - this is unlikely in a real program" + ); + } + + // All waiters belong to the current epoch. Therefore, it is safe to release the old + // array queue + if can_advance { + // We are safe to use a `fetch_add` here because `drain` is ensuring exclusivity + // of the access. + // + // However, this still needs to be `SeqCst` so that this properly synchronizes + // with "GUARD FENCE" and "WAITING FENCE". + let _previous = self.epoch.fetch_add(1, Ordering::SeqCst); + debug_assert_eq!(_previous, current, "concurrency violation"); + + let queue = &self.retiring[last_queue(current)]; + Some(Drain { queue, drain }) + } else { + // Previous generation has not completely retired. + None + } + } + + #[cfg(test)] + fn assert_no_workers(&self) { + for s in self.guards.iter() { + assert_eq!(s.load(Ordering::Relaxed), 0); + } + } + + #[cfg(test)] + fn snapshot(&self) -> Vec { + self.guards + .iter() + .map(|s| s.load(Ordering::Relaxed)) + .collect() + } + + #[cfg(test)] + fn waiting(&self) -> u64 { + self.can_advance_inner(&mut NoDelay).1 + } +} + +/// A handle registering the caller as a reader at a particular epoch. +/// +/// While this guard is held, the [`Registry`] will not advance past the guard's epoch, and +/// any items retired through *any* guard at that epoch (or earlier) will not be reclaimed. +/// +/// Obtained via [`Registry::guard`]. +#[derive(Debug)] +pub struct Guard<'a> { + slot: &'a AtomicU64, + retire: &'a SegQueue, + epoch: u64, + + #[cfg(test)] + slot_index: usize, +} + +impl Guard<'_> { + /// Return the epoch associated with this [`Guard`]'s creation. + #[inline] + pub fn epoch(&self) -> u64 { + self.epoch + } + + /// Retire the id `i` at this guard's epoch. + /// + /// `i` is a caller-defined id (typically an index into external storage).It will be + /// returned from a future [`Drain`] once the registry has advanced far enough that no + /// reader could observe it. + /// + /// See also: [`Self::retire_all`]. + #[inline] + pub fn retire(&self, i: u32) { + self.retire.push(i) + } + + /// Retire all ids in `itr`. See [`Self::retire`]. + pub fn retire_all(&self, itr: I) + where + I: IntoIterator, + { + for i in itr { + self.retire(i) + } + } +} + +impl Drop for Guard<'_> { + fn drop(&mut self) { + self.slot.store(0, Ordering::Release); + } +} + +/// An iterator over ids that are safe to reclaim, returned from [`Registry::try_advance`]. +/// +/// While this drain is alive, no other thread can advance the [`Registry`]'s epoch. Drop +/// it promptly after processing. +#[derive(Debug)] +pub struct Drain<'a> { + queue: &'a SegQueue, + drain: MutexGuard<'a, ()>, +} + +impl Drain<'_> { + /// Pop the next id ready for reclamation, or `None` if the drain is empty. + #[must_use = "reclaimed ids must be reclaimed"] + pub fn pop(&self) -> Option { + self.queue.pop() + } + + /// Return the number of ids remaining in this drain. + pub fn len(&self) -> usize { + self.queue.len() + } + + /// Return `true` if there are no ids remaining in this drain. + pub fn is_empty(&self) -> bool { + self.queue.is_empty() + } +} + +impl Iterator for Drain<'_> { + type Item = u32; + fn next(&mut self) -> Option { + self.pop() + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +// NOTE: This relies on `Drain` holding the `drain` guard. In this state, we are guaranteed +// that no-one is writing into the queue, which would otherwise invalidate the exact-size +// iterator guarantee. +impl ExactSizeIterator for Drain<'_> {} + +/// Returned by [`Registry::guard`] when all guard slots are occupied. +#[derive(Debug)] +#[non_exhaustive] +pub struct Unavailable; + +impl std::fmt::Display for Unavailable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("all available registry guard slots are occupied") + } +} + +impl std::error::Error for Unavailable {} + +impl From for diskann::ANNError { + #[track_caller] + fn from(unavailable: Unavailable) -> Self { + diskann::ANNError::opaque(unavailable) + } +} + +// Delays +// +// To help test standard race scenarios without advanced tooling, we use optional delays +// that our tests can introduce to ensure threads are in various intermediate points. +// +// This does not necessarily test that the memory orderings are correct, but at least +// is a smoke test that various (known) races are handled properly. + +#[derive(Debug)] +struct NoDelay; + +trait GuardDelay { + fn post_guard_check(&mut self) {} + fn pre_cas(&mut self) {} + fn post_cas(&mut self) {} + fn pre_fence(&mut self) {} + fn post_fence(&mut self) {} +} + +impl GuardDelay for NoDelay {} + +trait CanAdvanceDelay { + fn pre_fence(&mut self) {} + fn post_fence(&mut self) {} +} + +impl CanAdvanceDelay for NoDelay {} + +trait TryAdvanceDelay: CanAdvanceDelay {} + +impl TryAdvanceDelay for NoDelay {} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test::Sequencer; + + // This test ensures that two threads racing on `hint` will correctly resolve themselves + // when claiming a slot. + #[test] + fn test_cas_race() { + let seq = Sequencer::new(); + + let mut thread_a_loop_count = 0; + let mut thread_b_loop_count = 0; + let delay = TestGuardDelay::default() + .post_guard_check(|| seq.wait_for(0)) + .with_post_fence(|| thread_a_loop_count += 1); + + let registry = Registry::with_capacity(2); + assert_eq!(registry.capacity(), 2); + + std::thread::scope(|s| { + // Thread A + s.spawn(|| { + let g = registry.guard_inner(delay).unwrap(); + assert_eq!(g.slot_index, 1); + seq.wait_for(1); + }); + + // Thread B + s.spawn(|| { + // wait for Thread A to reach the delay point. + seq.until_waiting_for(0); + { + let delay = + TestGuardDelay::default().with_post_fence(|| thread_b_loop_count += 1); + let g = registry.guard_inner(delay).unwrap(); + assert_eq!(g.slot_index, 1); + } + let g = registry.guard_inner(NoDelay).unwrap(); + assert_eq!(g.slot_index, 0); + seq.advance_past(0); + seq.advance_past(1); + }); + }); + + assert_eq!(thread_a_loop_count, 1); + assert_eq!(thread_b_loop_count, 1); + + registry.assert_no_workers(); + } + + #[test] + fn test_register_wait() { + // This tests the case where a thread enters registration, reads a generation, then + // sleeps for several generation advances. It ensures that the thread recovers properly. + let seq = Sequencer::new(); + + let mut loop_count = 0; + let delay = TestGuardDelay::default() + .post_guard_check(|| seq.wait_for(0)) + .with_post_cas(|| seq.wait_for(1)) + .with_pre_fence(|| loop_count += 1); + + let registry = Registry::with_capacity(2); + + std::thread::scope(|s| { + let handle = s.spawn(|| { + let guard = registry.guard_inner(delay).unwrap(); + + // Since we hit the CAS loop - this serves as a sanity check that we have + // the correct drain buffer. + guard.retire(10); + guard.retire_all([1, 2, 3]); + guard + }); + + // Wait for the spawned thread to reach the critical section. + seq.until_waiting_for(0); + + assert_eq!(registry.waiting(), 1); + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + assert_eq!(registry.epoch(), 2); + } + + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + assert_eq!(registry.epoch(), 3); + } + + // We allow the registering thread to make it past the CAS. + // + // We pause it again because we want to verify that it registers an old generation. + seq.advance_past(0); + seq.until_waiting_for(1); + let (can_advance, waiter) = registry.can_advance_inner(&mut NoDelay); + assert!(!can_advance); + assert_eq!( + waiter, 1, + "waiting thread registers an older generation before observing the change" + ); + seq.advance_past(1); + + let expected = 3; + + // The generation should be the last set one - even though this thread was + // parked during the transition. + let r = handle.join().unwrap(); + assert_eq!(r.epoch(), expected); + assert_eq!(registry.waiting(), expected); + }); + + assert_eq!( + loop_count, 2, + "the registering thread should have looped to update its generation" + ); + + registry.assert_no_workers(); + + // Verify that we reclaim the ID flushed by the registering thread. + // + // This requires two epoch advancements. + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + } + + { + let drain = registry.try_advance().unwrap(); + let ids: Vec<_> = drain.collect(); + assert_eq!(ids, &[10, 1, 2, 3]); + } + } + + // Verifies that filling every slot causes `register` to return `Unavailable`, and that + // dropping an existing guard frees up its slot for a subsequent registration. + #[test] + fn test_slot_exhaustion() { + let registry = Registry::with_capacity(2); + + let g0 = registry.guard().unwrap(); + let g1 = registry.guard().unwrap(); + + // All guard slots are now occupied. The next registration must fail. + assert!(matches!(registry.guard(), Err(Unavailable))); + assert!(matches!(registry.guard(), Err(Unavailable))); + + // Dropping a guard releases its slot. + let freed_slot = g0.slot_index; + drop(g0); + + let g2 = registry.guard().unwrap(); + assert_eq!( + g2.slot_index, freed_slot, + "newly freed slot should be reclaimed" + ); + + // Registry is full again. + assert!(matches!(registry.guard(), Err(Unavailable))); + + drop(g1); + drop(g2); + + registry.assert_no_workers(); + } + + #[test] + fn test_slot_wrap_around() { + let registry = Registry::with_capacity(4); + + let (g2, g3) = { + let _g0 = registry.guard().unwrap(); + let _g1 = registry.guard().unwrap(); + + let g2 = registry.guard().unwrap(); + let g3 = registry.guard().unwrap(); + (g2, g3) + }; + + assert_eq!(g2.slot_index, 2); + assert_eq!(g3.slot_index, 3); + + let f = || { + // Keep wrapping and hitting the first two guard slots. + for _ in 0..10 { + let g0 = registry.guard().unwrap(); + let g1 = registry.guard().unwrap(); + + let s0 = g0.slot_index; + let s1 = g1.slot_index; + + // Due to how the hint works, the slotws could be acquired in either order. + if s0 < s1 { + assert_eq!((s0, s1), (0, 1)); + } else { + assert_eq!((s0, s1), (1, 0)); + }; + + assert!(matches!(registry.guard(), Err(Unavailable))); + } + }; + + // Run with the default hint. + f(); + + // Set the hint to `usize::MAX`. + registry.hint.store(usize::MAX - 10, Ordering::Relaxed); + + // Run tests again to ensure we can properly handle wrap-around. + f(); + + drop((g2, g3)); + registry.assert_no_workers(); + } + + // Verifies that `try_advance` short-circuits to `None` when another thread already holds + // the `drain` mutex, even if `can_advance` would otherwise succeed. This guards the + // early `try_lock` that avoids a redundant slot scan. + #[test] + fn test_concurrent_try_advance() { + let registry = Registry::with_capacity(2); + + // No outstanding registrations, so `can_advance` would succeed for any caller. + let drain = registry + .try_advance() + .expect("first try_advance must succeed"); + let gen_after_first = registry.epoch(); + assert_eq!(gen_after_first, 2); + + // While the first `Drain` is alive (holding the drain mutex), a concurrent + // `try_advance` must return `None` without advancing the generation. + std::thread::scope(|s| { + s.spawn(|| { + assert!( + registry.try_advance().is_none(), + "try_advance must fail while another holds the drain mutex" + ); + assert_eq!( + registry.epoch(), + gen_after_first, + "generation must not advance when drain is contended" + ); + }); + }); + + // Releasing the drain unblocks subsequent advances. + drop(drain); + + let _drain2 = registry + .try_advance() + .expect("try_advance must succeed once drain is released"); + assert_eq!(registry.epoch(), 3); + } + + // Verifies the 3-queue rotation invariant: items retired at generation `G` are drained + // on the second `try_advance` after `G`. The first advance returns the queue from + // `(G - 1) % 3` (one cycle older), so it must NOT contain items from `G`. + #[test] + fn test_drain_rotation() { + let registry = Registry::with_capacity(1); + + // Helper: register, retire one item, drop. Returns the generation we retired at. + let retire_at = |id: u32| { + let g = registry.guard().unwrap(); + let epoch = g.epoch(); + g.retire(id); + epoch + }; + + // Retire 100 at generation A (= 1). + let gen_a = retire_at(100); + assert_eq!(gen_a, 1); + + // 1st advance after A: must NOT drain item 100. + { + let drain = registry.try_advance().unwrap(); + assert!( + drain.is_empty(), + "100 must not drain on 1st advance after A" + ); + } + + // Retire 200 at generation B (= A - 1). + let gen_b = retire_at(200); + assert_eq!(gen_b, gen_a + 1); + + // 2nd advance after A (1st after B): drains A's queue → [100]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[100]); + } + + // Retire 300 at generation C. + let _gen_c = retire_at(300); + + // 2nd advance after B: drains B's queue → [200]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[200]); + } + + // 2nd advance after C: drains C's queue → [300]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[300]); + } + + // Rotation has cycled back to where A's queue used to live — must be empty, + // proving the queue slot was drained cleanly and is reusable. + { + let drain = registry.try_advance().unwrap(); + assert!( + drain.is_empty(), + "rotation should leave queues empty after one cycle" + ); + } + + registry.assert_no_workers(); + } + + //-------------// + // Test Delays // + //-------------// + + macro_rules! tester { + ($struct:ident, $trait:ident, $($with:ident => $f:ident),* $(,)?) => { + #[derive(Default)] + struct $struct<'a> { + $($f: Option>,)* + } + + impl<'a> $struct<'a> { + $( + fn $with(mut self, f: F) -> Self + where + F: FnMut() + Send + 'a + { + self.$f = Some(Box::new(f)); + self + } + )* + } + + impl $trait for $struct<'_> { + $( + fn $f(&mut self) { + if let Some(f) = self.$f.as_mut() { + f() + } + } + )* + } + } + } + + tester! { + TestGuardDelay, + GuardDelay, + post_guard_check => post_guard_check, + with_pre_cas => pre_cas, + with_post_cas => post_cas, + with_pre_fence => pre_fence, + with_post_fence => post_fence, + } + + // #[derive(Default)] + // struct TestGuardDelay<'a> { + // post_guard_check: Option<&'a mut dyn FnMut()>, + // pre_cas: Option<&'a mut dyn FnMut()>, + // pre_fence: Option<&'a mut dyn FnMut()>, + // post_fence: Option<&'a mut dyn FnMut()>, + // } + + // macro_rules! builder { + // ($f:ident, $field:ident) => { + // fn $f(mut self, f: &'a mut dyn FnMut()) -> Self { + // self.$field = Some(f); + // self + // } + // } + // } + + // macro_rules! forward { + // ($f:ident) => { + // fn $f(&mut self) { + // if let Some(f) = self.$f.as_mut() { + // f() + // } + // } + // } + // } + + // impl<'a> TestGuardDelay<'a> { + // builder!(post_guard_check, post_guard_check); + // builder!(with_pre_cas, pre_cas); + // builder!(with_pre_fence, pre_fence); + // builder!(with_post_fence, post_fence); + // } + + // impl GuardDelay for TestGuardDelay<'_> { + // forward!(post_guard_check); + // forward!(pre_cas); + // forward!(pre_fence); + // forward!(post_fence); + // } + + // struct CanAdvanceDelay; + + // impl CanAdvanceDelay for TestWaitingDelay {} + + // struct TestTryAdvanceDelay; + + // impl TryAdvanceDelay for TestTryAdvanceDelay {} +} diff --git a/diskann-inmem/src/sync/freelist.rs b/diskann-inmem/src/sync/freelist.rs new file mode 100644 index 000000000..c3b12aa08 --- /dev/null +++ b/diskann-inmem/src/sync/freelist.rs @@ -0,0 +1,502 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Finding Unused IDs +//! +//! When working with slots into an index, finding an available slot efficiently can be +//! challenging. This module provides a [`Freelist`] to make this more efficient. +//! +//! IDs are retrieved in several orders of precedence: +//! +//! ## Recycles +//! +//! Previously reclaimed slots can be recycled and is the preferred way of finding slots. +//! Reclaimed slots IDs live inside an atomic queue and as such, the size of this queue is +//! bounded to conserve memory. +//! +//! ## Minted +//! If not slots live in the recycled queue, new slots can be "minted" up to the configured +//! maximum. This simply tracks the maximum slot ID that has been yielded so far and returns +//! the next one. +//! +//! This path only works during the initial filling of the managed slots and exists to +//! provide a fast-path for static index builds. Once the maximum slot has been yielded, +//! minting no longer applies. +//! +//! ## Scanning +//! +//! If a slot cannot be found via recycling or via minting, a scan is requested. Scans +//! typically involve searching over an authoritative source of slot usage to find and +//! claim an unused slot. +//! +//! The [`Freelist`] assists with scans in several ways: +//! +//! 1. [`Freelist::scan`]: Receive a range of managed ID ranges to scan. Multiple threads +//! can call this method to receive disjoint ranges to process. +//! +//! 2. [`Freelist::push`]/[`Freelist::append`]: Available slots can be placed into the +//! freelist for recycling. +//! +//! 3. [`Freelist::pop_recycled`]: Attempt to retrieve a slot ID directly from the recycled +//! buffer. +//! +//! Together, these tools can be used to build a cooperative scan. A thread scans a block of +//! IDs returned by [`Freelist::scan`]. If a slot is claimed this way, the thread can continue +//! scanning the rest of the block, pushing any available slot IDs to the freelist. +//! +//! Other threads that are unsuccessfully scanning can periodically check +//! [`Freelist::pop_recycled`] to benefit from the work done by another more successful thread. +//! +//! # Non-Authoritative +//! +//! Note that the [`Freelist`] does not attempt to be authoritative on the list of slots IDs +//! that are used and unused. It's job is mainly to improve performance. +//! +//! An authoritative collection of [`AtomicTag`](super::AtomicTag)s must be used to correctly +//! manage slots. + +use std::{ + num::NonZeroU32, + sync::atomic::{AtomicU32, Ordering}, +}; + +use crossbeam_queue::ArrayQueue; +use diskann::utils::IntoUsize; + +// NOTE: We want the scan size to be relatively big. Each tag occupied just a single byte, +// so a scan needs to be at least 64 to ensure a thread is working with just a single cache +// line. +const SCAN_SIZE: u32 = 256; + +/// A tool for quickly finding unused slots in an index. +/// +/// See [freelist](self) for details. +#[derive(Debug)] +pub struct Freelist { + // Bounded fast queue of retired slots. + recycled: ArrayQueue, + + // The highest ID the freelist manages. IDs `>= max` are rejected by `push`/`append` + // and the minting path will not yield them. + max: u32, + + // The next `unminted` Id. This becomes unused once this reaches `max`. + next: AtomicU32, + + // The current bucket for scanning. + scan_bucket: AtomicU32, +} + +impl Freelist { + /// Construct a new [`Freelist`] that manages `max` ids. + /// + /// The internal fast recycled list will hold up to `recycled` items. + /// + /// The memory occupied by this struct is `O(recycled)`. + pub fn new(max: u32, recycled: NonZeroU32) -> Self { + Self { + recycled: ArrayQueue::new(recycled.get().into_usize()), + max, + next: AtomicU32::new(0), + scan_bucket: AtomicU32::new(0), + } + } + + /// Return the maximum number of slot IDs managed by `self`. + pub fn max(&self) -> u32 { + self.max + } + + /// Try to retrieve an id. + /// + /// If successful, return [`Id::Found`]. Otherwise, returns [`Id::Scan`]. + pub fn pop(&self) -> Id { + if let Some(id) = self.recycled.pop() { + return Id::Found(id); + } + + // Missed in the recycled buffer. Try pulling from the high-water mark. + let mut next = self.next.load(Ordering::Relaxed); + while next < self.max { + match self + .next + .compare_exchange(next, next + 1, Ordering::Relaxed, Ordering::Relaxed) + { + Ok(next) => return Id::Found(next), + Err(actual) => { + next = actual; + } + } + } + + // Missed in the recycle bin and from unallocated IDs. Time to indicate a scan. + Id::Scan + } + + /// Attempt to retrieve an ID directly from the recycled list. + /// + /// This may be used during scans to retrieve IDs found by other threads. + pub fn pop_recycled(&self) -> Option { + self.recycled.pop() + } + + /// Return a new [`Scan`] containing a range of IDs to check. + /// + /// This is managed such that multiple threads calling this function will receive + /// disjoint ranges to scan. + pub fn scan(&self) -> Scan { + if self.max == 0 { + return Scan { start: 0, stop: 0 }; + } + + let num_buckets = self.max.div_ceil(SCAN_SIZE); + + // It's possible that if `scan_bucket` wraps, we do a bit of redundant scanning. + // + // This is fine as this should happen rarely. + let bucket = self.scan_bucket.fetch_add(1, Ordering::Relaxed) % num_buckets; + + let start = bucket * SCAN_SIZE; + let stop = match start.checked_add(SCAN_SIZE) { + Some(stop) => stop.min(self.max), + None => self.max, + }; + + Scan { start, stop } + } + + /// Attempt to push `id` into the recycled list. Return `true` if `id` was inserted. + /// + /// If `false` is returned, it is likely because the internal recycle buffer is full. + /// + /// IDs exceeding [`Self::max`] are discarded. + pub fn push(&self, id: u32) -> bool { + if id < self.max { + self.recycled.push(id).is_ok() + } else { + false + } + } + + /// Append items from `itr` into the recycled buffer. Return the number of items + /// actually added. + /// + /// Callers may not assume that `itr` is fully consumed. + /// + /// IDs exceeding [`Self::max`] are discarded. + pub fn append(&self, itr: I) -> usize + where + I: IntoIterator, + { + let mut count = 0; + for id in itr { + if id < self.max { + if let Err(_) = self.recycled.push(id) { + break; + } else { + count += 1; + } + } + } + + count + } +} + +/// The result of [`Freelist::pop`]. +#[derive(Debug, Clone, Copy)] +#[must_use] +pub enum Id { + /// An ID was found directly in the [`Freelist`]. + Found(u32), + /// No ID was found in the [`Freelist`] and an exhaustive scan is recommended. + Scan, +} + +#[cfg(test)] +impl Id { + fn unwrap(self) -> u32 { + match self { + Self::Found(i) => i, + Self::Scan => panic!("expected Id::Found, got Id::Scan"), + } + } + + fn is_scan(self) -> bool { + matches!(self, Self::Scan) + } +} + +/// An [`ExactSizeIterator`] over IDs to scan. Returned by [`Freelist::scan`]. +#[derive(Debug)] +pub struct Scan { + start: u32, + stop: u32, +} + +impl Scan { + pub fn as_range(&self) -> std::ops::Range { + self.start..self.stop + } +} + +impl Iterator for Scan { + type Item = u32; + fn next(&mut self) -> Option { + if self.start >= self.stop { + None + } else { + let i = self.start; + self.start += 1; + Some(i) + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = (self.stop - self.start).into_usize(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for Scan {} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{collections::HashSet, sync::Barrier, thread}; + + fn freelist(max: u32, recycled: u32) -> Freelist { + Freelist::new(max, NonZeroU32::new(recycled).unwrap()) + } + + //---------// + // Minting // + //---------// + + #[test] + fn pop_mints_sequentially_until_exhausted() { + let fl = freelist(4, 8); + assert_eq!(fl.max(), 4); + + let mut got = Vec::new(); + for _ in 0..4 { + got.push(fl.pop().unwrap()); + } + assert_eq!(got, vec![0, 1, 2, 3]); + assert!(fl.pop().is_scan()); + assert!(fl.pop().is_scan()); + } + + #[test] + fn pop_returns_scan_when_max_zero() { + let fl = freelist(0, 1); + assert!(fl.pop().is_scan()); + assert_eq!(fl.max(), 0); + } + + #[test] + fn recycled_ids_take_precedence_over_minting() { + let fl = freelist(4, 8); + // Seed the recycled queue. + assert!(fl.push(2)); + // First pop must come from the recycled queue, not mint 0. + assert_eq!(fl.pop().unwrap(), 2); + // Subsequent pops mint from 0. + assert_eq!(fl.pop().unwrap(), 0); + } + + //------// + // Push // + //------// + + #[test] + fn push_rejects_ids_at_or_above_max() { + let fl = freelist(4, 8); + assert!(!fl.push(4)); + assert!(!fl.push(u32::MAX)); + assert!(fl.push(3)); + + assert_eq!(fl.pop_recycled().unwrap(), 3); + } + + #[test] + fn push_returns_false_when_recycled_full() { + let fl = freelist(16, 2); + assert!(fl.push(2)); + assert!(fl.push(3)); + assert!(!fl.push(5)); + + // Drained from recycled queue. + assert_eq!(fl.pop().unwrap(), 2); + assert_eq!(fl.pop().unwrap(), 3); + } + + #[test] + fn pop_recycled_empty_returns_none() { + let fl = freelist(4, 4); + assert!(fl.pop_recycled().is_none()); + } + + #[test] + fn pop_recycled_does_not_mint() { + let fl = freelist(4, 4); + // No pushes, no recycled entries — `pop_recycled` must not fall through to minting. + assert!(fl.pop_recycled().is_none()); + // The minting counter should be untouched. + assert_eq!(fl.pop().unwrap(), 0); + } + + //--------// + // Append // + //--------// + + #[test] + fn append_counts_inserted_and_skips_out_of_range() { + let fl = freelist(4, 8); + let count = fl.append([0u32, 4, 1, 7, 2].iter().copied()); + // 4 and 7 are >= max and skipped; 0, 1, 2 are inserted. + assert_eq!(count, 3); + let mut got = Vec::new(); + while let Some(id) = fl.pop_recycled() { + got.push(id); + } + got.sort(); + assert_eq!(got, vec![0, 1, 2]); + } + + #[test] + fn append_stops_when_buffer_full() { + let fl = freelist(16, 2); + let count = fl.append(0u32..16); + assert_eq!(count, 2); + } + + //------// + // Scan // + //------// + + fn as_vec(itr: I) -> Vec + where + I: Iterator, + { + itr.collect() + } + + #[test] + fn scan_on_empty_freelist_yields_nothing() { + let fl = freelist(0, 1); + let mut scan = fl.scan(); + assert_eq!(scan.len(), 0); + assert!(scan.next().is_none()); + } + + #[test] + fn scan_covers_full_range_in_one_pass() { + // Choose `max` to force a partial last bucket. + let max = 2 * SCAN_SIZE + 50; + let fl = freelist(max, 4); + + // First Round + + let scan = fl.scan(); + assert_eq!(scan.as_range(), 0..SCAN_SIZE); + assert_eq!(scan.len(), SCAN_SIZE.into_usize()); + assert_eq!(as_vec(scan), as_vec(0..SCAN_SIZE)); + + let scan = fl.scan(); + assert_eq!(scan.as_range(), SCAN_SIZE..2 * SCAN_SIZE); + assert_eq!(scan.len(), SCAN_SIZE.into_usize()); + assert_eq!(as_vec(scan), as_vec(SCAN_SIZE..2 * SCAN_SIZE)); + + let scan = fl.scan(); + assert_eq!(scan.as_range(), 2 * SCAN_SIZE..(2 * SCAN_SIZE + 50)); + assert_eq!(scan.len(), 50); + assert_eq!(as_vec(scan), as_vec((2 * SCAN_SIZE)..(2 * SCAN_SIZE + 50))); + + // Check Wrapping + + let scan = fl.scan(); + assert_eq!(scan.as_range(), 0..SCAN_SIZE); + assert_eq!(scan.len(), SCAN_SIZE.into_usize()); + assert_eq!(as_vec(scan), as_vec(0..SCAN_SIZE)); + } + + //-------------// + // Concurrency // + //-------------// + + #[test] + fn concurrent_pop_yields_unique_ids() { + let max = 4096u32; + let fl = Freelist::new(max, NonZeroU32::new(8).unwrap()); + let nthreads = 8; + let barrier = Barrier::new(nthreads); + + let results: Vec> = thread::scope(|s| { + let handles: Vec<_> = (0..nthreads) + .map(|_| { + s.spawn(|| { + let mut out = Vec::new(); + barrier.wait(); + loop { + match fl.pop() { + Id::Found(id) => out.push(id), + Id::Scan => break, + } + } + out + }) + }) + .collect(); + handles.into_iter().map(|h| h.join().unwrap()).collect() + }); + + let mut all: Vec = results.into_iter().flatten().collect(); + all.sort(); + let expected: Vec = (0..max).collect(); + assert_eq!(all, expected, "all ids in [0, max) minted exactly once"); + } + + #[test] + fn concurrent_scan_partitions_one_pass() { + let max = SCAN_SIZE * 4; + let fl = Freelist::new(max, NonZeroU32::new(4).unwrap()); + let num_buckets = max.div_ceil(SCAN_SIZE) as usize; + let nthreads = num_buckets; + let barrier = Barrier::new(nthreads); + + let ids: Vec = thread::scope(|s| { + let handles: Vec<_> = (0..nthreads) + .map(|_| { + s.spawn(|| { + barrier.wait(); + fl.scan().collect::>() + }) + }) + .collect(); + handles + .into_iter() + .flat_map(|h| h.join().unwrap()) + .collect() + }); + + let unique: HashSet = ids.iter().copied().collect(); + assert_eq!( + unique.len(), + ids.len(), + "no id appeared twice across threads" + ); + assert_eq!( + unique.len() as u32, + max, + "scans covered every id in [0, max)" + ); + } +} diff --git a/diskann-inmem/src/arbiter/mod.rs b/diskann-inmem/src/sync/mod.rs similarity index 56% rename from diskann-inmem/src/arbiter/mod.rs rename to diskann-inmem/src/sync/mod.rs index 268e57d18..6ec810631 100644 --- a/diskann-inmem/src/arbiter/mod.rs +++ b/diskann-inmem/src/sync/mod.rs @@ -3,13 +3,14 @@ * Licensed under the MIT license. */ -pub(crate) mod buffer; -pub use buffer::{Buffer, RawSlice}; - pub mod epoch; +pub use epoch::Registry; pub mod freelist; pub use freelist::Freelist; -pub mod generation; -pub use generation::Generation; +mod tag; +pub use tag::{AtomicTag, Tag}; + +#[cfg(test)] +mod test; diff --git a/diskann-inmem/src/sync/tag.rs b/diskann-inmem/src/sync/tag.rs new file mode 100644 index 000000000..4b30d1162 --- /dev/null +++ b/diskann-inmem/src/sync/tag.rs @@ -0,0 +1,308 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! State tags for slots participating in the EBR protocol. +//! +//! This module defines [`Tag`] and [`AtomicTag`], a small state machine used to label +//! individual slots in concurrent data structures. Tags pair with the epoch-based +//! reclamation machinery in [`super::epoch`]: epochs decide *when* it is safe to reclaim a +//! slot, while tags decide *whether* a given slot is currently readable, owned, or in +//! transition. +//! +//! Note that the type system does not enforce the tag protocol — only the documented +//! transitions on [`Tag`] are sound, and it is the caller's responsibility to follow them. + +use std::sync::atomic::{AtomicU8, Ordering}; + +/// A tag for controlling concurrent access to data. +/// +/// Tag updates and reads should use [`AtomicTag`]. +/// +/// A reader holding a [`Guard`](super::epoch::Guard) performs an [`Ordering::Acquire`] load +/// on an [`AtomicTag`]; if [`Tag::can_read`] returns `true`, the reader may access the +/// data this tag protects. +/// +/// # Named Tags +/// +/// * [`Tag::PUBLISHED`]: The associated slot has been published and may be freely accessed +/// by readers. +/// +/// * [`Tag::FROZEN`]: This data is protected and is not expected to be mutated. Readers +/// may still freely access this data. `FROZEN` has no defined transitions in this +/// protocol; once a slot is frozen it remains so for the lifetime of the structure. +/// +/// * [`Tag::AVAILABLE`]: The associated slot is not currently storing valid data +/// and is available to use. +/// +/// Ownership is acquired via a CAS from `AVAILABLE` to `OWNED`. +/// +/// * [`Tag::OWNED`]: The associated data is owned by some thread. Only the thread +/// owning this slot may update it. +/// +/// Note that ownership may be transferred between threads as long as this ownership +/// transfer is unambiguous and properly synchronized. +/// +/// In this state, the owning thread may write to the associated data. +/// +/// * [`Tag::RETIRING`]: Indicates that this slot is currently being [retired](super::epoch). +/// Readers may not access associated data after reading this tag, but readers who accessed +/// the tag before retirement may still exist. +/// +/// Only transition away from this value when the corresponding slot is returned from a +/// [`Drain`](super::epoch::Drain). +/// +/// # Allowed Transitions +/// +/// The following protocol must be used when working with [`AtomicTag`]ged data and a +/// [`Registry`](super::Registry). +/// +/// * [`Tag::AVAILABLE`] -> [`Tag::OWNED`]: Use a CAS to ensure unique ownership. Once in +/// the owned state, unsynchronized writes can be made to associated data. +/// +/// * [`Tag::OWNED`] -> [`Tag::PUBLISHED`]: Must be done as an [`Ordering::Release`] store +/// and only by the thread that acquired ownership. +/// +/// * [`Tag::PUBLISHED`] -> [`Tag::RETIRING`]: Must be done while under a +/// [`Guard`](super::epoch::Guard) and may be done with relaxed atomics. Writes to +/// associated data may not be made. Place into [`Guard::retire`](super::epoch::Guard::retire) +/// for final reclamation. +/// +/// * [`Tag::RETIRING`] -> [`Tag::AVAILABLE`]: May only be done if the corresponding slot is +/// retrieved from a [`Drain`](super::epoch::Drain). Writes may occur to associated data +/// and if so, this transition must be made with [`Ordering::Release`]. +/// +/// # Reading +/// +/// Checks to [`Tag::can_read`] can be made following [`Ordering::Acquire`] loads. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Tag(u8); + +impl Tag { + //-------------// + // High Values // + //-------------// + + /// The slot is permanently readable and never mutated again. See [`Tag`]. + pub const FROZEN: Self = Self::new(u8::MAX); + + /// The slot has been published and is freely readable. See [`Tag`]. + pub const PUBLISHED: Self = Self::new(u8::MAX - 1); + + //------------// + // Low Values // + //------------// + + /// The slot holds no valid data and may be claimed via CAS to [`Tag::OWNED`]. + /// See [`Tag`]. + pub const AVAILABLE: Self = Self::new(0); + + /// The slot is exclusively owned by a single thread that may write its data. + /// See [`Tag`]. + pub const OWNED: Self = Self::new(1); + + /// The slot is in the process of being retired and is no longer readable to new + /// readers. See [`Tag`]. + pub const RETIRING: Self = Self::new(2); + + const RESERVED: Self = Self::RETIRING; + + /// Return `true` if `self` is one of the protocol's reserved tag values. + /// + /// Reserved tags are part of the protocol's fixed vocabulary and are never delivered + /// as retirement payloads. + #[must_use = "this function has no side-effects"] + pub(crate) fn is_reserved(self) -> bool { + (self <= Self::RESERVED) || (self == Self::FROZEN) + } + + /// Return `true` if `self` is in a state where it is legal to access tagged data. + #[must_use = "this function has no side-effects"] + pub(crate) fn can_read(self) -> bool { + // Tags are split into `high` (readable) and `low` (non-readable) values so this + // check reduces to a single comparison. + self >= Self::PUBLISHED + } + + /// Construct a new [`Tag`] with `value`. + #[inline] + const fn new(value: u8) -> Self { + Self(value) + } + + /// Return the value of `self`. + #[inline] + const fn value(self) -> u8 { + self.0 + } +} + +impl std::fmt::Display for Tag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let me = *self; + if me == Self::AVAILABLE { + f.write_str("Tag(AVAILABLE)") + } else if me == Self::OWNED { + f.write_str("Tag(OWNED)") + } else if me == Self::RETIRING { + f.write_str("Tag(RETIRING)") + } else if me == Self::FROZEN { + f.write_str("Tag(FROZEN)") + } else if me == Self::PUBLISHED { + f.write_str("Tag(PUBLISHED)") + } else { + write!(f, "Tag({})", me.value()) + } + } +} + +/// An atomic [`Tag`]. +/// +/// Memory orderings are the caller's responsibility and must be chosen consistent with the +/// protocol described on [`Tag`]. +#[derive(Debug)] +#[repr(transparent)] +pub struct AtomicTag(AtomicU8); + +impl AtomicTag { + /// Construct a new [`AtomicTag`] initialized to `tag`. + pub const fn new(tag: Tag) -> Self { + Self(AtomicU8::new(tag.value())) + } + + /// Creates a new reference to a `AtomicTag` from a raw pointer. + /// + /// # Safety + /// + /// * `ptr` must be aligned to `align_of::()`. + /// * `ptr` must be valid for both reads and writes for the whole lifetime `'a`. + /// * The caller chooses `'a`; the underlying allocation must outlive `'a`. + /// * This must adhere to the memory model for atomic accesses. In particular, it must + /// not admit conflicting atomic and non-atomic accesses, or atomic accesses of + /// different sizes without synchronization. + /// + /// See: + pub unsafe fn from_ptr<'a>(ptr: *mut AtomicTag) -> &'a Self { + unsafe { &*ptr } + } + + /// Perform an atomic compare-exchange with the provided orderings. + /// + /// Note that this does not enforce the [`Tag`] transition protocol; the caller must + /// ensure `current` and `new` correspond to a legal transition. + /// + /// See: [`AtomicU8::compare_exchange`]. + pub fn compare_exchange( + &self, + current: Tag, + new: Tag, + success: Ordering, + failure: Ordering, + ) -> Result { + self.0 + .compare_exchange(current.value(), new.value(), success, failure) + .map(Tag::new) + .map_err(Tag::new) + } + + /// Perform an atomic load with the provided ordering. + /// + /// See: [`AtomicU8::load`]. + pub fn load(&self, ordering: Ordering) -> Tag { + Tag::new(self.0.load(ordering)) + } + + /// Perform an atomic store with the provided ordering. + /// + /// See: [`AtomicU8::store`]. + pub fn store(&self, val: Tag, ordering: Ordering) { + self.0.store(val.value(), ordering) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{sync::Barrier, thread}; + + use crate::{ + buffer::Buffer, + num::{Align, Bytes}, + }; + + fn spin_decrement(m: &AtomicTag, count: usize) { + for _ in 0..count { + let mut current = m.load(Ordering::Relaxed); + while let Err(c) = m.compare_exchange( + current, + Tag::new(current.value().wrapping_sub(1)), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + current = c; + } + } + } + + #[test] + fn test_atomic() { + let threads = 4; + let barrier = &Barrier::new(threads); + + // This dance basically verifies that we can view the tag though a proper-aligned + // raw pointer. + let buffer = + Buffer::new(1, Bytes::size_of::(), Align::of::()).unwrap(); + let ptr = buffer.get(0).unwrap().as_mut_ptr().cast::(); + + { + let tag = unsafe { AtomicTag::from_ptr(ptr) }; + tag.store(Tag::FROZEN, Ordering::Relaxed); + } + + let count = 1000; + thread::scope(|s| { + for _ in 0..threads { + s.spawn(|| { + // Re-derive `p` to avoid issues with `Send`. + let p = buffer.get(0).unwrap().as_mut_ptr().cast::(); + let tag = unsafe { AtomicTag::from_ptr(p) }; + barrier.wait(); + spin_decrement(&tag, count); + }); + } + }); + + { + let g = unsafe { AtomicTag::from_ptr(ptr) }.load(Ordering::Relaxed); + assert_eq!(g, Tag::new(u8::MAX.wrapping_sub((count * threads) as u8))); + } + } + + #[test] + fn test_is_reserved() { + assert!(Tag::FROZEN.is_reserved()); + assert!(!Tag::PUBLISHED.is_reserved()); + + assert!(Tag::AVAILABLE.is_reserved()); + assert!(Tag::OWNED.is_reserved()); + assert!(Tag::RETIRING.is_reserved()); + } + + #[test] + fn test_can_read() { + assert!(Tag::FROZEN.can_read()); + assert!(Tag::PUBLISHED.can_read()); + + assert!(!Tag::AVAILABLE.can_read()); + assert!(!Tag::OWNED.can_read()); + assert!(!Tag::RETIRING.can_read()); + } +} diff --git a/diskann-inmem/src/sync/test.rs b/diskann-inmem/src/sync/test.rs new file mode 100644 index 000000000..518ea730a --- /dev/null +++ b/diskann-inmem/src/sync/test.rs @@ -0,0 +1,282 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Directed stress tests for `Registry`. + +use std::{ + cell::UnsafeCell, + mem::MaybeUninit, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use rand::{Rng, distr::StandardUniform}; + +use super::{AtomicTag, Registry, Tag}; + +type Data = [u32; 4]; + +struct Slot { + tag: AtomicTag, + payload: UnsafeCell>>, +} + +impl Slot { + fn new() -> Self { + Self { + tag: AtomicTag::new(Tag::AVAILABLE), + payload: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + fn try_claim(&self, payload: Data, f: F) + where + F: FnOnce(), + { + if let Ok(_) = self.tag.compare_exchange( + Tag::AVAILABLE, + Tag::OWNED, + Ordering::Acquire, + Ordering::Relaxed, + ) { + unsafe { &mut *self.payload.get() }.write(Box::new(payload)); + f(); + self.tag.store(Tag::PUBLISHED, Ordering::Release); + } + } + + unsafe fn try_read(&self) -> Option<&Data> { + if self.tag.load(Ordering::Acquire).can_read() { + let payload = unsafe { &*self.payload.get() }; + Some(&*unsafe { payload.assume_init_ref() }) + } else { + None + } + } + + #[must_use] + fn retire(&self) -> bool { + let tag = self.tag.load(Ordering::Relaxed); + if tag != Tag::PUBLISHED { + return false; + } + + if let Ok(_) = self.tag.compare_exchange( + Tag::PUBLISHED, + Tag::RETIRING, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + true + } else { + false + } + } + + unsafe fn make_available(&self) { + assert_eq!(self.tag.load(Ordering::Relaxed), Tag::RETIRING); + unsafe { (&mut *self.payload.get()).assume_init_drop() }; + + if let Err(_) = self.tag.compare_exchange( + Tag::RETIRING, + Tag::AVAILABLE, + Ordering::Release, + Ordering::Relaxed, + ) { + panic!("concurrency violation"); + } + } +} + +impl Drop for Slot { + fn drop(&mut self) { + if self.tag.load(Ordering::Relaxed) != Tag::AVAILABLE { + let payload = self.payload.get_mut(); + unsafe { payload.assume_init_drop() }; + } + } +} + +// We control concurrency, so can safely share this. +unsafe impl Sync for Slot {} + +fn make_payload(epoch: u64, index: usize) -> Data { + [ + index as u32, + epoch as u32, + (epoch >> 32) as u32, + (index as u32) ^ (epoch as u32) ^ ((epoch >> 32) as u32), + ] +} + +fn verify_payload(data: &Data) -> (usize, u64) { + let checksum = data[0] ^ data[1] ^ data[2]; + assert_eq!( + data[3], checksum, + "torn or corrupted read: payload {data:?}, expected checksum {checksum}" + ); + let index = data[0] as usize; + let epoch = data[1] as u64 | ((data[2] as u64) << 32); + (index, epoch) +} + +struct Record { + epoch: u64, + index: usize, + data: Data, +} + +fn read_job( + registry: &Registry, + slots: &[Slot], + stop_at: u64, + retire_rate: f64, + active: &AtomicUsize, +) -> Vec { + assert!(retire_rate > 0.0); + assert!(retire_rate < 1.0); + + let mut records = Vec::new(); + let mut rng = rand::rng(); + + loop { + let mut reads = Vec::<&Data>::new(); + let guard = registry.guard().unwrap(); + if guard.epoch() >= stop_at { + break; + } + + for (i, slot) in slots.iter().enumerate() { + if let Some(read) = unsafe { slot.try_read() } { + reads.push(read); + + let sample: f64 = rng.sample(StandardUniform); + if sample < retire_rate && slot.retire() { + guard.retire(i as u32); + active.fetch_sub(1, Ordering::Release); + + std::thread::yield_now(); + records.push(Record { + epoch: guard.epoch(), + index: i, + data: *read, + }); + } + } + } + } + + records +} + +fn retire_job(registry: &Registry, slots: &[Slot], stop_at: u64, active: &AtomicUsize) { + loop { + let epoch = registry.epoch(); + if epoch >= stop_at { + return; + } + + if active.load(Ordering::Acquire) != 0 { + std::thread::yield_now(); + continue; + } + + if let Some(drain) = registry.try_advance() { + for i in drain { + unsafe { slots[i as usize].make_available() }; + } + } + } +} + +fn write_job(registry: &Registry, slots: &[Slot], stop_at: u64, active: &AtomicUsize) { + loop { + let epoch = registry.epoch(); + if epoch >= stop_at { + return; + } + + for (i, slot) in slots.iter().enumerate() { + slot.try_claim(make_payload(epoch, i), || { + active.fetch_add(1, Ordering::Relaxed); + }); + } + + std::thread::yield_now(); + } +} + +#[test] +fn registry_stress_test() { + let registry = Registry::new(); + let slots: Vec<_> = std::iter::repeat_with(Slot::new).take(10).collect(); + let active = AtomicUsize::new(0); + + let stop_at = if cfg!(miri) { 11 } else { 50_000 }; + let retire_rate = if cfg!(miri) { 0.95 } else { 0.1 }; + + // We use two threads for each job to be extra adversarial. + let barrier = std::sync::Barrier::new(6); + let result = std::thread::scope(|s| { + // Spin up readers. + let r0 = s.spawn(|| { + barrier.wait(); + read_job(®istry, &slots, stop_at, retire_rate, &active) + }); + + let r1 = s.spawn(|| { + barrier.wait(); + read_job(®istry, &slots, stop_at, retire_rate, &active) + }); + + // Spin up writers + s.spawn(|| { + barrier.wait(); + write_job(®istry, &slots, stop_at, &active); + }); + + s.spawn(|| { + barrier.wait(); + write_job(®istry, &slots, stop_at, &active); + }); + + // Spin up retirers + s.spawn(|| { + barrier.wait(); + retire_job(®istry, &slots, stop_at, &active); + }); + s.spawn(|| { + barrier.wait(); + retire_job(®istry, &slots, stop_at, &active); + }); + + let mut r0 = r0.join().unwrap(); + let r1 = r1.join().unwrap(); + r0.extend(r1); + r0 + }); + + for record in &result { + let (index, write_epoch) = verify_payload(&record.data); + + // The index encoded in the payload must match the slot we read from. + assert_eq!( + index, record.index, + "slot identity mismatch: payload says slot {index}, record says slot {}", + record.index + ); + + // The slot was written at `write_gen` and read at `record.generation. + // Since generations increase (newer = larger), write_gen <= record.generation + // means the write happened at or before the reader's epoch. + // + // Note that a reader can observe one epoch change during its tenure, so we *can* + // observe writes from one higher epoch. + assert!( + write_epoch <= (record.epoch + 1), + "read data from the future: write_gen={write_epoch}, read_gen={}", + record.epoch + ); + } +} diff --git a/diskann-inmem/src/test.rs b/diskann-inmem/src/test.rs new file mode 100644 index 000000000..81c9cffd9 --- /dev/null +++ b/diskann-inmem/src/test.rs @@ -0,0 +1,86 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::Arc; + +use parking_lot::{Condvar, Mutex}; + +#[derive(Clone)] +pub(crate) struct Sequencer(Arc); + +struct SequencerInner { + state: Mutex, + condvar: Condvar, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum State { + Empty, + Parked(usize), + Released(usize), +} + +impl Sequencer { + pub(crate) fn new() -> Self { + Self(Arc::new(SequencerInner { + state: Mutex::new(State::Empty), + condvar: Condvar::new(), + })) + } + + pub(crate) fn wait_for(&self, stage: usize) { + let mut state = self.0.state.lock(); + if stage == 0 { + assert_eq!(*state, State::Empty) + } else { + assert_eq!(*state, State::Released(stage - 1)) + } + + *state = State::Parked(stage); + self.0.condvar.notify_all(); + self.0 + .condvar + .wait_while(&mut state, move |s| *s != State::Released(stage)); + } + + pub(crate) fn advance_past(&self, stage: usize) { + let mut state = self.0.state.lock(); + self.0 + .condvar + .wait_while(&mut state, move |s| Self::check_release(*s, stage)); + *state = State::Released(stage); + self.0.condvar.notify_all(); + } + + pub(crate) fn until_waiting_for(&self, stage: usize) { + let mut state = self.0.state.lock(); + if *state != State::Parked(stage) { + self.0 + .condvar + .wait_while(&mut state, move |s| Self::check_release(*s, stage)) + } + } + + fn check_release(current: State, stage: usize) -> bool { + match current { + State::Empty => { + assert_eq!(stage, 0); + true + } + State::Released(s) => { + if s + 1 != stage { + panic!("observed {:?} while releasing stage {}", current, stage); + } + true + } + State::Parked(s) => { + if s != stage { + panic!("observed {:?} while releasing stage {}", current, stage) + } + false + } + } + } +} From 2f03e0da7ddcfaba871d6cdb606d3778a383cb44 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 17 Jun 2026 15:56:14 -0700 Subject: [PATCH 12/45] Getting closer. --- diskann-inmem/src/buffer.rs | 12 +- diskann-inmem/src/lib.rs | 4 +- diskann-inmem/src/neighbors.rs | 279 +++++++++++++++++++---------- diskann-inmem/src/provider.rs | 8 +- diskann-inmem/src/store.rs | 11 +- diskann-inmem/src/sync/epoch.rs | 90 ++++------ diskann-inmem/src/sync/freelist.rs | 75 ++------ diskann-inmem/src/sync/mod.rs | 8 +- diskann-inmem/src/sync/test.rs | 4 +- 9 files changed, 254 insertions(+), 237 deletions(-) diff --git a/diskann-inmem/src/buffer.rs b/diskann-inmem/src/buffer.rs index 4bee9fb07..bbe2397b5 100644 --- a/diskann-inmem/src/buffer.rs +++ b/diskann-inmem/src/buffer.rs @@ -9,12 +9,13 @@ use crate::num::{Align, Bytes}; /// An unsynchronized row-store for raw data. /// -/// The backing data is stored as a raw pointers and interacted with via [`RawSlice`], which +/// The backing data is stored as raw pointers and interacted with via [`RawSlice`], which /// is also raw pointer based. Careful use of this struct enables safe use of /// [`RawSlice::as_slice`], [`RawSlice::as_mut_slice`], and other accesses from multiple /// threads without undefined behavior. /// -/// Note that `Buffer` is unconditionally `Send` and `Sync`. +/// `Buffer` is unconditionally `Send` and `Sync`: it holds only a pointer plus metadata, +/// and the synchronization burden is shifted to users of [`RawSlice`]. #[derive(Debug)] pub(crate) struct Buffer { ptr: NonNull, @@ -163,8 +164,9 @@ impl<'a> RawSlice<'a> { /// The memory `[ptr, ptr.add(len.value()))` must be part of a single allocation for /// the duration of the lifetime `'a`. /// - /// However, this has the semantics of a pointer: multiple threads can hold a [`RawSlice`] - /// to the same piece of memory without undefined behavior. + /// The underlying allocation is safe to alias from multiple threads. [`RawSlice`] + /// itself is intentionally `!Send + !Sync`; each thread must derive its own from a + /// shared `&Buffer`. unsafe fn new(ptr: NonNull, len: Bytes) -> Self { Self { ptr, @@ -180,7 +182,7 @@ impl<'a> RawSlice<'a> { unsafe { self.truncate_unchecked(self.len.min(n)) } } - /// Shorten the slice to the `n`. + /// Shorten the slice to `n` bytes. /// /// # Safety /// diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index c8d438a5f..9c4b34bd0 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -4,6 +4,8 @@ */ mod buffer; +mod neighbors; + pub mod num; mod sync; @@ -11,10 +13,8 @@ pub mod ids; pub mod layers; mod store; -pub mod neighbors; pub mod provider; -pub use neighbors::Neighbors; pub use provider::{Context, Provider, Strategy}; #[cfg(test)] diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index 7138433f7..9bcefcd58 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -3,6 +3,26 @@ * Licensed under the MIT license. */ +//! A Concurrent Graph Structure +//! +//! The [`Neighbors`] data structure is a concurrent graph managed out of a single allocation. +//! The use of a single allocation puts a hard upper-bound on the length each adjacency list, +//! which is enforced by the types in this module. +//! +//! Concurrency is obtained using sharded read/write locks, with [`Neighbors::get`] and +//! [`Neighbors::set`] acquiring read and write locks respectively. +//! +//! To implement atomic read-modify-write operations, [`Neighbors::lock`] can be used to +//! obtain a [`Lock`]ed list. +//! +//! Due to lock sharding, attempting to acquire multiple [`Lock`]s to a single [`Neighbors`] +//! simultaneously can lead to dead-lock. +//! +//! ## Performance Considerations +//! +//! Adjacency lists written through the APIs exposed in this module are not validated for +//! uniqueness nor for being in-bounds. These are the caller's responsibility. + use std::ptr::NonNull; use diskann::{graph::AdjacencyList, utils::IntoUsize}; @@ -27,28 +47,37 @@ type Id = u32; /// Callers must not hold more than one [`Lock`] at a time. const LOCK_GRANULARITY: usize = 16; -fn lock_index(i: usize) -> usize { - i / LOCK_GRANULARITY +fn lock_index(i: u32) -> usize { + i.into_usize() / LOCK_GRANULARITY } +/// A concurrent graph data structure with a fixed number of adjacency lists and a fixed +/// upper-bound for each adjacency list's length. See the [module level docs](self) for +/// more detail. +/// +/// Adjacency lists are indexed by `[0, Neighbors::entries)`. #[derive(Debug)] -pub struct Neighbors { +pub(crate) struct Neighbors { neighbors: Buffer, locks: Vec>, } impl Neighbors { - pub fn new(entries: usize, max_length: usize) -> Result { - // This is exceedingly unlikely and - if max_length > (u32::MAX).into_usize() { - return Err(NeighborsError::AdjacencyListTooLong(max_length)); - } - + /// Construct a new [`Neighbors`] capable of holding `entries` adjacency lists with a + /// maximum length of `max_length`. + /// + /// # Errors + /// + /// Returns an error if `(max_length + 1) * size_of::()` overflows `usize` + /// (unreachable on 64-bit targets) or the resulting allocation would exceed + /// `isize::MAX` bytes. + pub(crate) fn new(entries: u32, max_length: u32) -> Result { let bytes = max_length + .into_usize() .checked_add(1) .and_then(|len| len.checked_mul(std::mem::size_of::())) .map(Bytes::new) - .ok_or(NeighborsError::AdjacencyListTooLong(max_length))?; + .ok_or(NeighborsError::Overflow(max_length))?; // We materialize slices of `Id` into the raw byte buffers. // @@ -62,36 +91,47 @@ impl Neighbors { ); } - let neighbors = Buffer::new(entries, bytes, ALIGN)?; + let neighbors = Buffer::new(entries.into_usize(), bytes, ALIGN)?; let locks = std::iter::repeat_with(|| RwLock::new(())) - .take(entries.div_ceil(LOCK_GRANULARITY)) + .take(entries.into_usize().div_ceil(LOCK_GRANULARITY)) .collect(); Ok(Self { neighbors, locks }) } /// Return the maximum length for any adjacency list. - pub fn max_length(&self) -> usize { + pub(crate) fn max_length(&self) -> usize { // We reserve 4 bytes at the beginning for the length of the adjacency list. (self.neighbors.stride().value() - std::mem::size_of::()) / std::mem::size_of::() } - pub fn entries(&self) -> usize { - self.neighbors.len() + /// Return the maximum length for any adjacency list as a 32-bit integer. + pub(crate) fn max_length_u32(&self) -> u32 { + // Lossless by the invariants on `Self::new`. + self.max_length() as u32 } - pub fn get(&self, i: usize, neighbors: &mut AdjacencyList) -> Result<(), OutOfBounds> { + /// Return the number of adjacency lists contained by this graph. + pub(crate) fn entries(&self) -> u32 { + // Cast is lossless by construction. + self.neighbors.len() as u32 + } + + /// Copy the contents of adjacency list `i` into `neighbors`. + /// + /// Returns an error if `i` exceeds [`Self::entries`]. + pub(crate) fn get(&self, i: u32, neighbors: &mut AdjacencyList) -> Result<(), OutOfBounds> { self.check(i)?; let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }; let _guard = lock.read(); - // SAFETY: By consruction `self.buffer` has the same number of entries as + // SAFETY: By construction `self.buffer` has the same number of entries as // `self.locks` and we have already checked that `i` is in-bounds there. let (prefix, rest) = - unsafe { self.neighbors.get_unchecked(i) }.split(Bytes::size_of::()); + unsafe { self.neighbors.get_unchecked(i.into_usize()) }.split(Bytes::size_of::()); debug_assert_eq!(prefix.len(), Bytes::size_of::()); debug_assert!(prefix.as_ptr().cast::().is_aligned()); @@ -99,8 +139,7 @@ impl Neighbors { // SAFETY: We hold the read-lock, so reading is safe. From our bounds checks, we // know that this pointer is valid. let len: usize = unsafe { prefix.as_ptr().cast::().read() } - .into_usize() - .min(self.max_length()); + .min(self.max_length_u32()).into_usize(); let mut resizer = neighbors.resize(len); unsafe { @@ -114,33 +153,49 @@ impl Neighbors { Ok(()) } - pub fn lock(&self, i: usize) -> Result, OutOfBounds> { + /// Lock adjacency list `i` for read-modify-write operations. + /// + /// Returns an error if `i` exceeds [`Self::entries`]. + pub(crate) fn lock(&self, i: u32) -> Result, OutOfBounds> { self.check(i)?; Ok(unsafe { self.lock_unchecked(i) }) } - unsafe fn lock_unchecked(&self, i: usize) -> Lock<'_> { + unsafe fn lock_unchecked(&self, i: u32) -> Lock<'_> { let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }.write(); - // SAFETY: By consruction `self.buffer` has the same number of entries as + // SAFETY: By construction `self.buffer` has the same number of entries as // `self.locks` and we have already checked that `i` is in-bounds there. - let slice = unsafe { self.neighbors.get_unchecked(i) }; + let slice = unsafe { self.neighbors.get_unchecked(i.into_usize()) }; debug_assert!(slice.as_ptr().cast::().is_aligned()); Lock { ptr: slice.as_non_null().cast::(), - capacity: self.max_length(), + capacity: self.max_length().into_usize(), _lock: lock, } } - pub fn set(&self, i: usize, neighbors: &[u32]) -> Result<(), SetError> { + /// Overwrite the contents of adjacency list `i` with `neighbors`. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * `i` exceeds [`Self::entries`]. + /// * `neighbors.len()` exceeds [`Self::max_length_u32`]. + /// + /// If an error is returned, the graph is left unmodified. + pub(crate) fn set(&self, i: u32, neighbors: &[u32]) -> Result<(), SetError> { self.check(i).map_err(SetError::OutOfBounds)?; // We can check the length of `neighbors` before acquiring any locks as an early exit. - if neighbors.len() > self.max_length() { - return Err(SetError::TooLong(TooLong)); + if neighbors.len() > self.max_length().into_usize() { + return Err(SetError::TooLong(TooLong { + got: neighbors.len(), + max: self.max_length_u32(), + })); } let lock = unsafe { self.lock_unchecked(i) }; @@ -148,7 +203,7 @@ impl Neighbors { Ok(()) } - fn check(&self, i: usize) -> Result<(), OutOfBounds> { + fn check(&self, i: u32) -> Result<(), OutOfBounds> { if i >= self.entries() { Err(OutOfBounds(i)) } else { @@ -157,14 +212,49 @@ impl Neighbors { } } +/// Errors returned by [`Neighbors::new`]. +#[derive(Debug, Error)] +pub(crate) enum NeighborsError { + /// Computing the per-list byte size `(max_length + 1) * size_of::()` overflowed + /// `usize`. + /// + /// Unreachable on 64-bit targets. + #[error("adjacency list length of {0} is too long")] + Overflow(u32), + + /// Allocation of the underlying buffer failed. + /// + /// This can occur if the total allocation size (`entries * per-list bytes`) + /// would exceed `isize::MAX`, or if the underlying allocator returns an error. + #[error("neighbor buffer allocation failed")] + AllocationFailed(#[from] BufferError), +} + + +/// Attempted to access a [`Neighbors`] at an out-of-bounds index. #[derive(Debug, Clone, Copy, Error)] #[error("index {} is out-of-bounds", self.0)] -pub struct OutOfBounds(usize); +pub(crate) struct OutOfBounds(u32); + +/// A neighbor list was longer than the configured per-list capacity. +/// +/// `got` is the caller-supplied length (any `usize`); `max` is the per-list capacity, +/// which is bounded by `u32` per [`Neighbors::new`]. +#[derive(Debug, Clone, Copy, Error)] +#[error("length {} exceeds the max length {}", self.got, self.max)] +pub(crate) struct TooLong { + got: usize, + max: u32, +} +/// Errors during [`Neighbors::set`]. #[derive(Debug, Clone, Copy, Error)] -pub enum SetError { +pub(crate) enum SetError { + /// Attempted to access an out-of-bounds index. #[error(transparent)] OutOfBounds(OutOfBounds), + + /// The new adjacency list was too long. #[error(transparent)] TooLong(TooLong), } @@ -173,7 +263,7 @@ pub enum SetError { /// /// Callers must not hold more than one `Lock` at a time. See [`LOCK_GRANULARITY`] for /// details on the deadlock hazard. -pub struct Lock<'a> { +pub(crate) struct Lock<'a> { ptr: NonNull, capacity: usize, _lock: RwLockWriteGuard<'a, ()>, @@ -181,50 +271,36 @@ pub struct Lock<'a> { impl Lock<'_> { /// Return the capacity of the neighbor buffer. - pub fn capacity(&self) -> usize { + pub(crate) fn capacity(&self) -> usize { self.capacity } /// Return the current length of the neighbor list. /// - /// This is guaranteed to be less than [`capacity`](Self::capacity). - pub fn len(&self) -> usize { + /// This is guaranteed to be less than or equal to [`capacity`](Self::capacity). + pub(crate) fn len(&self) -> usize { // SAFETY: By construction, `self.raw` has a length of at least 1. // - // The `min` operation is to be conservative. + // The `min` operation defensively clamps in case the stored length has been + // corrupted; under normal operation it should already be `<= capacity`. unsafe { self.ptr.read() }.into_usize().min(self.capacity()) } - /// Return `true` only if `self.len() == 0`. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// View the current contents of the locked adjacency list as a slice. - pub fn as_slice(&self) -> &[u32] { - let len = self.len(); - unsafe { std::slice::from_raw_parts(self.ptr.add(1).as_ptr().cast_const(), len) } - } - - /// Consume the [`Lock`] - copying the contents of `neighbors`. + /// Consume `self`, appending `neighbors` to the list. /// - /// Returns an error if `neighbors.len() > self.capacity()` without copying any of the - /// contents of `neighbors`. - pub fn write(self, neighbors: &[u32]) -> Result<(), TooLong> { - if neighbors.len() > self.capacity() { - return Err(TooLong); - } - - unsafe { self.write_unchecked(neighbors) }; - Ok(()) - } - - pub fn append(self, neighbors: &[u32]) -> Result<(), TooLong> { + /// Returns an error if the concatenated list would exceed [`Self::capacity`] without + /// modify the adjacency list. + /// + /// This method does not attempt to deduplicate `neighbors`. + pub(crate) fn append(self, neighbors: &[u32]) -> Result<(), TooLong> { let len = self.len(); let newlen = len.saturating_add(neighbors.len()); if newlen > self.capacity() { - return Err(TooLong); + return Err(TooLong { + got: newlen, + max: self.capacity as u32, + }); } unsafe { @@ -245,6 +321,31 @@ impl Lock<'_> { unsafe { std::ptr::copy_nonoverlapping(neighbors.as_ptr(), self.ptr.as_ptr().add(1), len) } unsafe { self.ptr.write(len as u32) }; } + + #[cfg(test)] + fn as_slice(&self) -> &[u32] { + let len = self.len(); + unsafe { std::slice::from_raw_parts(self.ptr.add(1).as_ptr().cast_const(), len) } + } + + #[cfg(test)] + fn write(self, neighbors: &[u32]) -> Result<(), TooLong> { + if neighbors.len() > self.capacity() { + return Err(TooLong { + got: neighbors.len(), + max: self.capacity as u32, + }); + } + + unsafe { self.write_unchecked(neighbors) }; + Ok(()) + } + + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.len() == 0 + } } impl std::fmt::Debug for Lock<'_> { @@ -257,17 +358,6 @@ impl std::fmt::Debug for Lock<'_> { } } -#[derive(Debug, Error)] -pub enum NeighborsError { - #[error("adjacency list length of {} is too long", 0)] - AdjacencyListTooLong(usize), - #[error("neighbor bufffer allocation failed")] - AllocationFailed(#[from] BufferError), -} - -#[derive(Debug, Clone, Copy, Error)] -#[error("too long")] -pub struct TooLong; /////////// // Tests // @@ -279,22 +369,31 @@ mod tests { use crate::test::Sequencer; - // Constructor errors + // -- OutOfBounds checks -- #[test] - fn new_rejects_max_length_exceeding_u32_max() { - let result = Neighbors::new(10, (u32::MAX as usize) + 1); - assert!(matches!( - result, - Err(NeighborsError::AdjacencyListTooLong(_)) - )); + fn out_of_bounds_rejects_indices_beyond_entries() { + let n = Neighbors::new(4, 4).unwrap(); + // entries == 4, so valid indices are 0..=3. + // Regression test: a buggy `check` using `i == entries()` would let + // `entries+1`, `entries+2`, ... slip through to UB. + let mut out = AdjacencyList::with_capacity(4); + for bad in [4u32, 5, 100, u32::MAX] { + assert!(matches!(n.get(bad, &mut out), Err(OutOfBounds(_)))); + assert!(matches!(n.set(bad, &[]), Err(SetError::OutOfBounds(_)))); + assert!(matches!(n.lock(bad), Err(OutOfBounds(_)))); + } } #[test] - fn new_rejects_allocation_overflow() { - // entries * (max_length + 1) * sizeof(Id) overflows. - let result = Neighbors::new(usize::MAX, 64); - assert!(matches!(result, Err(NeighborsError::AllocationFailed(_)))); + fn empty_neighbors_rejects_all_access() { + let n = Neighbors::new(0, 4).unwrap(); + let mut out = AdjacencyList::with_capacity(4); + for i in [0u32, 1, u32::MAX] { + assert!(matches!(n.get(i, &mut out), Err(OutOfBounds(_)))); + assert!(matches!(n.set(i, &[]), Err(SetError::OutOfBounds(_)))); + assert!(matches!(n.lock(i), Err(OutOfBounds(_)))); + } } // TooLong errors @@ -445,14 +544,14 @@ mod tests { Err(SetError::OutOfBounds(_)) )); - let generate = |round: usize, entry: usize| -> Vec { + let generate = |round: u32, entry: u32| -> Vec { (0..(round + 1)) - .map(|r| (entry + r).try_into().unwrap()) + .map(|r| entry + r) .collect() }; // Test mutation via `Neighbors::set`. - for round in 0..neighbors.max_length() { + for round in 0..neighbors.max_length_u32() { for i in 0..neighbors.entries() { let v = generate(round, i); neighbors.set(i, &v).unwrap(); @@ -471,7 +570,7 @@ mod tests { clear(&mut neighbors); // Test mutation via `lock + write`. - for round in 0..neighbors.max_length() { + for round in 0..neighbors.max_length_u32() { for i in 0..neighbors.entries() { let v = generate(round, i); neighbors.lock(i).unwrap().write(&v).unwrap(); @@ -490,12 +589,12 @@ mod tests { clear(&mut neighbors); // Test mutation via `lock + append`. - for round in 0..neighbors.max_length() { + for round in 0..neighbors.max_length_u32() { for i in 0..neighbors.entries() { neighbors .lock(i) .unwrap() - .append(&[(round + i).try_into().unwrap()]) + .append(&[round + i]) .unwrap(); } @@ -520,7 +619,7 @@ mod tests { // Verify that holding a `Lock` correctly blocks reads for the same adjacency list. #[test] fn lock_blocks_get() { - for i in 0..10 { + for _ in 0..10 { let neighbors = Neighbors::new(3, 4).unwrap(); let seq = Sequencer::new(); diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 4f03c9ee0..5bb4e4e3c 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -292,7 +292,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { for i in ids { self.reader .neighbors() - .get(i.into_usize(), &mut self.ids) + .get(i, &mut self.ids) .unwrap(); // Filter out unvisited IDs and ensure that all the IDs we are about @@ -475,7 +475,7 @@ impl provider::NeighborAccessor for PruneAccessor<'_> { Ok(self .reader .neighbors() - .get(id.into_usize(), neighbors) + .get(id, neighbors) .unwrap()) }; ready(work) @@ -492,7 +492,7 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { Ok(self .reader .neighbors() - .set(id.into_usize(), neighbors) + .set(id, neighbors) .unwrap()) }; ready(work) @@ -506,7 +506,7 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { let work = move || -> ANNResult<()> { self.reader .neighbors() - .lock(id.into_usize()) + .lock(id) .unwrap() .append(neighbors) .unwrap(); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index deeadae9c..c935e69ee 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -9,7 +9,7 @@ use diskann::utils::IntoUsize; use diskann_utils::views::MatrixView; use crate::{ - Neighbors, + neighbors::Neighbors, buffer::{Buffer, RawSlice}, num::{Align, Bytes}, sync::{AtomicTag, Freelist, Tag, Registry, epoch, freelist}, @@ -48,7 +48,7 @@ impl Primary { let unpadded = bytes.checked_add(SPLIT).unwrap(); let padded_bytes = unpadded.checked_next_multiple_of(Bytes::CACHELINE).unwrap(); - let total = entries.checked_add(init.nrows()).unwrap(); + let total: usize = entries.checked_add(init.nrows()).unwrap(); let this = Self { buffer: Buffer::new(total, padded_bytes, Align::_128).unwrap(), @@ -62,7 +62,7 @@ impl Primary { // we do not want it to release frozen IDs. freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), registry: Registry::new(), - neighbors: Neighbors::new(total, max_neighbors).unwrap(), + neighbors: Neighbors::new(total.try_into().unwrap(), max_neighbors.try_into().unwrap()).unwrap(), }; // Populate frozen points. @@ -130,7 +130,7 @@ impl Primary { buffer: &self.buffer, unpadded: self.unpadded, neighbors: &self.neighbors, - epoch: self.registry.guard()?, + _epoch: self.registry.guard()?, }) } @@ -264,7 +264,8 @@ pub struct Reader<'a> { buffer: &'a Buffer, unpadded: Bytes, neighbors: &'a Neighbors, - epoch: epoch::Guard<'a>, + // It's important that we hold onto this, even if we don't use it. + _epoch: epoch::Guard<'a>, } impl<'a> Reader<'a> { diff --git a/diskann-inmem/src/sync/epoch.rs b/diskann-inmem/src/sync/epoch.rs index b36c95403..0e5b47b21 100644 --- a/diskann-inmem/src/sync/epoch.rs +++ b/diskann-inmem/src/sync/epoch.rs @@ -56,7 +56,7 @@ const CAPACITY: usize = 256; /// A registry of epoch-based [`Guard`]s. See the [module-level docs](self). #[derive(Debug)] -pub struct Registry { +pub(crate) struct Registry { // A record of the active guards. // // * 0 = "available". @@ -105,14 +105,14 @@ fn last_queue(epoch: u64) -> usize { impl Registry { /// Construct a new [`Registry`] with the default number of guard slots (256). - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self::with_capacity(CAPACITY) } /// Construct a new [`Registry`] with `capacity` guard slots. /// /// This is the number of [`Guard`]s that can be registered concurrently. - pub fn with_capacity(capacity: usize) -> Self { + pub(crate) fn with_capacity(capacity: usize) -> Self { Self { guards: (0..capacity).map(|_| AtomicU64::new(0)).collect(), hint: AtomicUsize::new(0), @@ -122,15 +122,10 @@ impl Registry { } } - /// Return the number of [`Guard`]s this [`Registry`] supports. - pub fn capacity(&self) -> usize { - self.guards.len() - } - /// Return the current epoch. /// /// This has [`Ordering::Acquire`] semantics. - pub fn epoch(&self) -> u64 { + pub(crate) fn epoch(&self) -> u64 { self.epoch.load(Ordering::Acquire) } @@ -142,7 +137,7 @@ impl Registry { /// /// Returns an error if the number of currently active guards exceeds [`Self::capacity`] /// and thus a new guard cannot be made. - pub fn guard(&self) -> Result, Unavailable> { + pub(crate) fn guard(&self) -> Result, Unavailable> { self.guard_inner(NoDelay) } @@ -189,6 +184,7 @@ impl Registry { return Ok(Guard { slot: m, retire: &self.retiring[queue(epoch)], + #[cfg(test)] epoch, #[cfg(test)] slot_index: slot, @@ -199,19 +195,7 @@ impl Registry { Err(Unavailable) } - /// Return `true` if the epoch can be advanced. - /// - /// This uses a fast method that may be conservative: it can return `false` even when a - /// subsequent call to [`Self::try_advance`] would succeed (for example, if a guard slot - /// is observed to hold an old epoch but the corresponding `Guard` is about to be - /// dropped). - /// - /// This is a synchronizing operation with [`Ordering::Acquire`] semantics. - pub fn can_advance(&self) -> bool { - self.can_advance_inner(&mut NoDelay).0 - } - - fn can_advance_inner(&self, delay: &mut T) -> (bool, u64) + fn can_advance(&self, delay: &mut T) -> (bool, u64) where T: CanAdvanceDelay, { @@ -275,7 +259,7 @@ impl Registry { /// /// Panics if the epoch counter is about to overflow `u64::MAX`. In practice this is /// effectively unreachable. - pub fn try_advance(&self) -> Option> { + pub(crate) fn try_advance(&self) -> Option> { self.try_advance_inner(NoDelay) } @@ -291,7 +275,7 @@ impl Registry { // This can help save an expensive slot scan. let drain = self.drain.try_lock()?; - let (can_advance, current) = self.can_advance_inner(&mut delay); + let (can_advance, current) = self.can_advance(&mut delay); // Don't wrap around! if current == u64::MAX { @@ -312,7 +296,7 @@ impl Registry { debug_assert_eq!(_previous, current, "concurrency violation"); let queue = &self.retiring[last_queue(current)]; - Some(Drain { queue, drain }) + Some(Drain { queue, _drain: drain }) } else { // Previous generation has not completely retired. None @@ -336,7 +320,7 @@ impl Registry { #[cfg(test)] fn waiting(&self) -> u64 { - self.can_advance_inner(&mut NoDelay).1 + self.can_advance(&mut NoDelay).1 } } @@ -347,22 +331,18 @@ impl Registry { /// /// Obtained via [`Registry::guard`]. #[derive(Debug)] -pub struct Guard<'a> { +pub(crate) struct Guard<'a> { slot: &'a AtomicU64, retire: &'a SegQueue, - epoch: u64, + + #[cfg(test)] + pub(super) epoch: u64, #[cfg(test)] slot_index: usize, } impl Guard<'_> { - /// Return the epoch associated with this [`Guard`]'s creation. - #[inline] - pub fn epoch(&self) -> u64 { - self.epoch - } - /// Retire the id `i` at this guard's epoch. /// /// `i` is a caller-defined id (typically an index into external storage).It will be @@ -371,19 +351,9 @@ impl Guard<'_> { /// /// See also: [`Self::retire_all`]. #[inline] - pub fn retire(&self, i: u32) { + pub(crate) fn retire(&self, i: u32) { self.retire.push(i) } - - /// Retire all ids in `itr`. See [`Self::retire`]. - pub fn retire_all(&self, itr: I) - where - I: IntoIterator, - { - for i in itr { - self.retire(i) - } - } } impl Drop for Guard<'_> { @@ -397,26 +367,26 @@ impl Drop for Guard<'_> { /// While this drain is alive, no other thread can advance the [`Registry`]'s epoch. Drop /// it promptly after processing. #[derive(Debug)] -pub struct Drain<'a> { +pub(crate) struct Drain<'a> { queue: &'a SegQueue, - drain: MutexGuard<'a, ()>, + _drain: MutexGuard<'a, ()>, } impl Drain<'_> { /// Pop the next id ready for reclamation, or `None` if the drain is empty. #[must_use = "reclaimed ids must be reclaimed"] - pub fn pop(&self) -> Option { + pub(crate) fn pop(&self) -> Option { self.queue.pop() } /// Return the number of ids remaining in this drain. - pub fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.queue.len() } - /// Return `true` if there are no ids remaining in this drain. - pub fn is_empty(&self) -> bool { - self.queue.is_empty() + #[cfg(test)] + fn is_empty(&self) -> bool { + self.len() == 0 } } @@ -439,7 +409,7 @@ impl ExactSizeIterator for Drain<'_> {} /// Returned by [`Registry::guard`] when all guard slots are occupied. #[derive(Debug)] #[non_exhaustive] -pub struct Unavailable; +pub(crate) struct Unavailable; impl std::fmt::Display for Unavailable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -511,8 +481,6 @@ mod tests { .with_post_fence(|| thread_a_loop_count += 1); let registry = Registry::with_capacity(2); - assert_eq!(registry.capacity(), 2); - std::thread::scope(|s| { // Thread A s.spawn(|| { @@ -565,7 +533,9 @@ mod tests { // Since we hit the CAS loop - this serves as a sanity check that we have // the correct drain buffer. guard.retire(10); - guard.retire_all([1, 2, 3]); + guard.retire(1); + guard.retire(2); + guard.retire(3); guard }); @@ -590,7 +560,7 @@ mod tests { // We pause it again because we want to verify that it registers an old generation. seq.advance_past(0); seq.until_waiting_for(1); - let (can_advance, waiter) = registry.can_advance_inner(&mut NoDelay); + let (can_advance, waiter) = registry.can_advance(&mut NoDelay); assert!(!can_advance); assert_eq!( waiter, 1, @@ -603,7 +573,7 @@ mod tests { // The generation should be the last set one - even though this thread was // parked during the transition. let r = handle.join().unwrap(); - assert_eq!(r.epoch(), expected); + assert_eq!(r.epoch, expected); assert_eq!(registry.waiting(), expected); }); @@ -759,7 +729,7 @@ mod tests { // Helper: register, retire one item, drop. Returns the generation we retired at. let retire_at = |id: u32| { let g = registry.guard().unwrap(); - let epoch = g.epoch(); + let epoch = g.epoch; g.retire(id); epoch }; diff --git a/diskann-inmem/src/sync/freelist.rs b/diskann-inmem/src/sync/freelist.rs index c3b12aa08..1c32e8be2 100644 --- a/diskann-inmem/src/sync/freelist.rs +++ b/diskann-inmem/src/sync/freelist.rs @@ -74,7 +74,7 @@ const SCAN_SIZE: u32 = 256; /// /// See [freelist](self) for details. #[derive(Debug)] -pub struct Freelist { +pub(crate) struct Freelist { // Bounded fast queue of retired slots. recycled: ArrayQueue, @@ -95,7 +95,7 @@ impl Freelist { /// The internal fast recycled list will hold up to `recycled` items. /// /// The memory occupied by this struct is `O(recycled)`. - pub fn new(max: u32, recycled: NonZeroU32) -> Self { + pub(crate) fn new(max: u32, recycled: NonZeroU32) -> Self { Self { recycled: ArrayQueue::new(recycled.get().into_usize()), max, @@ -104,15 +104,10 @@ impl Freelist { } } - /// Return the maximum number of slot IDs managed by `self`. - pub fn max(&self) -> u32 { - self.max - } - /// Try to retrieve an id. /// /// If successful, return [`Id::Found`]. Otherwise, returns [`Id::Scan`]. - pub fn pop(&self) -> Id { + pub(crate) fn pop(&self) -> Id { if let Some(id) = self.recycled.pop() { return Id::Found(id); } @@ -138,7 +133,7 @@ impl Freelist { /// Attempt to retrieve an ID directly from the recycled list. /// /// This may be used during scans to retrieve IDs found by other threads. - pub fn pop_recycled(&self) -> Option { + pub(crate) fn pop_recycled(&self) -> Option { self.recycled.pop() } @@ -146,7 +141,7 @@ impl Freelist { /// /// This is managed such that multiple threads calling this function will receive /// disjoint ranges to scan. - pub fn scan(&self) -> Scan { + pub(crate) fn scan(&self) -> Scan { if self.max == 0 { return Scan { start: 0, stop: 0 }; } @@ -172,43 +167,19 @@ impl Freelist { /// If `false` is returned, it is likely because the internal recycle buffer is full. /// /// IDs exceeding [`Self::max`] are discarded. - pub fn push(&self, id: u32) -> bool { + pub(crate) fn push(&self, id: u32) -> bool { if id < self.max { self.recycled.push(id).is_ok() } else { false } } - - /// Append items from `itr` into the recycled buffer. Return the number of items - /// actually added. - /// - /// Callers may not assume that `itr` is fully consumed. - /// - /// IDs exceeding [`Self::max`] are discarded. - pub fn append(&self, itr: I) -> usize - where - I: IntoIterator, - { - let mut count = 0; - for id in itr { - if id < self.max { - if let Err(_) = self.recycled.push(id) { - break; - } else { - count += 1; - } - } - } - - count - } } /// The result of [`Freelist::pop`]. #[derive(Debug, Clone, Copy)] #[must_use] -pub enum Id { +pub(crate) enum Id { /// An ID was found directly in the [`Freelist`]. Found(u32), /// No ID was found in the [`Freelist`] and an exhaustive scan is recommended. @@ -231,13 +202,14 @@ impl Id { /// An [`ExactSizeIterator`] over IDs to scan. Returned by [`Freelist::scan`]. #[derive(Debug)] -pub struct Scan { +pub(crate) struct Scan { start: u32, stop: u32, } impl Scan { - pub fn as_range(&self) -> std::ops::Range { + #[cfg(test)] + fn as_range(&self) -> std::ops::Range { self.start..self.stop } } @@ -283,7 +255,6 @@ mod tests { #[test] fn pop_mints_sequentially_until_exhausted() { let fl = freelist(4, 8); - assert_eq!(fl.max(), 4); let mut got = Vec::new(); for _ in 0..4 { @@ -298,7 +269,6 @@ mod tests { fn pop_returns_scan_when_max_zero() { let fl = freelist(0, 1); assert!(fl.pop().is_scan()); - assert_eq!(fl.max(), 0); } #[test] @@ -353,31 +323,6 @@ mod tests { assert_eq!(fl.pop().unwrap(), 0); } - //--------// - // Append // - //--------// - - #[test] - fn append_counts_inserted_and_skips_out_of_range() { - let fl = freelist(4, 8); - let count = fl.append([0u32, 4, 1, 7, 2].iter().copied()); - // 4 and 7 are >= max and skipped; 0, 1, 2 are inserted. - assert_eq!(count, 3); - let mut got = Vec::new(); - while let Some(id) = fl.pop_recycled() { - got.push(id); - } - got.sort(); - assert_eq!(got, vec![0, 1, 2]); - } - - #[test] - fn append_stops_when_buffer_full() { - let fl = freelist(16, 2); - let count = fl.append(0u32..16); - assert_eq!(count, 2); - } - //------// // Scan // //------// diff --git a/diskann-inmem/src/sync/mod.rs b/diskann-inmem/src/sync/mod.rs index 6ec810631..e5553f944 100644 --- a/diskann-inmem/src/sync/mod.rs +++ b/diskann-inmem/src/sync/mod.rs @@ -3,11 +3,11 @@ * Licensed under the MIT license. */ -pub mod epoch; -pub use epoch::Registry; +pub(crate) mod epoch; +pub(crate) use epoch::Registry; -pub mod freelist; -pub use freelist::Freelist; +pub(crate) mod freelist; +pub(crate) use freelist::Freelist; mod tag; pub use tag::{AtomicTag, Tag}; diff --git a/diskann-inmem/src/sync/test.rs b/diskann-inmem/src/sync/test.rs index 518ea730a..39297beea 100644 --- a/diskann-inmem/src/sync/test.rs +++ b/diskann-inmem/src/sync/test.rs @@ -143,7 +143,7 @@ fn read_job( loop { let mut reads = Vec::<&Data>::new(); let guard = registry.guard().unwrap(); - if guard.epoch() >= stop_at { + if guard.epoch >= stop_at { break; } @@ -158,7 +158,7 @@ fn read_job( std::thread::yield_now(); records.push(Record { - epoch: guard.epoch(), + epoch: guard.epoch, index: i, data: *read, }); From 89209ef44960afe1d5c993ee766ba28d7b06e79b Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 17 Jun 2026 16:07:36 -0700 Subject: [PATCH 13/45] Checkpoint. --- diskann-inmem/src/lib.rs | 4 +- diskann-inmem/src/sync/epoch.rs | 61 +++--------------------------- diskann-inmem/src/sync/freelist.rs | 14 +++---- diskann-inmem/src/sync/mod.rs | 2 +- diskann-inmem/src/sync/tag.rs | 26 +++++++------ 5 files changed, 30 insertions(+), 77 deletions(-) diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index 9c4b34bd0..a6b0bb29e 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -3,11 +3,13 @@ * Licensed under the MIT license. */ +#![deny(rustdoc::broken_intra_doc_links)] + mod buffer; mod neighbors; +mod sync; pub mod num; -mod sync; pub mod ids; pub mod layers; diff --git a/diskann-inmem/src/sync/epoch.rs b/diskann-inmem/src/sync/epoch.rs index 0e5b47b21..10bf21d81 100644 --- a/diskann-inmem/src/sync/epoch.rs +++ b/diskann-inmem/src/sync/epoch.rs @@ -94,7 +94,7 @@ pub(crate) struct Registry { retiring: [SegQueue; 3], } -// Return the queue index for the `generation`. +// Return the queue index for the `epoch`. fn queue(epoch: u64) -> usize { epoch.into_usize() % 3 } @@ -135,8 +135,8 @@ impl Registry { /// /// # Errors /// - /// Returns an error if the number of currently active guards exceeds [`Self::capacity`] - /// and thus a new guard cannot be made. + /// Returns an error if the number of currently active guards exceeds the number of + /// internal guard slots and thus a new guard cannot be made. pub(crate) fn guard(&self) -> Result, Unavailable> { self.guard_inner(NoDelay) } @@ -345,11 +345,9 @@ pub(crate) struct Guard<'a> { impl Guard<'_> { /// Retire the id `i` at this guard's epoch. /// - /// `i` is a caller-defined id (typically an index into external storage).It will be + /// `i` is a caller-defined id (typically an index into external storage). It will be /// returned from a future [`Drain`] once the registry has advanced far enough that no /// reader could observe it. - /// - /// See also: [`Self::retire_all`]. #[inline] pub(crate) fn retire(&self, i: u32) { self.retire.push(i) @@ -656,7 +654,7 @@ mod tests { let s0 = g0.slot_index; let s1 = g1.slot_index; - // Due to how the hint works, the slotws could be acquired in either order. + // Due to how the hint works, the slots could be acquired in either order. if s0 < s1 { assert_eq!((s0, s1), (0, 1)); } else { @@ -829,53 +827,4 @@ mod tests { with_pre_fence => pre_fence, with_post_fence => post_fence, } - - // #[derive(Default)] - // struct TestGuardDelay<'a> { - // post_guard_check: Option<&'a mut dyn FnMut()>, - // pre_cas: Option<&'a mut dyn FnMut()>, - // pre_fence: Option<&'a mut dyn FnMut()>, - // post_fence: Option<&'a mut dyn FnMut()>, - // } - - // macro_rules! builder { - // ($f:ident, $field:ident) => { - // fn $f(mut self, f: &'a mut dyn FnMut()) -> Self { - // self.$field = Some(f); - // self - // } - // } - // } - - // macro_rules! forward { - // ($f:ident) => { - // fn $f(&mut self) { - // if let Some(f) = self.$f.as_mut() { - // f() - // } - // } - // } - // } - - // impl<'a> TestGuardDelay<'a> { - // builder!(post_guard_check, post_guard_check); - // builder!(with_pre_cas, pre_cas); - // builder!(with_pre_fence, pre_fence); - // builder!(with_post_fence, post_fence); - // } - - // impl GuardDelay for TestGuardDelay<'_> { - // forward!(post_guard_check); - // forward!(pre_cas); - // forward!(pre_fence); - // forward!(post_fence); - // } - - // struct CanAdvanceDelay; - - // impl CanAdvanceDelay for TestWaitingDelay {} - - // struct TestTryAdvanceDelay; - - // impl TryAdvanceDelay for TestTryAdvanceDelay {} } diff --git a/diskann-inmem/src/sync/freelist.rs b/diskann-inmem/src/sync/freelist.rs index 1c32e8be2..966153c6d 100644 --- a/diskann-inmem/src/sync/freelist.rs +++ b/diskann-inmem/src/sync/freelist.rs @@ -8,16 +8,16 @@ //! When working with slots into an index, finding an available slot efficiently can be //! challenging. This module provides a [`Freelist`] to make this more efficient. //! -//! IDs are retrieved in several orders of precedence: +//! IDs are retrieved in the following order of precedence: //! //! ## Recycles //! -//! Previously reclaimed slots can be recycled and is the preferred way of finding slots. +//! Previously reclaimed slots can be recycled and are the preferred way of finding slots. //! Reclaimed slots IDs live inside an atomic queue and as such, the size of this queue is //! bounded to conserve memory. //! //! ## Minted -//! If not slots live in the recycled queue, new slots can be "minted" up to the configured +//! If no slots live in the recycled queue, new slots can be "minted" up to the configured //! maximum. This simply tracks the maximum slot ID that has been yielded so far and returns //! the next one. //! @@ -33,10 +33,10 @@ //! //! The [`Freelist`] assists with scans in several ways: //! -//! 1. [`Freelist::scan`]: Receive a range of managed ID ranges to scan. Multiple threads +//! 1. [`Freelist::scan`]: Receive a range of managed IDs to scan. Multiple threads //! can call this method to receive disjoint ranges to process. //! -//! 2. [`Freelist::push`]/[`Freelist::append`]: Available slots can be placed into the +//! 2. [`Freelist::push`]: Available slots can be placed into the //! freelist for recycling. //! //! 3. [`Freelist::pop_recycled`]: Attempt to retrieve a slot ID directly from the recycled @@ -52,7 +52,7 @@ //! # Non-Authoritative //! //! Note that the [`Freelist`] does not attempt to be authoritative on the list of slots IDs -//! that are used and unused. It's job is mainly to improve performance. +//! that are used and unused. Its job is mainly to improve performance. //! //! An authoritative collection of [`AtomicTag`](super::AtomicTag)s must be used to correctly //! manage slots. @@ -166,7 +166,7 @@ impl Freelist { /// /// If `false` is returned, it is likely because the internal recycle buffer is full. /// - /// IDs exceeding [`Self::max`] are discarded. + /// IDs at or above [`Self::max`] are discarded. pub(crate) fn push(&self, id: u32) -> bool { if id < self.max { self.recycled.push(id).is_ok() diff --git a/diskann-inmem/src/sync/mod.rs b/diskann-inmem/src/sync/mod.rs index e5553f944..104bf71de 100644 --- a/diskann-inmem/src/sync/mod.rs +++ b/diskann-inmem/src/sync/mod.rs @@ -10,7 +10,7 @@ pub(crate) mod freelist; pub(crate) use freelist::Freelist; mod tag; -pub use tag::{AtomicTag, Tag}; +pub(crate) use tag::{AtomicTag, Tag}; #[cfg(test)] mod test; diff --git a/diskann-inmem/src/sync/tag.rs b/diskann-inmem/src/sync/tag.rs index 4b30d1162..81dfd14a7 100644 --- a/diskann-inmem/src/sync/tag.rs +++ b/diskann-inmem/src/sync/tag.rs @@ -78,7 +78,7 @@ use std::sync::atomic::{AtomicU8, Ordering}; /// Checks to [`Tag::can_read`] can be made following [`Ordering::Acquire`] loads. #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[repr(transparent)] -pub struct Tag(u8); +pub(crate) struct Tag(u8); impl Tag { //-------------// @@ -86,10 +86,10 @@ impl Tag { //-------------// /// The slot is permanently readable and never mutated again. See [`Tag`]. - pub const FROZEN: Self = Self::new(u8::MAX); + pub(crate) const FROZEN: Self = Self::new(u8::MAX); /// The slot has been published and is freely readable. See [`Tag`]. - pub const PUBLISHED: Self = Self::new(u8::MAX - 1); + pub(crate) const PUBLISHED: Self = Self::new(u8::MAX - 1); //------------// // Low Values // @@ -97,16 +97,18 @@ impl Tag { /// The slot holds no valid data and may be claimed via CAS to [`Tag::OWNED`]. /// See [`Tag`]. - pub const AVAILABLE: Self = Self::new(0); + pub(crate) const AVAILABLE: Self = Self::new(0); /// The slot is exclusively owned by a single thread that may write its data. /// See [`Tag`]. - pub const OWNED: Self = Self::new(1); + pub(crate) const OWNED: Self = Self::new(1); /// The slot is in the process of being retired and is no longer readable to new /// readers. See [`Tag`]. - pub const RETIRING: Self = Self::new(2); + pub(crate) const RETIRING: Self = Self::new(2); + /// NOTE: We rely on reserved values being contiguous so `is_reserved` can be + /// implemented relatively efficiently. const RESERVED: Self = Self::RETIRING; /// Return `true` if `self` is one of the protocol's reserved tag values. @@ -164,11 +166,11 @@ impl std::fmt::Display for Tag { /// protocol described on [`Tag`]. #[derive(Debug)] #[repr(transparent)] -pub struct AtomicTag(AtomicU8); +pub(crate) struct AtomicTag(AtomicU8); impl AtomicTag { /// Construct a new [`AtomicTag`] initialized to `tag`. - pub const fn new(tag: Tag) -> Self { + pub(crate) const fn new(tag: Tag) -> Self { Self(AtomicU8::new(tag.value())) } @@ -184,7 +186,7 @@ impl AtomicTag { /// different sizes without synchronization. /// /// See: - pub unsafe fn from_ptr<'a>(ptr: *mut AtomicTag) -> &'a Self { + pub(crate) unsafe fn from_ptr<'a>(ptr: *mut AtomicTag) -> &'a Self { unsafe { &*ptr } } @@ -194,7 +196,7 @@ impl AtomicTag { /// ensure `current` and `new` correspond to a legal transition. /// /// See: [`AtomicU8::compare_exchange`]. - pub fn compare_exchange( + pub(crate) fn compare_exchange( &self, current: Tag, new: Tag, @@ -210,14 +212,14 @@ impl AtomicTag { /// Perform an atomic load with the provided ordering. /// /// See: [`AtomicU8::load`]. - pub fn load(&self, ordering: Ordering) -> Tag { + pub(crate) fn load(&self, ordering: Ordering) -> Tag { Tag::new(self.0.load(ordering)) } /// Perform an atomic store with the provided ordering. /// /// See: [`AtomicU8::store`]. - pub fn store(&self, val: Tag, ordering: Ordering) { + pub(crate) fn store(&self, val: Tag, ordering: Ordering) { self.0.store(val.value(), ordering) } } From 70f03b9c1a05afb495207c19d5cea208dbaa972e Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 17 Jun 2026 17:03:53 -0700 Subject: [PATCH 14/45] Reorganize. --- diskann-inmem/src/buffer.rs | 6 +- diskann-inmem/src/{sync => }/epoch.rs | 5 +- diskann-inmem/src/{sync => }/freelist.rs | 0 diskann-inmem/src/ids/mod.rs | 7 - diskann-inmem/src/layers/full.rs | 10 +- diskann-inmem/src/lib.rs | 12 +- diskann-inmem/src/neighbors.rs | 25 ++- diskann-inmem/src/num.rs | 5 + diskann-inmem/src/provider.rs | 74 ++++----- diskann-inmem/src/{ids => }/sharded.rs | 0 diskann-inmem/src/store.rs | 144 +++++++++++------- diskann-inmem/src/sync/mod.rs | 16 -- diskann-inmem/src/{sync => }/tag.rs | 0 .../src/{sync/test.rs => test/epoch.rs} | 5 +- diskann-inmem/src/test/mod.rs | 10 ++ .../src/{test.rs => test/sequencer.rs} | 0 16 files changed, 168 insertions(+), 151 deletions(-) rename diskann-inmem/src/{sync => }/epoch.rs (99%) rename diskann-inmem/src/{sync => }/freelist.rs (100%) delete mode 100644 diskann-inmem/src/ids/mod.rs rename diskann-inmem/src/{ids => }/sharded.rs (100%) delete mode 100644 diskann-inmem/src/sync/mod.rs rename diskann-inmem/src/{sync => }/tag.rs (100%) rename diskann-inmem/src/{sync/test.rs => test/epoch.rs} (99%) create mode 100644 diskann-inmem/src/test/mod.rs rename diskann-inmem/src/{test.rs => test/sequencer.rs} (100%) diff --git a/diskann-inmem/src/buffer.rs b/diskann-inmem/src/buffer.rs index bbe2397b5..8c2721e55 100644 --- a/diskann-inmem/src/buffer.rs +++ b/diskann-inmem/src/buffer.rs @@ -33,7 +33,11 @@ impl Buffer { /// /// Returns an error if the number of bytes `bytes_per_entry * entries` rounded up to /// the next multiple of `align` exceeds `isize::MAX`. - pub(crate) fn new(entries: usize, bytes_per_entry: Bytes, align: Align) -> Result { + pub(crate) fn new( + entries: usize, + bytes_per_entry: Bytes, + align: Align, + ) -> Result { // If we overflow `usize::MAX`, we will definitely overflow `isize::MAX`. let bytes = bytes_per_entry.checked_mul(entries).ok_or(BufferError)?; diff --git a/diskann-inmem/src/sync/epoch.rs b/diskann-inmem/src/epoch.rs similarity index 99% rename from diskann-inmem/src/sync/epoch.rs rename to diskann-inmem/src/epoch.rs index 10bf21d81..562623eeb 100644 --- a/diskann-inmem/src/sync/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -296,7 +296,10 @@ impl Registry { debug_assert_eq!(_previous, current, "concurrency violation"); let queue = &self.retiring[last_queue(current)]; - Some(Drain { queue, _drain: drain }) + Some(Drain { + queue, + _drain: drain, + }) } else { // Previous generation has not completely retired. None diff --git a/diskann-inmem/src/sync/freelist.rs b/diskann-inmem/src/freelist.rs similarity index 100% rename from diskann-inmem/src/sync/freelist.rs rename to diskann-inmem/src/freelist.rs diff --git a/diskann-inmem/src/ids/mod.rs b/diskann-inmem/src/ids/mod.rs deleted file mode 100644 index ab2ae10e8..000000000 --- a/diskann-inmem/src/ids/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -pub(crate) mod sharded; -pub(crate) use sharded::Sharded; diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 9edf3f88a..8b16a7c0e 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -76,7 +76,10 @@ where { type Query<'a> = &'a [T]; - fn query_distance<'a>(&'a self, query: &'a [T]) -> ANNResult> { + fn query_distance<'a>( + &'a self, + query: &'a [T], + ) -> ANNResult> { Ok(Box::new(QueryDistance::new(self.distance, query))) } } @@ -90,10 +93,7 @@ where } } -impl layers::Insert for Full where - T: bytemuck::Pod + std::fmt::Debug + Send + Sync -{ -} +impl layers::Insert for Full where T: bytemuck::Pod + std::fmt::Debug + Send + Sync {} ////////////// // Distance // diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index a6b0bb29e..51a56761d 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -5,16 +5,18 @@ #![deny(rustdoc::broken_intra_doc_links)] +pub mod num; + mod buffer; +mod epoch; +mod freelist; mod neighbors; -mod sync; - -pub mod num; +mod sharded; +mod tag; -pub mod ids; -pub mod layers; mod store; +pub mod layers; pub mod provider; pub use provider::{Context, Provider, Strategy}; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index 9bcefcd58..507fd1b39 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -121,7 +121,11 @@ impl Neighbors { /// Copy the contents of adjacency list `i` into `neighbors`. /// /// Returns an error if `i` exceeds [`Self::entries`]. - pub(crate) fn get(&self, i: u32, neighbors: &mut AdjacencyList) -> Result<(), OutOfBounds> { + pub(crate) fn get( + &self, + i: u32, + neighbors: &mut AdjacencyList, + ) -> Result<(), OutOfBounds> { self.check(i)?; let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }; @@ -139,7 +143,8 @@ impl Neighbors { // SAFETY: We hold the read-lock, so reading is safe. From our bounds checks, we // know that this pointer is valid. let len: usize = unsafe { prefix.as_ptr().cast::().read() } - .min(self.max_length_u32()).into_usize(); + .min(self.max_length_u32()) + .into_usize(); let mut resizer = neighbors.resize(len); unsafe { @@ -230,7 +235,6 @@ pub(crate) enum NeighborsError { AllocationFailed(#[from] BufferError), } - /// Attempted to access a [`Neighbors`] at an out-of-bounds index. #[derive(Debug, Clone, Copy, Error)] #[error("index {} is out-of-bounds", self.0)] @@ -341,7 +345,6 @@ impl Lock<'_> { Ok(()) } - #[cfg(test)] fn is_empty(&self) -> bool { self.len() == 0 @@ -358,7 +361,6 @@ impl std::fmt::Debug for Lock<'_> { } } - /////////// // Tests // /////////// @@ -544,11 +546,8 @@ mod tests { Err(SetError::OutOfBounds(_)) )); - let generate = |round: u32, entry: u32| -> Vec { - (0..(round + 1)) - .map(|r| entry + r) - .collect() - }; + let generate = + |round: u32, entry: u32| -> Vec { (0..(round + 1)).map(|r| entry + r).collect() }; // Test mutation via `Neighbors::set`. for round in 0..neighbors.max_length_u32() { @@ -591,11 +590,7 @@ mod tests { // Test mutation via `lock + append`. for round in 0..neighbors.max_length_u32() { for i in 0..neighbors.entries() { - neighbors - .lock(i) - .unwrap() - .append(&[round + i]) - .unwrap(); + neighbors.lock(i).unwrap().append(&[round + i]).unwrap(); } for i in 0..neighbors.entries() { diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index 073e66fd2..16a9a42db 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -38,6 +38,11 @@ impl Bytes { } } + #[inline] + pub const fn div(self, other: NonZeroUsize) -> Bytes { + Bytes::new(self.value() / other.get()) + } + #[inline] pub(crate) const fn unchecked_mul(self, other: usize) -> Bytes { Bytes::new(self.value() * other) diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 5bb4e4e3c..9d7deeb00 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -19,11 +19,10 @@ use diskann::{ use diskann_utils::views::Matrix; use crate::{ - ids, layers::{self, Distance, QueryDistance}, num::Bytes, - store::{self, Primary}, - sync::epoch::Unavailable, + sharded::Sharded, + store::{self, Store}, }; pub trait Id: Send + Sync + Hash + Eq + Clone + 'static {} @@ -34,9 +33,9 @@ pub struct Provider where M: Id, { - primary: Primary, + store: Store, layer: L, - mapping: ids::Sharded, + mapping: Sharded, } impl Provider @@ -56,18 +55,14 @@ where layers::Set::into_bytes(&layer, point, row).unwrap(); } - let primary = Primary::new(capacity, bytes, 32, data.as_view()); - let mapping = ids::Sharded::new(capacity); + let store = Store::new(capacity, bytes, 32, data.as_view()); + let mapping = Sharded::new(capacity); Self { - primary, + store, layer, mapping, } } - - fn reader(&self) -> Result, Unavailable> { - self.primary.reader() - } } /////////////////// @@ -123,7 +118,7 @@ where async fn delete(&self, _context: &Context, gid: &M) -> ANNResult<()> { // TODO: These need to actually happen in lock-step. let internal = self.mapping.remove(gid).unwrap(); - assert!(self.primary.delete(internal.into_usize())); + assert!(self.store.delete(internal.into_usize())); Ok(()) } @@ -137,7 +132,7 @@ where id: u32, ) -> ANNResult { if self - .primary + .store .reader() .unwrap() .can_read(id.into_usize()) @@ -173,7 +168,7 @@ where F: FnMut(ANNResult, Self::InternalId) + Send, { let work = move || { - let reader = self.primary.reader().unwrap(); + let reader = self.store.reader().unwrap(); for i in itr { if reader.can_read(i.into_usize()).unwrap() { f(Ok(diskann::provider::ElementStatus::Valid), i) @@ -209,7 +204,7 @@ where element: T, ) -> impl std::future::Future> + Send { let work = move || { - let mut slot = self.primary.acquire().unwrap(); + let mut slot = self.store.acquire().unwrap(); // TODO: Proper cleanup via `Guard` or some other mechanism on the event of // insert failure. @@ -290,10 +285,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { { let work = move || -> ANNResult<()> { for i in ids { - self.reader - .neighbors() - .get(i, &mut self.ids) - .unwrap(); + self.reader.neighbors().get(i, &mut self.ids).unwrap(); // Filter out unvisited IDs and ensure that all the IDs we are about self.ids @@ -349,15 +341,15 @@ fn dispatch_expand_beam(bytes: Bytes) -> FExpandBeam { const CACHE_LINE_SIZE: usize = 64; -pub unsafe fn test_function( - list: &[u32], - lookahead: usize, - reader: &store::Reader<'_>, - distance: &dyn layers::QueryDistance, - f: &mut dyn FnMut(u32, f32), -) -> ANNResult<()> { - unsafe { expand_beam_inner::<4>(list, lookahead, reader, distance, f) } -} +// pub unsafe fn test_function( +// list: &[u32], +// lookahead: usize, +// reader: &store::Reader<'_>, +// distance: &dyn layers::QueryDistance, +// f: &mut dyn FnMut(u32, f32), +// ) -> ANNResult<()> { +// unsafe { expand_beam_inner::<4>(list, lookahead, reader, distance, f) } +// } /// Safety (no # yet because we need to revisit this - clippy will lint) /// @@ -471,13 +463,7 @@ impl provider::NeighborAccessor for PruneAccessor<'_> { id: Self::Id, neighbors: &mut AdjacencyList, ) -> impl std::future::Future> + Send { - let work = move || { - Ok(self - .reader - .neighbors() - .get(id, neighbors) - .unwrap()) - }; + let work = move || Ok(self.reader.neighbors().get(id, neighbors).unwrap()); ready(work) } } @@ -488,13 +474,7 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { id: Self::Id, neighbors: &[Self::Id], ) -> impl std::future::Future> + Send { - let work = move || { - Ok(self - .reader - .neighbors() - .set(id, neighbors) - .unwrap()) - }; + let work = move || Ok(self.reader.neighbors().set(id, neighbors).unwrap()); ready(work) } @@ -550,7 +530,7 @@ where query: L::Query<'a>, ) -> ANNResult> { let distance = ::query_distance(&provider.layer, query)?; - let reader = provider.primary.reader()?; + let reader = provider.store.reader()?; let expand_beam = dispatch_expand_beam(reader.bytes()); let accessor = SearchAccessor { reader, @@ -558,7 +538,7 @@ where ids: AdjacencyList::new(), expand_beam, provider, - start_points: provider.primary.frozen(), + start_points: provider.store.frozen(), }; Ok(accessor) } @@ -634,7 +614,7 @@ where _capacity: usize, ) -> ANNResult> { Ok(PruneAccessor { - reader: provider.primary.reader()?, + reader: provider.store.reader()?, distance: ::as_distance(&provider.layer), }) } @@ -692,7 +672,7 @@ where ) -> impl Future> + Send { let work = move || { - let reader = provider.primary.reader().unwrap(); + let reader = provider.store.reader().unwrap(); let mut buf: Box<[_]> = std::iter::repeat_n(0.0, provider.layer.dim()).collect(); let data = reader.read(id.into_usize()).unwrap(); bytemuck::must_cast_slice_mut::(&mut buf).copy_from_slice(data); diff --git a/diskann-inmem/src/ids/sharded.rs b/diskann-inmem/src/sharded.rs similarity index 100% rename from diskann-inmem/src/ids/sharded.rs rename to diskann-inmem/src/sharded.rs diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index c935e69ee..08c000109 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -3,20 +3,22 @@ * Licensed under the MIT license. */ -use std::{iter::repeat_n, num::NonZeroU32, sync::atomic::Ordering}; +use std::{iter::repeat_n, num::{NonZeroU32, NonZeroUsize}, sync::atomic::Ordering}; use diskann::utils::IntoUsize; use diskann_utils::views::MatrixView; use crate::{ - neighbors::Neighbors, buffer::{Buffer, RawSlice}, + epoch::{self, Registry}, + freelist::{self, Freelist}, + neighbors::Neighbors, num::{Align, Bytes}, - sync::{AtomicTag, Freelist, Tag, Registry, epoch, freelist}, + tag::{AtomicTag, Tag}, }; #[derive(Debug)] -pub struct Primary { +pub(crate) struct Store { // The invasive store where concurrency tags are stored inline with the data. // // These tags are mirrored from `tags` - with the latter being used for secondary scans @@ -34,9 +36,10 @@ pub struct Primary { const SPLIT: Bytes = Bytes::size_of::(); const RETRY_LIMIT: usize = 20; +const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); -impl Primary { - pub fn new( +impl Store { + pub(crate) fn new( entries: usize, bytes: Bytes, max_neighbors: usize, @@ -46,7 +49,13 @@ impl Primary { assert_ne!(init.nrows(), 0); let unpadded = bytes.checked_add(SPLIT).unwrap(); - let padded_bytes = unpadded.checked_next_multiple_of(Bytes::CACHELINE).unwrap(); + + // Pad to half a cache line. When data occupies just part of a cache line, this + // results in the same total number of cache lines being fetched while potentially + // enabling more compact memory. + let padded_bytes = unpadded + .checked_next_multiple_of(Bytes::CACHELINE.div(TWO)) + .unwrap(); let total: usize = entries.checked_add(init.nrows()).unwrap(); @@ -62,7 +71,8 @@ impl Primary { // we do not want it to release frozen IDs. freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), registry: Registry::new(), - neighbors: Neighbors::new(total.try_into().unwrap(), max_neighbors.try_into().unwrap()).unwrap(), + neighbors: Neighbors::new(total.try_into().unwrap(), max_neighbors.try_into().unwrap()) + .unwrap(), }; // Populate frozen points. @@ -76,19 +86,19 @@ impl Primary { } /// Return the range of slots containing frozen items in `self`. - pub fn frozen(&self) -> std::ops::Range { + pub(crate) fn frozen(&self) -> std::ops::Range { (self.unfrozen as u32)..(self.buffer.len() as u32) } /// Return the number of unfrozen slots managed by `self`. - pub fn capacity(&self) -> usize { + pub(crate) fn capacity(&self) -> usize { self.buffer.len() - self.unfrozen } /// Attempt to reclaim retired slots. /// /// If successful, returns the number of slots reclaimed. - pub fn try_drain(&self) -> Option { + pub(crate) fn try_drain(&self) -> Option { fn release(tag: &AtomicTag, kind: &'static str) { // Relaxed ordering is sufficient as all readers/writers are synchronized on // the central generation. @@ -125,17 +135,20 @@ impl Primary { /// # Errors /// /// Returns [`epoch::Unavailable`] if there are too many active readers. - pub fn reader(&self) -> Result, epoch::Unavailable> { + pub(crate) fn reader(&self) -> Result, epoch::Unavailable> { Ok(Reader { buffer: &self.buffer, unpadded: self.unpadded, neighbors: &self.neighbors, - _epoch: self.registry.guard()?, + _guard: self.registry.guard()?, }) } - /// Attempt to acquire new slot for writing. - pub fn acquire(&self) -> Option> { + /// Attempt to acquire a new [`Slot`] for writing. + /// + /// This method first consults the freelist and falls back to scanning the tags list + /// if no ID is available from the fast path. + pub(crate) fn acquire(&self) -> Option> { for _ in 0..RETRY_LIMIT { match self.freelist.pop() { freelist::Id::Found(id) => { @@ -154,6 +167,32 @@ impl Primary { None } + pub(crate) fn delete(&self, i: usize) -> bool { + let guard = self.registry.guard().unwrap(); + let tag = self.tags.get(i).unwrap(); + let current = tag.load(Ordering::Relaxed); + + // We can only perform a deletion if the generation is not in a reserved state. + if current.is_reserved() { + return false; + } + + let retiring = Tag::RETIRING; + + // Even if we make this change, we can't access any data until we wait for the + // epoch to be bumped. As such, relaxed semantics are fine. + match tag.compare_exchange(current, retiring, Ordering::Relaxed, Ordering::Relaxed) { + Ok(_) => { + // Set the metadata in the mirror as well. + let (mirror, _) = unsafe { self.data_unchecked(i) }; + mirror.store(retiring, Ordering::Relaxed); + guard.retire(i as u32); + true + } + Err(_) => false, + } + } + fn scan_acquire(&self) -> Option> { // This is potentially quite slow - but stop if we've scanned the entire range // without finding anything. @@ -202,6 +241,12 @@ impl Primary { unsafe { self.try_acquire(tag, i) } } + /// Try to acquire `slot` with the associated `tag`. + /// + /// # Safety + /// + /// Caller asserts that `tag` was obtained from `self.tags[slot]`. This is meant as + /// a perfomance optimization where `tag` is first queried for potential availability. unsafe fn try_acquire<'a>(&'a self, tag: &'a AtomicTag, slot: u32) -> Option> { match tag.compare_exchange( Tag::AVAILABLE, @@ -222,32 +267,11 @@ impl Primary { } } - pub(crate) fn delete(&self, i: usize) -> bool { - let guard = self.registry.guard().unwrap(); - let tag = self.tags.get(i).unwrap(); - let current = tag.load(Ordering::Relaxed); - - // We can only perform a deletion if the generation is not in a reserved state. - if current.is_reserved() { - return false; - } - - let retiring = Tag::RETIRING; - - // Even if we make this change, we can't access any data until we wait for the - // epoch to be bumped. As such, relaxed semantics are fine. - match tag.compare_exchange(current, retiring, Ordering::Relaxed, Ordering::Relaxed) { - Ok(_) => { - // Set the metadata in the mirror as well. - let (mirror, _) = unsafe { self.data_unchecked(i) }; - mirror.store(retiring, Ordering::Relaxed); - guard.retire(i as u32); - true - } - Err(_) => false, - } - } - + /// Return the data at position `i` without bound-checking. + /// + /// # Safety + /// + /// The index `i` must be less then `self.buffer.len()`. unsafe fn data_unchecked(&self, i: usize) -> (&AtomicTag, RawSlice<'_>) { let (mirror, data) = unsafe { self.buffer.get_unchecked(i) } .truncate(self.unpadded) @@ -259,13 +283,16 @@ impl Primary { } } +/// An epoch protect reader into [`Store`]. +/// +/// Created via [`Store::reader`]. #[derive(Debug)] -pub struct Reader<'a> { +pub(crate) struct Reader<'a> { buffer: &'a Buffer, unpadded: Bytes, neighbors: &'a Neighbors, // It's important that we hold onto this, even if we don't use it. - _epoch: epoch::Guard<'a>, + _guard: epoch::Guard<'a>, } impl<'a> Reader<'a> { @@ -275,7 +302,7 @@ impl<'a> Reader<'a> { /// 1. Index `i` is out-of-bounds. /// 2. The read cannot be guaranteed to be race-free. #[inline] - pub fn read(&self, i: usize) -> Option<&[u8]> { + pub(crate) fn read(&self, i: usize) -> Option<&[u8]> { if self.is_in_bounds(i) { unsafe { self.read_in_bounds(i) } } else { @@ -286,10 +313,14 @@ impl<'a> Reader<'a> { /// Return `true` if the index `i` is in-bounds. #[inline] #[must_use = "this function has no side-effects"] - pub fn is_in_bounds(&self, i: usize) -> bool { + pub(crate) fn is_in_bounds(&self, i: usize) -> bool { i < self.buffer.len() } + /// Return `true` if it is safe to read the data at position `i`. + /// + /// This guarantee only holds while `self` is alive. Construction of a new [`Reader`] + /// requires a separate check. pub(crate) fn can_read(&self, i: usize) -> Option { if !self.is_in_bounds(i) { return None; @@ -303,6 +334,12 @@ impl<'a> Reader<'a> { Some(can_read) } + /// Read the data as position `i` if it is guaranteed to be race-free without bounds + /// checking. + /// + /// # Safety + /// + /// The index `i` must satisfy [`Self::is_in_bounds`]. #[inline] pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { debug_assert!(self.is_in_bounds(i)); @@ -320,8 +357,8 @@ impl<'a> Reader<'a> { .can_read(); if can_read { - // SAFETY: tags and buffer always have the same length, and we - // verified i < tags.len() above. + // SAFETY: We've passed the `can_read` check - `_guard` will ensure the read + // slice is valid and race-free. Some(unsafe { rest.as_slice() }) } else { None @@ -332,7 +369,7 @@ impl<'a> Reader<'a> { /// /// # Safety /// - /// The index `i` must be in-bounds. + /// The index `i` must be satisfy [`Self::is_in_bounds`]. #[inline] pub(crate) unsafe fn read_raw_unchecked(&self, i: usize) -> RawSlice<'_> { unsafe { self.buffer.get_unchecked(i) }.truncate(self.unpadded) @@ -343,15 +380,15 @@ impl<'a> Reader<'a> { self.unpadded } - // TODO: We may want to lock `Neighbors` in some way to enable exclusive access during - // operations like snapshots. + /// Return [`Neighbors`]. pub(crate) fn neighbors(&self) -> &Neighbors { &self.neighbors } } +/// A writable buffer into the data managed by a [`Store`], obtained from [`Store::Acquire`]. #[derive(Debug)] -pub struct Slot<'a> { +pub(crate) struct Slot<'a> { tag: &'a AtomicTag, mirror: &'a AtomicTag, data: RawSlice<'a>, @@ -359,12 +396,13 @@ pub struct Slot<'a> { } impl<'a> Slot<'a> { - pub fn as_mut_slice(&mut self) -> &mut [u8] { + /// View the managed data as a mutable slice. + pub(crate) fn as_mut_slice(&mut self) -> &mut [u8] { unsafe { self.data.as_mut_slice() } } /// Return the slot associated with this write. - pub fn slot(&self) -> u32 { + pub(crate) fn slot(&self) -> u32 { self.slot } diff --git a/diskann-inmem/src/sync/mod.rs b/diskann-inmem/src/sync/mod.rs deleted file mode 100644 index 104bf71de..000000000 --- a/diskann-inmem/src/sync/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -pub(crate) mod epoch; -pub(crate) use epoch::Registry; - -pub(crate) mod freelist; -pub(crate) use freelist::Freelist; - -mod tag; -pub(crate) use tag::{AtomicTag, Tag}; - -#[cfg(test)] -mod test; diff --git a/diskann-inmem/src/sync/tag.rs b/diskann-inmem/src/tag.rs similarity index 100% rename from diskann-inmem/src/sync/tag.rs rename to diskann-inmem/src/tag.rs diff --git a/diskann-inmem/src/sync/test.rs b/diskann-inmem/src/test/epoch.rs similarity index 99% rename from diskann-inmem/src/sync/test.rs rename to diskann-inmem/src/test/epoch.rs index 39297beea..b4a2a5243 100644 --- a/diskann-inmem/src/sync/test.rs +++ b/diskann-inmem/src/test/epoch.rs @@ -13,7 +13,10 @@ use std::{ use rand::{Rng, distr::StandardUniform}; -use super::{AtomicTag, Registry, Tag}; +use crate::{ + epoch::Registry, + tag::{AtomicTag, Tag}, +}; type Data = [u32; 4]; diff --git a/diskann-inmem/src/test/mod.rs b/diskann-inmem/src/test/mod.rs new file mode 100644 index 000000000..e91e30f81 --- /dev/null +++ b/diskann-inmem/src/test/mod.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod sequencer; +pub(crate) use sequencer::Sequencer; + +// Longer Running Tests +mod epoch; diff --git a/diskann-inmem/src/test.rs b/diskann-inmem/src/test/sequencer.rs similarity index 100% rename from diskann-inmem/src/test.rs rename to diskann-inmem/src/test/sequencer.rs From 02cb0596af096694f5dc719d0975c4bd8356f662 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 17 Jun 2026 17:19:27 -0700 Subject: [PATCH 15/45] Fix up compile errors. --- diskann-inmem/src/provider.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 9d7deeb00..46d25fa9f 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -634,13 +634,6 @@ where // TODO: This is such a hack. impl glue::InplaceDeleteStrategy, M>> for Strategy where - Self: glue::PruneStrategy, M>>, - Self: for<'a> glue::InsertStrategy< - 'a, - Provider, M>, - &'a [f32], - SearchAccessor = SearchAccessor<'a>, - >, M: Id, { type DeleteElement<'a> = &'a [f32]; From 340d3a5ede3eb0a45569339502ad1d00606712db Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 19 Jun 2026 16:20:52 -0700 Subject: [PATCH 16/45] Checkpoint. --- Cargo.lock | 1 + Cargo.toml | 4 + diskann-benchmark/src/backend/mod.rs | 23 -- diskann-benchmark/src/index/inmem2.rs | 3 +- diskann-benchmark/src/index/mod.rs | 2 +- diskann-inmem/Cargo.toml | 3 +- diskann-inmem/src/layers/full.rs | 116 +++++++-- diskann-inmem/src/layers/mod.rs | 32 ++- diskann-inmem/src/provider.rs | 242 ++++++++++++------ diskann-inmem/src/store.rs | 150 +++++++++-- .../src/distance/implementations.rs | 8 +- diskann-vector/src/distance/mod.rs | 2 +- 12 files changed, 427 insertions(+), 159 deletions(-) delete mode 100644 diskann-benchmark/src/backend/mod.rs diff --git a/Cargo.lock b/Cargo.lock index a30491ab8..181385e13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -817,6 +817,7 @@ dependencies = [ "diskann", "diskann-utils", "diskann-vector", + "diskann-wide", "parking_lot", "rand 0.9.4", "thiserror 2.0.17", diff --git a/Cargo.toml b/Cargo.toml index a62586882..b6c4a7e33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,3 +119,7 @@ opt-level = 1 debug = true debug-assertions = true overflow-checks = true + +[profile.samply] +inherits = "release" +debug = true diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs deleted file mode 100644 index 6bde405ad..000000000 --- a/diskann-benchmark/src/backend/mod.rs +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use diskann_benchmark_runner::Registry; - -mod disk_index; -mod exhaustive; -mod filters; -mod index; -mod inmem2; -mod multi_vector; - -pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { - exhaustive::register_benchmarks(registry)?; - disk_index::register_benchmarks(registry)?; - index::register_benchmarks(registry)?; - filters::register_benchmarks(registry)?; - multi_vector::register_benchmarks(registry)?; - inmem2::register_benchmarks(registry)?; - Ok(()) -} diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index ef104a718..c656efd6b 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -33,8 +33,7 @@ use diskann_vector::distance::Metric; use serde::{Deserialize, Serialize}; use crate::{ - backend::index::build::ProgressMeter, inputs::graph_index::DynamicRunbookParams, - utils::datafiles, + index::build::ProgressMeter, inputs::graph_index::DynamicRunbookParams, utils::datafiles, }; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { diff --git a/diskann-benchmark/src/index/mod.rs b/diskann-benchmark/src/index/mod.rs index e902d8b1b..f3191e6fd 100644 --- a/diskann-benchmark/src/index/mod.rs +++ b/diskann-benchmark/src/index/mod.rs @@ -11,8 +11,8 @@ mod streaming; mod benchmarks; mod inmem; -mod result; mod inmem2; +mod result; #[cfg(feature = "bftree")] mod bftree; diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index c5f8636cd..8f445a019 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -13,7 +13,8 @@ crossbeam-queue = "0.3.12" dashmap.workspace = true diskann = { workspace = true } diskann-utils = { workspace = true, default-features = false } -diskann-vector.workspace = true +diskann-vector = { workspace = true } +diskann-wide = { workspace = true } parking_lot = "0.12.5" thiserror.workspace = true diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 8b16a7c0e..8cd0d77a7 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -3,11 +3,17 @@ * Licensed under the MIT license. */ +use std::{fmt::Debug, marker::PhantomData}; + use diskann::{ANNError, ANNResult}; use diskann_vector::{ UnalignedSlice, distance::{self, DistanceProvider, Metric}, }; +use diskann_wide::{ + ARCH, + arch::{Current, FTarget2}, +}; use thiserror::Error; use crate::{layers, num::Bytes}; @@ -19,6 +25,7 @@ where T: 'static, { distance: Distance, + metric: Metric, _type: std::marker::PhantomData, } @@ -37,6 +44,7 @@ where Self { distance, + metric, _type: std::marker::PhantomData, } } @@ -70,30 +78,21 @@ where } } -impl layers::Search for Full -where - T: std::fmt::Debug + Send + Sync + 'static, -{ - type Query<'a> = &'a [T]; - - fn query_distance<'a>( - &'a self, - query: &'a [T], - ) -> ANNResult> { - Ok(Box::new(QueryDistance::new(self.distance, query))) - } -} - impl layers::AsDistance for Full where - T: std::fmt::Debug + Send + Sync + 'static, + T: Debug + Send + Sync + 'static, { fn as_distance(&self) -> &dyn layers::Distance { &self.distance } } -impl layers::Insert for Full where T: bytemuck::Pod + std::fmt::Debug + Send + Sync {} +impl layers::Insert for Full +where + T: bytemuck::Pod + Debug + Send + Sync + 'static, + Self: for<'a> layers::Search = &'a [T]>, +{ +} ////////////// // Distance // @@ -143,7 +142,7 @@ where impl layers::Distance for Distance where - T: std::fmt::Debug + 'static, + T: Debug + 'static, { fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult { let bytes = self.bytes(); @@ -210,7 +209,7 @@ where impl layers::QueryDistance for QueryDistance<'_, T> where - T: std::fmt::Debug + Sync + 'static, + T: Debug + Sync + 'static, { fn evaluate(&self, x: &[u8]) -> ANNResult { if x.len() != self.distance.bytes() { @@ -234,3 +233,84 @@ struct QueryDistanceError { expected: usize, xlen: usize, } + +//-----------// +// Version 2 // +//-----------// + +macro_rules! specialize { + ($me:ident, $query:ident, $visitor:ident, $T:ty, $(($var:ident, $N:literal, $f:ty)),* $(,)?) => { + match ($me.metric, $me.dim()) { + $( + (Metric::$var, $N) => { + let wrapped = Wrap::, _>::new($query); + return Ok(unsafe { + $visitor.visit_sized::<{ $N * std::mem::size_of::<$T>() }, _>(wrapped) + }) + }, + )* + _ => {}, + } + } +} + +impl layers::Search for Full { + type Query<'a> = &'a [f32]; + + fn query_distance<'a, V>(&'a self, query: &'a [f32], visitor: V) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + use diskann_vector::distance::{Specialize, SquaredL2}; + + specialize!(self, query, visitor, f32, (L2, 100, SquaredL2)); + + Ok(visitor.visit(QueryDistance::new(self.distance, query))) + } +} + +impl layers::Search for Full { + type Query<'a> = &'a [u8]; + + fn query_distance<'a, V>(&'a self, query: &'a [u8], visitor: V) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + use diskann_vector::distance::{Specialize, SquaredL2}; + + specialize!(self, query, visitor, u8, (L2, 128, SquaredL2)); + + Ok(visitor.visit(QueryDistance::new(self.distance, query))) + } +} + +#[derive(Debug)] +struct Wrap<'a, I, T> { + query: &'a [T], + inner: PhantomData, +} + +impl<'a, I, T> Wrap<'a, I, T> { + fn new(query: &'a [T]) -> Self { + Self { + query, + inner: PhantomData, + } + } +} + +impl layers::QueryDistance for Wrap<'_, I, T> +where + I: for<'a> FTarget2, UnalignedSlice<'a, T>> + + Send + + Sync + + Debug, + T: Send + Sync + 'static + Debug, +{ + #[inline(always)] + fn evaluate(&self, x: &[u8]) -> ANNResult { + // TODO: This is not fully valid - we need to check. + let x = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.query.len()) }; + Ok(I::run(ARCH, self.query.into(), x)) + } +} diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index 1a82e9d9a..3cc796ea4 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -72,17 +72,31 @@ pub trait Set: Layer { pub trait Search: Send + Sync + 'static { type Query<'a>; - fn query_distance<'a>( - &'a self, - query: Self::Query<'a>, - ) -> ANNResult>; + fn query_distance<'a, V>(&'a self, query: Self::Query<'a>, visitor: V) -> ANNResult + where + V: QueryVisitor<'a>; +} + +pub trait QueryVisitor<'a>: Sized { + type Output; + + fn visit(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a; + + unsafe fn visit_sized(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a, + { + self.visit(distance) + } } pub trait Insert: Search + for<'a> Set> + AsDistance { - fn insert_distance<'a>( - &'a self, - query: Self::Query<'a>, - ) -> ANNResult> { - self.query_distance(query) + fn insert_distance<'a, V>(&'a self, query: Self::Query<'a>, visitor: V) -> ANNResult + where + V: QueryVisitor<'a>, + { + self.query_distance(query, visitor) } } diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 46d25fa9f..d47a3effc 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::hash::Hash; +use std::{hash::Hash, marker::PhantomData}; use diskann::{ ANNError, ANNErrorKind, ANNResult, @@ -55,7 +55,7 @@ where layers::Set::into_bytes(&layer, point, row).unwrap(); } - let store = Store::new(capacity, bytes, 32, data.as_view()); + let store = Store::new(capacity, bytes, 32, data.as_view()).unwrap(); let mapping = Sharded::new(capacity); Self { store, @@ -131,13 +131,9 @@ where _context: &Context, id: u32, ) -> ANNResult { - if self - .store - .reader() - .unwrap() - .can_read(id.into_usize()) - .unwrap() - { + // Not that this check is approximate. A full check requires materialization of + // a `reader`. + if self.store.can_read_approximate(id.into_usize()).unwrap() { Ok(diskann::provider::ElementStatus::Valid) } else { Ok(diskann::provider::ElementStatus::Deleted) @@ -156,31 +152,6 @@ where Ok(diskann::provider::ElementStatus::Deleted) } } - - fn statuses_unordered( - &self, - _context: &Self::Context, - itr: Itr, - mut f: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - F: FnMut(ANNResult, Self::InternalId) + Send, - { - let work = move || { - let reader = self.store.reader().unwrap(); - for i in itr { - if reader.can_read(i.into_usize()).unwrap() { - f(Ok(diskann::provider::ElementStatus::Valid), i) - } else { - f(Ok(diskann::provider::ElementStatus::Deleted), i) - } - } - Ok(()) - }; - - ready(work) - } } fn ready(f: F) -> std::future::Ready @@ -225,9 +196,8 @@ where #[derive(Debug)] pub struct SearchAccessor<'a> { reader: store::Reader<'a>, - distance: Box, ids: AdjacencyList, - expand_beam: FExpandBeam, + expand_beam: ExpandBeam<'a>, // The parent provider for the accessor. provider: &'a (dyn std::any::Any + Send + Sync), @@ -256,7 +226,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { for p in self.start_points.clone() { match self.reader.read(p.into_usize()) { Some(point) => { - f(p, self.distance.evaluate(point)?); + f(p, self.expand_beam.evaluate(point)?); } None => { return Err(ANNError::message( @@ -292,13 +262,8 @@ impl glue::SearchAccessor for SearchAccessor<'_> { .retain(|i| pred.eval_mut(i) && self.reader.is_in_bounds(i.into_usize())); unsafe { - (self.expand_beam)( - &self.ids, - 8, - &self.reader, - &*self.distance, - &mut on_neighbors, - ) + self.expand_beam + .run(&self.ids, 8, &self.reader, &mut on_neighbors) }?; } @@ -310,32 +275,128 @@ impl glue::SearchAccessor for SearchAccessor<'_> { } type FExpandBeam = unsafe fn( + *const (), &[u32], usize, &store::Reader<'_>, - &dyn layers::QueryDistance, &mut dyn FnMut(u32, f32), ) -> ANNResult<()>; -fn dispatch_expand_beam(bytes: Bytes) -> FExpandBeam { - if bytes <= Bytes::CACHELINE { - expand_beam_inner::<1> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(2) { - expand_beam_inner::<2> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(3) { - expand_beam_inner::<3> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(4) { - expand_beam_inner::<4> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(5) { - expand_beam_inner::<5> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(6) { - expand_beam_inner::<6> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(7) { - expand_beam_inner::<7> - } else if bytes <= Bytes::CACHELINE.unchecked_mul(16) { - expand_beam_inner::<8> - } else { - expand_beam_inner::<16> +#[derive(Debug)] +struct ExpandBeamVisitor { + bytes: Bytes, +} + +impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { + type Output = ExpandBeam<'a>; + + unsafe fn visit_sized(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a, + { + // Make sure there's no lying. + assert_eq!(Bytes::new(BYTES + 1), self.bytes); + + unsafe { ExpandBeam::new(distance, expand_beam_inner::) } + } + + fn visit(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a, + { + unsafe { ExpandBeam::new(distance, expand_beam_inner::) } + } +} + +#[derive(Debug)] +struct ExpandBeam<'a> { + ptr: *const (), + expand_beam: FExpandBeam, + vtable: &'static VTable, + lifetime: PhantomData<&'a ()>, +} + +#[derive(Debug)] +struct VTable { + evaluate: unsafe fn(*const (), &[u8]) -> ANNResult, + drop: unsafe fn(*mut ()), +} + +// SAFETY: We constrain `ptr` to be `Send`. +unsafe impl Send for ExpandBeam<'_> {} + +// SAFETY: We constrain `ptr` to be `Send`. +unsafe impl Sync for ExpandBeam<'_> {} + +impl<'a> ExpandBeam<'a> { + unsafe fn new(x: T, expand_beam: FExpandBeam) -> Self + where + T: layers::QueryDistance + Send + Sync + 'a, + { + let vtable = &VTable { + evaluate: evaluate::, + drop: drop::, + }; + + let ptr: *const T = Box::leak(Box::new(x)); + Self { + ptr: ptr.cast(), + expand_beam, + vtable: &vtable, + lifetime: PhantomData, + } + } + + unsafe fn run( + &self, + list: &[u32], + lookahead: usize, + reader: &store::Reader<'_>, + f: &mut dyn FnMut(u32, f32), + ) -> ANNResult<()> { + unsafe { (self.expand_beam)(self.ptr, list, lookahead, reader, f) } + } + + fn evaluate(&self, x: &[u8]) -> ANNResult { + unsafe { (self.vtable.evaluate)(self.ptr, x) } + } +} + +impl Drop for ExpandBeam<'_> { + fn drop(&mut self) { + unsafe { (self.vtable.drop)(self.ptr.cast_mut()) } + } +} + +unsafe fn drop(ptr: *mut ()) { + let _ = unsafe { Box::from_raw(ptr.cast::()) }; +} + +unsafe fn evaluate(ptr: *const (), x: &[u8]) -> ANNResult +where + T: layers::QueryDistance, +{ + let f = unsafe { &*ptr.cast::() }; + ::evaluate(f, x) +} + +#[inline(always)] +unsafe fn prefetch(ptr: *const u8, len: usize) { + use std::arch::x86_64::*; + + // Fetch the last cache line (the one with the tag) first. + let stride = Bytes::CACHELINE.value(); + let ptr = ptr.cast::(); + let lines = len.div_ceil(stride); + if lines == 0 { + return; + } + + unsafe { _mm_prefetch(ptr.add(stride * (lines - 1)), _MM_HINT_T0) }; + for i in 0..(lines - 1) { + unsafe { + _mm_prefetch(ptr.add(stride * i), _MM_HINT_T0); + } } } @@ -353,37 +414,45 @@ const CACHE_LINE_SIZE: usize = 64; /// Safety (no # yet because we need to revisit this - clippy will lint) /// +/// * The concrete type of `distance` must be `T`. /// * All items in `list` must in-bounds with respect to `reader`. /// * The number of bytes associated with `N` cache lines must "make sense". -unsafe fn expand_beam_inner( +unsafe fn expand_beam_inner( + distance: *const (), list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - distance: &dyn layers::QueryDistance, f: &mut dyn FnMut(u32, f32), -) -> ANNResult<()> { +) -> ANNResult<()> +where + T: layers::QueryDistance, +{ + let distance = unsafe { &*distance.cast::() }; + debug_assert!( - N * CACHE_LINE_SIZE - <= reader - .bytes() - .checked_next_multiple_of(Bytes::CACHELINE) - .unwrap() - .value(), + BYTES + 1 <= reader.bytes().value(), "we really rely on this: {}, bytes = {}", - N, + BYTES + 1, reader.bytes() ); + let bytes = if BYTES == 0 { + reader.bytes().value() + } else { + BYTES + 1 + }; + let len = list.len(); let lookahead = lookahead.min(len); for j in 0..lookahead { unsafe { - diskann_vector::prefetch_exactly::( + prefetch( reader .read_raw_unchecked(list.get_unchecked(j).into_usize()) .as_ptr() .cast(), + bytes, ) } } @@ -392,11 +461,12 @@ unsafe fn expand_beam_inner( for &i in list.iter() { if j != len { unsafe { - diskann_vector::prefetch_exactly::( + prefetch( reader .read_raw_unchecked(list.get_unchecked(j).into_usize()) .as_ptr() .cast(), + bytes, ) } j += 1; @@ -529,12 +599,17 @@ where _context: &'a Context, query: L::Query<'a>, ) -> ANNResult> { - let distance = ::query_distance(&provider.layer, query)?; let reader = provider.store.reader()?; - let expand_beam = dispatch_expand_beam(reader.bytes()); + let expand_beam = ::query_distance( + &provider.layer, + query, + ExpandBeamVisitor { + bytes: provider.store.bytes(), + }, + )?; + let accessor = SearchAccessor { reader, - distance, ids: AdjacencyList::new(), expand_beam, provider, @@ -544,6 +619,15 @@ where } } +pub fn test_function<'a>( + x: &'a Provider>, + strategy: &'a Strategy, + context: &'a Context, + query: &'a [u8], +) -> SearchAccessor<'a> { + glue::SearchStrategy::search_accessor(strategy, x, context, query).unwrap() +} + #[derive(Debug, Clone, Copy)] pub struct Translate(std::marker::PhantomData<(L, M)>); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 08c000109..a10267ea2 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -3,10 +3,15 @@ * Licensed under the MIT license. */ -use std::{iter::repeat_n, num::{NonZeroU32, NonZeroUsize}, sync::atomic::Ordering}; +use std::{ + iter::repeat_n, + num::{NonZeroU32, NonZeroUsize}, + sync::atomic::Ordering, +}; use diskann::utils::IntoUsize; use diskann_utils::views::MatrixView; +use thiserror::Error; use crate::{ buffer::{Buffer, RawSlice}, @@ -35,54 +40,80 @@ pub(crate) struct Store { } const SPLIT: Bytes = Bytes::size_of::(); -const RETRY_LIMIT: usize = 20; const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); +// TODO: This is a guess and probably needs tuning. +const RETRY_LIMIT: usize = 20; + impl Store { pub(crate) fn new( entries: usize, bytes: Bytes, max_neighbors: usize, init: MatrixView<'_, u8>, - ) -> Self { - assert_eq!(init.ncols(), bytes.value()); - assert_ne!(init.nrows(), 0); + ) -> Result { + if init.ncols() != bytes.value() { + return Err(StoreError::mismatched_frozen_point_dim(init.ncols(), bytes)); + } - let unpadded = bytes.checked_add(SPLIT).unwrap(); + if init.nrows() == 0 { + return Err(StoreError::need_frozen_point()); + } + + let unpadded = bytes + .checked_add(SPLIT) + .expect("unreachable because `init` cannot exceed `isize::MAX` bytes"); // Pad to half a cache line. When data occupies just part of a cache line, this // results in the same total number of cache lines being fetched while potentially // enabling more compact memory. let padded_bytes = unpadded .checked_next_multiple_of(Bytes::CACHELINE.div(TWO)) - .unwrap(); + .expect("unreachabel because `init` cannot exceed `isize::MAX` bytes"); + + let too_many_entries = || StoreError::too_many_entries(entries, init.nrows()); + + // We have a hard upper-bound of `u32::MAX` total slots. + // + // Thiis enforces that bound. + let entries: u32 = entries.try_into().map_err(|_| too_many_entries())?; + + let frozen: u32 = init.nrows().try_into().map_err(|_| too_many_entries())?; - let total: usize = entries.checked_add(init.nrows()).unwrap(); + let total: u32 = entries.checked_add(frozen).ok_or_else(too_many_entries)?; - let this = Self { - buffer: Buffer::new(total, padded_bytes, Align::_128).unwrap(), + let max_neighbors: u32 = max_neighbors + .try_into() + .map_err(|_| StoreError::too_many_neighbors(max_neighbors))?; + + let me = Self { + buffer: Buffer::new(total.into_usize(), padded_bytes, Align::_128).unwrap(), unpadded, - unfrozen: entries, - tags: repeat_n(Tag::AVAILABLE, total) + unfrozen: entries.into_usize(), + tags: repeat_n(Tag::AVAILABLE, total.into_usize()) .map(|v| AtomicTag::new(v)) .collect(), // NOTE: The `Freelist` is initialized to `entries` and not `total` because // we do not want it to release frozen IDs. - freelist: Freelist::new(entries.try_into().unwrap(), NonZeroU32::new(1024).unwrap()), + freelist: Freelist::new(entries, NonZeroU32::new(1024).unwrap()), registry: Registry::new(), - neighbors: Neighbors::new(total.try_into().unwrap(), max_neighbors.try_into().unwrap()) - .unwrap(), + neighbors: Neighbors::new(total, max_neighbors).unwrap(), }; // Populate frozen points. for (i, data) in init.row_iter().enumerate() { - let mut slot = this.slot((entries + i).try_into().unwrap()).unwrap(); + // We have checked that the total number of entries fits in `u32`, so this + // arithmetic cannot overflow. + let mut slot = me + .slot(entries + (i as u32)) + .expect("store was just created - claiming the slot must succeed"); + slot.as_mut_slice().copy_from_slice(data); slot.freeze(); } - this + Ok(me) } /// Return the range of slots containing frozen items in `self`. @@ -95,6 +126,11 @@ impl Store { self.buffer.len() - self.unfrozen } + /// Return the number of bytes occupied by each entry. + pub(crate) fn bytes(&self) -> Bytes { + self.unpadded + } + /// Attempt to reclaim retired slots. /// /// If successful, returns the number of slots reclaimed. @@ -273,14 +309,74 @@ impl Store { /// /// The index `i` must be less then `self.buffer.len()`. unsafe fn data_unchecked(&self, i: usize) -> (&AtomicTag, RawSlice<'_>) { - let (mirror, data) = unsafe { self.buffer.get_unchecked(i) } + let (data, mirror) = unsafe { self.buffer.get_unchecked(i) } .truncate(self.unpadded) - .split(SPLIT); + .split(self.unpadded.unchecked_sub(SPLIT)); ( unsafe { AtomicTag::from_ptr(mirror.as_mut_ptr().cast()) }, data, ) } + + /// Return whether or not it is probably okay to read from the slot `i`. + /// + /// This check is approximate and non-synchronizing. To fully check, [`Reader::can_read`] + /// must be used. + /// + /// Returns `None` is index `i` is out-of-bounds. + pub(crate) fn can_read_approximate(&self, i: usize) -> Option { + self.tags + .get(i) + .map(|tag| tag.load(Ordering::Relaxed).can_read()) + } +} + +#[derive(Debug, Error)] +#[error(transparent)] +pub struct StoreError(StoreErrorInner); + +impl StoreError { + fn mismatched_frozen_point_dim(dim: usize, bytes: Bytes) -> Self { + Self(StoreErrorInner::MismatchedFrozenPointDim { dim, bytes }) + } + + fn need_frozen_point() -> Self { + Self(StoreErrorInner::NeedFrozenPoint) + } + + fn too_many_entries(entries: usize, frozen: usize) -> Self { + Self(StoreErrorInner::TooManyEntries { entries, frozen }) + } + + fn too_many_neighbors(neighbors: usize) -> Self { + Self(StoreErrorInner::TooManyNeighbors { neighbors }) + } +} + +impl From for StoreError { + fn from(inner: StoreErrorInner) -> Self { + Self(inner) + } +} + +#[derive(Debug, Error)] +enum StoreErrorInner { + #[error( + "frozen point dim ({}) must have the same dimensionality as requested bytes ({})", + dim, + bytes + )] + MismatchedFrozenPointDim { dim: usize, bytes: Bytes }, + #[error("at least one frozen point must be provided")] + NeedFrozenPoint, + #[error( + "total points ({} + {} frozen) must not exceed `u32::MAX`", + entries, + frozen + )] + TooManyEntries { entries: usize, frozen: usize }, + #[error("number of neighbors ({}) may not exceed `u32::MAX`", neighbors)] + TooManyNeighbors { neighbors: usize }, } /// An epoch protect reader into [`Store`]. @@ -326,8 +422,14 @@ impl<'a> Reader<'a> { return None; } - let tag_ptr = unsafe { self.buffer.get_unchecked(i).truncate_unchecked(SPLIT) }; - let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.as_mut_ptr().cast()) } + let tag_ptr = unsafe { + self.buffer + .get_unchecked(i) + .as_mut_ptr() + .add(self.unpadded.unchecked_sub(SPLIT).value()) + }; + + let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.cast()) } .load(Ordering::Acquire) .can_read(); @@ -344,11 +446,11 @@ impl<'a> Reader<'a> { pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { debug_assert!(self.is_in_bounds(i)); - let (tag_ptr, rest) = unsafe { + let (data, tag_ptr) = unsafe { self.buffer .get_unchecked(i) .truncate_unchecked(self.unpadded) - .split_unchecked(SPLIT) + .split_unchecked(self.unpadded.unchecked_sub(SPLIT)) }; // NOTE: Must be `Acquire` to correctly synchronize with writes. @@ -359,7 +461,7 @@ impl<'a> Reader<'a> { if can_read { // SAFETY: We've passed the `can_read` check - `_guard` will ensure the read // slice is valid and race-free. - Some(unsafe { rest.as_slice() }) + Some(unsafe { data.as_slice() }) } else { None } diff --git a/diskann-vector/src/distance/implementations.rs b/diskann-vector/src/distance/implementations.rs index d454b6956..8e119e43c 100644 --- a/diskann-vector/src/distance/implementations.rs +++ b/diskann-vector/src/distance/implementations.rs @@ -49,7 +49,13 @@ macro_rules! architecture_hook { /// A utility for specializing distance computations for fixed-length slices. #[derive(Debug, Clone, Copy)] -pub(crate) struct Specialize(std::marker::PhantomData); +pub struct Specialize(std::marker::PhantomData); + +impl Specialize { + pub fn new() -> Self { + Self(std::marker::PhantomData) + } +} impl diskann_wide::arch::FTarget2, UnalignedSlice<'_, R>> diff --git a/diskann-vector/src/distance/mod.rs b/diskann-vector/src/distance/mod.rs index 415f5a67b..5ade0f22d 100644 --- a/diskann-vector/src/distance/mod.rs +++ b/diskann-vector/src/distance/mod.rs @@ -9,7 +9,7 @@ pub mod reference; pub mod simd; pub mod implementations; -pub use implementations::{Cosine, CosineNormalized, FullL2, InnerProduct, SquaredL2}; +pub use implementations::{Cosine, CosineNormalized, FullL2, InnerProduct, Specialize, SquaredL2}; pub mod distance_provider; pub use distance_provider::{Distance, DistanceProvider}; From 135e7430c6e3dd4062297fb34183bbc83a42ea17 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 22 Jun 2026 12:07:33 -0700 Subject: [PATCH 17/45] Checkpoing before vibes. --- Cargo.lock | 4 ++ diskann-inmem/Cargo.toml | 22 ++++++++++ diskann-inmem/integration/main.rs | 11 +++++ diskann-inmem/integration/store.rs | 7 +++ diskann-inmem/src/integration/mod.rs | 6 +++ diskann-inmem/src/integration/store.rs | 61 ++++++++++++++++++++++++++ diskann-inmem/src/lib.rs | 4 ++ diskann-inmem/src/provider.rs | 6 +-- diskann-inmem/src/store.rs | 59 ++++++++++++++++++++++--- 9 files changed, 170 insertions(+), 10 deletions(-) create mode 100644 diskann-inmem/integration/main.rs create mode 100644 diskann-inmem/integration/store.rs create mode 100644 diskann-inmem/src/integration/mod.rs create mode 100644 diskann-inmem/src/integration/store.rs diff --git a/Cargo.lock b/Cargo.lock index 181385e13..15c8d4c8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -811,15 +811,19 @@ dependencies = [ name = "diskann-inmem" version = "0.54.0" dependencies = [ + "anyhow", "bytemuck", "crossbeam-queue", "dashmap", "diskann", + "diskann-benchmark-runner", "diskann-utils", "diskann-vector", "diskann-wide", "parking_lot", "rand 0.9.4", + "serde", + "serde_json", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 8f445a019..44503505d 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -18,9 +18,31 @@ diskann-wide = { workspace = true } parking_lot = "0.12.5" thiserror.workspace = true +# Integration Test Dependencies +diskann-benchmark-runner = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"], optional = true } +serde_json = { workspace = true, optional = true } +anyhow = { workspace = true, optional = true } + [lints] workspace = true [dev-dependencies] rand.workspace = true tokio = { workspace = true, features = ["macros"] } + +[[bin]] +name = "integration-tests" +path = "integration/main.rs" +required-features = ["integration-test"] + +[features] +default = [] + +# Enable stress test module +integration-test = [ + "dep:diskann-benchmark-runner", + "dep:serde", + "dep:serde_json", + "dep:anyhow", +] diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs new file mode 100644 index 000000000..569350f8c --- /dev/null +++ b/diskann-inmem/integration/main.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod store; + +pub fn main() -> anyhow::Result<()> { + println!("hello world"); + Ok(()) +} diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs new file mode 100644 index 000000000..8f613ddb3 --- /dev/null +++ b/diskann-inmem/integration/store.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use serde::{Serialize, Deserialize}; + diff --git a/diskann-inmem/src/integration/mod.rs b/diskann-inmem/src/integration/mod.rs new file mode 100644 index 000000000..2f378e125 --- /dev/null +++ b/diskann-inmem/src/integration/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub mod store; diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs new file mode 100644 index 000000000..c93451cb1 --- /dev/null +++ b/diskann-inmem/src/integration/store.rs @@ -0,0 +1,61 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use crate::store; + +#[derive(Debug)] +pub struct Store { + store: store::Store, +} + +impl Store { + pub fn acquire(&self) -> Option> { + self.store.acquire().map(Writer::new) + } + + #[must_use = "result indicates success or failure"] + pub fn retire(&self, i: usize) -> bool { + self.store.retire(i).is_ok() + } + + pub fn reader(&self) -> Option> { + match self.store.reader() { + Ok(reader) => Some(Reader::new(reader)), + Err(crate::epoch::Unavailable) => None, + } + } +} + +pub struct Reader<'a> { + reader: store::Reader<'a>, +} + +impl<'a> Reader<'a> { + fn new(reader: store::Reader<'a>) -> Self { + Self { reader } + } + + pub fn read(&self, i: usize) -> Option<&[u8]> { + self.reader.read(i) + } +} + +pub struct Writer<'a> { + slot: store::Slot<'a>, +} + +impl<'a> Writer<'a> { + fn new(slot: store::Slot<'a>) -> Self { + Self { slot } + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.slot.as_mut_slice() + } + + pub fn slot(&self) -> u32 { + self.slot.slot() + } +} diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index 51a56761d..a94dbec83 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -23,3 +23,7 @@ pub use provider::{Context, Provider, Strategy}; #[cfg(test)] mod test; + +#[cfg(feature = "integration-test")] +#[doc(hidden)] +pub mod integration; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index d47a3effc..239c2b6f0 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -118,7 +118,7 @@ where async fn delete(&self, _context: &Context, gid: &M) -> ANNResult<()> { // TODO: These need to actually happen in lock-step. let internal = self.mapping.remove(gid).unwrap(); - assert!(self.store.delete(internal.into_usize())); + self.store.retire(internal.into_usize()).unwrap(); Ok(()) } @@ -620,10 +620,10 @@ where } pub fn test_function<'a>( - x: &'a Provider>, + x: &'a Provider>, strategy: &'a Strategy, context: &'a Context, - query: &'a [u8], + query: &'a [f32], ) -> SearchAccessor<'a> { glue::SearchStrategy::search_accessor(strategy, x, context, query).unwrap() } diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index a10267ea2..d827fd6aa 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -203,16 +203,33 @@ impl Store { None } - pub(crate) fn delete(&self, i: usize) -> bool { - let guard = self.registry.guard().unwrap(); - let tag = self.tags.get(i).unwrap(); + /// Attempt to retire slot `i`. If successful, this slot will be placed in an internal + /// retirement queue for reclamation once we can prove no readers are active that could + /// have observed this transition. + /// + /// Returns `Ok(())` if the slot was successfully retired. + /// + /// # Errors + /// + /// Returns an error in any of the following conditions: + /// + /// * The slot index `i` is out-of-bounds. + /// * The slot is not in a state that can be retired (e.g., it is already retired or + /// is owned by a different thread). + /// * An [`epoch::Guard`] could not be obtained due to registration slot exhaustion. + /// * An attempt to acquire the slot after these checks races with another thread and + /// the race was lost. + pub(crate) fn retire(&self, i: usize) -> Result<(), RetireError> { + let tag = self.tags.get(i).ok_or(RetireError::OutOfBounds)?; let current = tag.load(Ordering::Relaxed); // We can only perform a deletion if the generation is not in a reserved state. if current.is_reserved() { - return false; + return Err(RetireError::SlotIsReserved { tag: current }); } + let guard = self.registry.guard().map_err(RetireError::GuardUnavailable)?; + let retiring = Tag::RETIRING; // Even if we make this change, we can't access any data until we wait for the @@ -223,9 +240,9 @@ impl Store { let (mirror, _) = unsafe { self.data_unchecked(i) }; mirror.store(retiring, Ordering::Relaxed); guard.retire(i as u32); - true + Ok(()) } - Err(_) => false, + Err(_) => Err(RetireError::CouldNotClaimSlot), } } @@ -333,7 +350,7 @@ impl Store { #[derive(Debug, Error)] #[error(transparent)] -pub struct StoreError(StoreErrorInner); +pub(crate) struct StoreError(StoreErrorInner); impl StoreError { fn mismatched_frozen_point_dim(dim: usize, bytes: Bytes) -> Self { @@ -379,6 +396,23 @@ enum StoreErrorInner { TooManyNeighbors { neighbors: usize }, } +/// Error conditions for [`Store::retire`]. +#[derive(Debug, Error)] +pub(crate) enum RetireError { + /// Slot index was out-of-bounds. + #[error("index out of bounds")] + OutOfBounds, + /// The slot cannot be retired because it is in a reserved state. + #[error("slot is reserved: {}", tag)] + SlotIsReserved { tag: Tag }, + /// An [`epoch::Guard`] could not be acquired. + #[error(transparent)] + GuardUnavailable(epoch::Unavailable), + /// Another thread won the retirement race. + #[error("could not claim slot")] + CouldNotClaimSlot, +} + /// An epoch protect reader into [`Store`]. /// /// Created via [`Store::reader`]. @@ -521,3 +555,14 @@ impl Drop for Slot<'_> { self.tag.store(Tag::PUBLISHED, Ordering::Release); } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + +} From 690d1defa130dc709b24f8d56551325884efba7f Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 22 Jun 2026 18:20:19 -0700 Subject: [PATCH 18/45] Checkpoint. --- Cargo.lock | 2 + diskann-benchmark-runner/src/utils/fmt.rs | 53 ++ diskann-inmem/Cargo.toml | 8 +- .../example/store-stress-test.json | 19 + .../integration/example/store-stress.json | 19 + diskann-inmem/integration/index/datatype.rs | 119 +++++ diskann-inmem/integration/index/mod.rs | 7 + diskann-inmem/integration/index/traits.rs | 85 +++ diskann-inmem/integration/main.rs | 64 ++- diskann-inmem/integration/store.rs | 496 +++++++++++++++++- diskann-inmem/src/epoch.rs | 52 +- diskann-inmem/src/integration/store.rs | 36 +- diskann-inmem/src/layers/full.rs | 163 ++++-- diskann-inmem/src/store.rs | 7 +- 14 files changed, 1070 insertions(+), 60 deletions(-) create mode 100644 diskann-inmem/integration/example/store-stress-test.json create mode 100644 diskann-inmem/integration/example/store-stress.json create mode 100644 diskann-inmem/integration/index/datatype.rs create mode 100644 diskann-inmem/integration/index/mod.rs create mode 100644 diskann-inmem/integration/index/traits.rs diff --git a/Cargo.lock b/Cargo.lock index 15c8d4c8b..d74087d26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -820,10 +820,12 @@ dependencies = [ "diskann-utils", "diskann-vector", "diskann-wide", + "half", "parking_lot", "rand 0.9.4", "serde", "serde_json", + "tempfile", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 8113d30b9..462db6257 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -369,6 +369,59 @@ where } } +////////////// +// KeyValue // +////////////// + +/// Display a dynamic list of key-value pairs such that the keys are right-aligned. +/// +/// # Examples +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::KeyValue; +/// +/// let mut kv = KeyValue::new(); +/// kv.push("a", &1); +/// kv.push("hello", &"world"); +/// +/// let expected = +/// " a: 1 +/// hello: world +/// "; +/// +/// assert_eq!(kv.to_string(), expected); +/// ``` +pub struct KeyValue<'a> { + kv: Vec<(&'a str, &'a dyn std::fmt::Display)>, + max_key_length: usize, +} + +impl<'a> KeyValue<'a> { + /// Create a new empty [`KeyValue`] formatter. + pub fn new() -> Self { + Self { + kv: Vec::new(), + max_key_length: 0, + } + } + + /// Push the key-value pair to `self` for formatting. + pub fn push(&mut self, key: &'a str, value: &'a dyn std::fmt::Display) { + self.max_key_length = self.max_key_length.max(key.len()); + self.kv.push((key, value)) + } +} + +impl std::fmt::Display for KeyValue<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let width = self.max_key_length; + for (k, v) in self.kv.iter() { + writeln!(f, "{:>width$}: {v}", k)?; + } + Ok(()) + } +} + /////////// // Tests // /////////// diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 44503505d..3f3b61893 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -16,20 +16,23 @@ diskann-utils = { workspace = true, default-features = false } diskann-vector = { workspace = true } diskann-wide = { workspace = true } parking_lot = "0.12.5" -thiserror.workspace = true +thiserror = { workspace = true } +half = { workspace = true } # Integration Test Dependencies diskann-benchmark-runner = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"], optional = true } serde_json = { workspace = true, optional = true } anyhow = { workspace = true, optional = true } +rand = { workspace = true, optional = true } [lints] workspace = true [dev-dependencies] -rand.workspace = true +rand = { workspace = true } tokio = { workspace = true, features = ["macros"] } +tempfile = { workspace = true } [[bin]] name = "integration-tests" @@ -45,4 +48,5 @@ integration-test = [ "dep:serde", "dep:serde_json", "dep:anyhow", + "dep:rand", ] diff --git a/diskann-inmem/integration/example/store-stress-test.json b/diskann-inmem/integration/example/store-stress-test.json new file mode 100644 index 000000000..b04a008fa --- /dev/null +++ b/diskann-inmem/integration/example/store-stress-test.json @@ -0,0 +1,19 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "store-stress", + "content": { + "readers": 4, + "writers": 2, + "retirers": 1, + "capacity": 512, + "entry_bytes": 64, + "low_watermark": 128, + "duration_secs": 2, + "max_ops": 2000000, + "seed": 11939873485092837375 + } + } + ] +} diff --git a/diskann-inmem/integration/example/store-stress.json b/diskann-inmem/integration/example/store-stress.json new file mode 100644 index 000000000..6f1d2b836 --- /dev/null +++ b/diskann-inmem/integration/example/store-stress.json @@ -0,0 +1,19 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "store-stress", + "content": { + "readers": 8, + "writers": 4, + "retirers": 2, + "capacity": 4096, + "entry_bytes": 128, + "low_watermark": 1024, + "duration_secs": 10, + "max_ops": 50000000, + "seed": 11939873485092837375 + } + } + ] +} diff --git a/diskann-inmem/integration/index/datatype.rs b/diskann-inmem/integration/index/datatype.rs new file mode 100644 index 000000000..253ba7a60 --- /dev/null +++ b/diskann-inmem/integration/index/datatype.rs @@ -0,0 +1,119 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use half::f16; +use thiserror::Error; + +////////////// +// DataType // +////////////// + +#[derive(Debug, Clone, Copy)] +pub(crate) enum DataType { + F32, + F16, + U8, + I8, +} + +impl DataType { + fn as_str(self) -> &'static str { + match self { + Self::F32 => "f32", + Self::F16 => "f16", + Self::U8 => "u8", + Self::I8 => "i8", + } + } +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +pub(crate) trait AsDataType { + const DATA_TYPE: DataType; +} + +macro_rules! as_data_type { + ($T:ty, $variant:ident) => { + impl AsDataType for $T { + const DATA_TYPE: DataType = DataType::$variant; + } + }; +} + +as_data_type!(f32, F32); +as_data_type!(f16, F16); +as_data_type!(u8, U8); +as_data_type!(i8, I8); + +#[derive(Debug, Error)] +#[error("wrong data-type: expected {}, got {}", self.expected, self.got)] +pub(crate) struct WrongDataType { + expected: DataType, + got: DataType, +} + +impl WrongDataType { + fn new(expected: DataType, got: DataType) -> Self { + Self { expected, got } + } +} + +/////////// +// Slice // +/////////// + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Slice<'a> { + F32(&'a [f32]), + F16(&'a [f16]), + U8(&'a [u8]), + I8(&'a [i8]), +} + +impl<'a> Slice<'a> { + pub(crate) fn data_type(&self) -> DataType { + match self { + Self::F32(_) => DataType::F32, + Self::F16(_) => DataType::F16, + Self::U8(_) => DataType::U8, + Self::I8(_) => DataType::I8, + } + } + + pub(crate) fn try_cast(self) -> Result<&'a [T], WrongDataType> + where + T: FromSlice, + { + T::from_slice(self) + } +} + +pub(crate) trait FromSlice: Sized { + fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType>; +} + +macro_rules! from_slice { + ($T:ty, $variant:ident) => { + impl FromSlice for $T { + fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType> { + if let Slice::$variant(s) = slice { + Ok(s) + } else { + Err(WrongDataType::new(DataType::$variant, slice.data_type())) + } + } + } + }; +} + +from_slice!(f32, F32); +from_slice!(f16, F16); +from_slice!(u8, U8); +from_slice!(i8, I8); diff --git a/diskann-inmem/integration/index/mod.rs b/diskann-inmem/integration/index/mod.rs new file mode 100644 index 000000000..46883ab49 --- /dev/null +++ b/diskann-inmem/integration/index/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod datatype; +mod traits; diff --git a/diskann-inmem/integration/index/traits.rs b/diskann-inmem/integration/index/traits.rs new file mode 100644 index 000000000..f73e1a041 --- /dev/null +++ b/diskann-inmem/integration/index/traits.rs @@ -0,0 +1,85 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{future::Future, pin::Pin}; + +use diskann::{ + graph::{DiskANNIndex, search::Knn}, + neighbor::Neighbor, +}; +use half::f16; + +use diskann_inmem::{Context, Provider, Strategy, layers}; + +use super::datatype::{AsDataType, DataType, FromSlice, Slice}; + +pub(crate) trait Index { + fn data_type(&self) -> DataType; + + fn search<'a>( + &'a self, + query: Slice<'a>, + knn: Knn, + neighbors: &'a mut Vec>, + ) -> Pin> + 'a>>; + + fn insert<'a>( + &'a self, + vector: Slice<'a>, + id: u64, + ) -> Pin> + 'a>>; + + // fn retire(&self, id: u64) -> anyhow::Result<()>; +} + +/////////// +// Impls // +/////////// + +impl Index for DiskANNIndex, u64>> +where + layers::Full: for<'a> layers::Insert = &'a [T]>, + T: FromSlice + AsDataType + Send + Sync + 'static, +{ + fn data_type(&self) -> DataType { + T::DATA_TYPE + } + + fn search<'a>( + &'a self, + query: Slice<'a>, + knn: Knn, + neighbors: &'a mut Vec>, + ) -> Pin> + 'a>> { + let fut = async move { + let query = query.try_cast()?; + let _ = self + .search(knn, &Strategy, &Context, query, neighbors) + .await?; + + Ok(()) + }; + + Box::pin(fut) + } + + fn insert<'a>( + &'a self, + vector: Slice<'a>, + id: u64, + ) -> Pin> + 'a>> { + let fut = async move { + let vector = vector.try_cast()?; + self.insert(&Strategy, &Context, &id, vector).await?; + + Ok(()) + }; + + Box::pin(fut) + } + + // fn retire(&self, id: u64) -> anyhow::Result<()> { + // } +} diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs index 569350f8c..ce2f582f6 100644 --- a/diskann-inmem/integration/main.rs +++ b/diskann-inmem/integration/main.rs @@ -3,9 +3,67 @@ * Licensed under the MIT license. */ +mod index; mod store; -pub fn main() -> anyhow::Result<()> { - println!("hello world"); - Ok(()) +use diskann_benchmark_runner::{App, Registry, output}; + +/// Build a [`Registry`] with all integration benchmarks registered. +fn registry() -> anyhow::Result { + let mut registry = Registry::new(); + registry.register("store-stress", store::StoreStress)?; + Ok(registry) +} + +fn main() -> anyhow::Result<()> { + let app = App::parse(); + app.run(®istry()?, &mut output::default()) +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use diskann_benchmark_runner::{app::Commands, output::Memory}; + + // The directory containing the committed example input files. + fn example_directory() -> std::path::PathBuf { + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("integration") + .join("example") + } + + // Drive the named example through the full runner flow: load the JSON input file, + // dispatch through the registry, run the benchmark, and write results to disk. + fn run_example(name: &str) { + let input_file = example_directory().join(name); + assert!(input_file.exists(), "missing example file: {input_file:?}"); + + let tempdir = tempfile::tempdir().unwrap(); + let output_file = tempdir.path().join("output.json"); + + let command = Commands::Run { + input_file, + output_file: output_file.clone(), + dry_run: false, + // Unit tests are a debug build; bypass the runner's debug-mode guard. + allow_debug: true, + }; + let app = App::from_commands(command); + + let mut output = Memory::new(); + // A benchmark error (e.g. an invariant violation) propagates here and fails the test. + app.run(®istry().unwrap(), &mut output).unwrap(); + + assert!(output_file.exists(), "results file was not written"); + } + + #[test] + fn store_stress_integration() { + run_example("store-stress-test.json"); + } } diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index 8f613ddb3..c82dbf46d 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -3,5 +3,499 @@ * Licensed under the MIT license. */ -use serde::{Serialize, Deserialize}; +//! Concurrency stress test for the in-memory [`Store`](diskann_inmem::integration::store::Store). +//! +//! Reader, writer, and retirer threads hammer the epoch-based store concurrently while a +//! per-guard invariant checker verifies the store's safety guarantees: +//! +//! 1. Reads are never torn. +//! 2. A readable value is stable for the lifetime of a single reader guard. +//! 3. A slot never resurrects (`readable -> unreadable -> readable`) within one guard. +use std::{ + collections::HashMap, + io::Write, + ops::Range, + sync::{ + Mutex, + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering::Relaxed}, + }, + time::{Duration, Instant}, +}; + +use diskann_benchmark_runner::{ + Benchmark, Checker, Checkpoint, Input, Output, + benchmark::{FailureScore, MatchScore}, + utils::fmt::KeyValue, +}; +use rand::{Rng, SeedableRng, distr::Uniform, rngs::StdRng}; +use serde::{Deserialize, Serialize}; + +use diskann_inmem::integration::store::Store; + +/// Maximum number of concurrent reader guards supported by the epoch registry. +const GUARD_CAPACITY: usize = 256; + +/// Number of slots a reader inspects per guard. Kept small so guards are short-lived, +/// allowing the epoch to advance and reclamation to make progress. +const READER_WINDOW: usize = 64; + +/// Number of times a reader re-reads its window within a single guard. Re-reading is what +/// exercises the value-stability and no-resurrection invariants. +const READER_PASSES: usize = 4; + +/// How often (in retirer iterations) a retirer attempts to reclaim retired slots. +const RECLAIM_EVERY: u64 = 16; + +/////////// +// Input // +/////////// + +/// Configuration for a [`StoreStress`] run. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoreStressInput { + /// Number of reader threads. Must be below [`GUARD_CAPACITY`]. + readers: usize, + /// Number of writer threads. + writers: usize, + /// Number of retirer threads. + retirers: usize, + /// Number of writable (non-frozen) slots. + capacity: usize, + /// Bytes per entry. Must be a non-zero multiple of 8 (the stamp lane width). + entry_bytes: usize, + /// Retirers only retire while the live published population exceeds this watermark. + low_watermark: usize, + /// Wall-clock cap for the run, in seconds. Zero means unbounded (rely on `max_ops`). + duration_secs: u64, + /// Total-operation cap across all worker threads. Zero means unbounded (rely on + /// `duration_secs`). + max_ops: u64, + /// Seed for the worker pseudo-random number generators. + seed: u64, +} + +impl Input for StoreStressInput { + type Raw = Self; + + fn tag() -> &'static str { + "store-stress" + } + + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + if raw.readers == 0 || raw.writers == 0 { + anyhow::bail!("`readers` and `writers` must be non-zero"); + } + if raw.readers >= GUARD_CAPACITY { + anyhow::bail!( + "`readers` ({}) must be below the epoch guard capacity ({GUARD_CAPACITY})", + raw.readers, + ); + } + if raw.capacity == 0 { + anyhow::bail!("`capacity` must be non-zero"); + } + if raw.entry_bytes == 0 || raw.entry_bytes % 8 != 0 { + anyhow::bail!( + "`entry_bytes` ({}) must be a non-zero multiple of 8", + raw.entry_bytes, + ); + } + if raw.low_watermark > raw.capacity { + anyhow::bail!( + "`low_watermark` ({}) must not exceed `capacity` ({})", + raw.low_watermark, + raw.capacity, + ); + } + if raw.duration_secs == 0 && raw.max_ops == 0 { + anyhow::bail!("at least one of `duration_secs` or `max_ops` must be non-zero"); + } + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { + StoreStressInput { + readers: 8, + writers: 4, + retirers: 2, + capacity: 4096, + entry_bytes: 128, + low_watermark: 1024, + duration_secs: 5, + max_ops: 50_000_000, + seed: 0xA5A5_1234_DEAD_BEEF, + } + } +} + +impl std::fmt::Display for StoreStressInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("readers", &self.readers); + kv.push("writers", &self.writers); + kv.push("retirers", &self.retirers); + kv.push("capacity", &self.capacity); + kv.push("entry_bytes", &self.entry_bytes); + kv.push("low_watermark", &self.low_watermark); + kv.push("duration_secs", &self.duration_secs); + kv.push("max_ops", &self.max_ops); + kv.push("seed", &self.seed); + write!(f, "{}", kv) + } +} + +//////////// +// Output // +//////////// + +/// Summary statistics produced by a [`StoreStress`] run. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoreStressStats { + elapsed_secs: f64, + reads: u64, + acquires_ok: u64, + acquires_fail: u64, + retires_ok: u64, + retires_fail: u64, + reclaims: u64, + /// Observed `readable -> unreadable` transitions across all reader guards. + transitions: u64, + /// Peak observed live (published, not-yet-retired) population. + peak_live: usize, +} + +impl std::fmt::Display for StoreStressStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("elapsed_secs", &self.elapsed_secs); + kv.push("reads", &self.reads); + kv.push("acquires_ok", &self.acquires_ok); + kv.push("acquires_fail", &self.acquires_fail); + kv.push("retires_ok", &self.retires_ok); + kv.push("retires_fail", &self.retires_fail); + kv.push("reclaims", &self.reclaims); + kv.push("transitions", &self.transitions); + kv.push("peak_live", &self.peak_live); + write!(f, "{}", kv) + } +} + +///////////// +// Payload // +///////////// + +/// Fill `buf` with `stamp` replicated across every 8-byte lane. +fn write_stamp(buf: &mut [u8], stamp: u64) { + let bytes = stamp.to_ne_bytes(); + for lane in buf.chunks_exact_mut(8) { + lane.copy_from_slice(&bytes); + } +} + +/// Read the stamp from `buf`, returning `Err` if any 8-byte lane disagrees (a torn read). +fn read_stamp(buf: &[u8]) -> Result { + let (lanes, _) = buf.as_chunks::<8>(); + let mut lanes = lanes.iter(); + let first = u64::from_ne_bytes(*lanes.next().ok_or(())?); + for lane in lanes { + if u64::from_ne_bytes(*lane) != first { + return Err(()); + } + } + Ok(first) +} + +//////////////// +// Invariants // +//////////////// + +/// Per-guard observation of a single slot. +#[derive(Debug, Clone, Copy)] +enum SlotObservations { + /// The slot was observed readable with the given stamp. + Readable(u64), + /// The slot was observed readable and then became unreadable (retired). + Retired, +} + +/// Feed a single observation of slot `i` into the per-guard checker, recording a violation +/// on the shared state if a safety invariant is broken. +fn observe( + shared: &Shared, + observed: &mut HashMap, + i: usize, + read: Option<&[u8]>, +) { + match (observed.get(&i).copied(), read) { + // Not yet observed readable; an unreadable slot tells us nothing actionable. + (None, None) => {} + // First readable observation: record the stamp (after a tearing check). + (None, Some(bytes)) => match read_stamp(bytes) { + Ok(stamp) => { + observed.insert(i, SlotObservations::Readable(stamp)); + } + Err(()) => record_violation(shared, format!("torn read at slot {i}")), + }, + // Still readable: the value must be identical and untorn. + (Some(SlotObservations::Readable(prev)), Some(bytes)) => match read_stamp(bytes) { + Ok(stamp) if stamp != prev => record_violation( + shared, + format!("slot {i} value changed within guard: {prev} -> {stamp}"), + ), + Ok(_) => {} + Err(()) => record_violation(shared, format!("torn read at slot {i}")), + }, + // Readable -> unreadable: an allowed, terminal transition. + (Some(SlotObservations::Readable(_)), None) => { + observed.insert(i, SlotObservations::Retired); + shared.transitions.fetch_add(1, Relaxed); + } + // Resurrection: a slot that retired came back to life within the same guard. + (Some(SlotObservations::Retired), Some(_)) => record_violation( + shared, + format!("resurrection at slot {i}: unreadable -> readable within one guard"), + ), + (Some(SlotObservations::Retired), None) => {} + } +} + +//////////// +// Shared // +//////////// + +/// State shared by all worker threads for the duration of a run. +struct Shared { + store: Store, + slots: usize, + readable: Uniform, + writable: Uniform, + low_watermark: usize, + max_ops: u64, + deadline: Instant, + + stop: AtomicBool, + violation: Mutex>, + + stamp: AtomicU64, + live: AtomicUsize, + peak_live: AtomicUsize, + + ops: AtomicU64, + reads: AtomicU64, + acquires_ok: AtomicU64, + acquires_fail: AtomicU64, + retires_ok: AtomicU64, + retires_fail: AtomicU64, + reclaims: AtomicU64, + transitions: AtomicU64, +} + +/// Record the first observed invariant violation and signal all workers to stop. +fn record_violation(shared: &Shared, message: String) { + let mut slot = shared.violation.lock().unwrap(); + slot.push(message); + shared.stop.store(true, Relaxed); +} + +/// Return `true` once any termination condition is met. +fn should_stop(shared: &Shared) -> bool { + shared.stop.load(Relaxed) + || shared.ops.load(Relaxed) >= shared.max_ops + || Instant::now() >= shared.deadline +} + +///////////// +// Workers // +///////////// + +fn writer(shared: &Shared) { + while !should_stop(shared) { + shared.ops.fetch_add(1, Relaxed); + match shared.store.acquire() { + Some(mut writer) => { + let stamp = shared.stamp.fetch_add(1, Relaxed); + write_stamp(writer.as_mut_slice(), stamp); + // Dropping the writer publishes the slot. + drop(writer); + let live = shared.live.fetch_add(1, Relaxed) + 1; + shared.peak_live.fetch_max(live, Relaxed); + shared.acquires_ok.fetch_add(1, Relaxed); + } + None => { + shared.acquires_fail.fetch_add(1, Relaxed); + std::thread::yield_now(); + } + } + } +} + +fn retirer(shared: &Shared, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let mut iteration: u64 = 0; + + while !should_stop(shared) { + shared.ops.fetch_add(1, Relaxed); + iteration += 1; + + // Flow control: keep a steady readable population. + if shared.live.load(Relaxed) > shared.low_watermark { + let i = rng.sample(&shared.writable); + if shared.store.retire(i) { + shared.live.fetch_sub(1, Relaxed); + shared.retires_ok.fetch_add(1, Relaxed); + } else { + shared.retires_fail.fetch_add(1, Relaxed); + } + } + + if iteration % RECLAIM_EVERY == 0 + && let Some(reclaimed) = shared.store.reclaim() + { + shared.reclaims.fetch_add(reclaimed as u64, Relaxed); + } + + std::thread::yield_now(); + } +} + +fn reader(shared: &Shared, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let slots = shared.slots; + let window = READER_WINDOW.min(slots); + let mut observations = HashMap::with_capacity(window); + + while !should_stop(shared) { + shared.ops.fetch_add(1, Relaxed); + let Some(guard) = shared.store.reader() else { + // All guard slots are occupied; back off and retry. + std::thread::yield_now(); + continue; + }; + + observations.clear(); + let start = rng.sample(&shared.readable); + for _ in 0..READER_PASSES { + for k in 0..window { + let i = (start + k) % slots; + observe(shared, &mut observations, i, guard.read(i)); + shared.reads.fetch_add(1, Relaxed); + } + } + } +} + +/////////////// +// Benchmark // +/////////////// + +/// The store concurrency stress benchmark. +#[derive(Debug)] +pub struct StoreStress; + +impl Benchmark for StoreStress { + type Input = StoreStressInput; + type Output = StoreStressStats; + + fn try_match(&self, _input: &StoreStressInput) -> Result { + Ok(MatchScore(0)) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + _input: Option<&StoreStressInput>, + ) -> std::fmt::Result { + write!( + f, + "concurrency stress test for the in-memory store (readers/writers/retirers)" + ) + } + + fn run( + &self, + input: &StoreStressInput, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + let store = Store::new(input.capacity, input.entry_bytes); + let writable = store.writable(); + let slots = store.slots(); + let start = Instant::now(); + + let shared = Shared { + store, + slots, + readable: Uniform::new(0, slots)?, + writable: Uniform::try_from(writable)?, + low_watermark: input.low_watermark, + max_ops: if input.max_ops == 0 { + u64::MAX + } else { + input.max_ops + }, + deadline: if input.duration_secs == 0 { + // Effectively unbounded; the op cap terminates the run. + start + Duration::from_secs(u64::from(u32::MAX)) + } else { + start + Duration::from_secs(input.duration_secs) + }, + stop: AtomicBool::new(false), + violation: Mutex::new(Vec::new()), + // Stamp 0 is reserved for the zeroed frozen point. + stamp: AtomicU64::new(1), + live: AtomicUsize::new(0), + peak_live: AtomicUsize::new(0), + ops: AtomicU64::new(0), + reads: AtomicU64::new(0), + acquires_ok: AtomicU64::new(0), + acquires_fail: AtomicU64::new(0), + retires_ok: AtomicU64::new(0), + retires_fail: AtomicU64::new(0), + reclaims: AtomicU64::new(0), + transitions: AtomicU64::new(0), + }; + + writeln!(output, "{}", input)?; + + std::thread::scope(|scope| { + let shared = &shared; + for _ in 0..input.writers { + scope.spawn(move || writer(shared)); + } + for t in 0..input.retirers { + let seed = input.seed ^ (0x2000_0000 + t as u64); + scope.spawn(move || retirer(shared, seed)); + } + for t in 0..input.readers { + let seed = input.seed ^ (0x4000_0000 + t as u64); + scope.spawn(move || reader(shared, seed)); + } + }); + + let errors: Vec<_> = std::mem::take(&mut *shared.violation.lock().unwrap()); + if !errors.is_empty() { + anyhow::bail!("invariants violated: {:?}", errors); + } + + let elapsed = start.elapsed(); + let stats = StoreStressStats { + elapsed_secs: elapsed.as_secs_f64(), + reads: shared.reads.load(Relaxed), + acquires_ok: shared.acquires_ok.load(Relaxed), + acquires_fail: shared.acquires_fail.load(Relaxed), + retires_ok: shared.retires_ok.load(Relaxed), + retires_fail: shared.retires_fail.load(Relaxed), + reclaims: shared.reclaims.load(Relaxed), + transitions: shared.transitions.load(Relaxed), + peak_live: shared.peak_live.load(Relaxed), + }; + + writeln!(output, "{}", stats)?; + Ok(stats) + } +} diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs index 562623eeb..efc3d63af 100644 --- a/diskann-inmem/src/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -82,25 +82,33 @@ pub(crate) struct Registry { // hand out overlapping `Drain`s referring to the same retiring queue. drain: Mutex<()>, - // We use three queues for storing retiring items. + // We use four queues for storing retiring items. The rationale is documented below. // - // 1. Belongs to the current generation and is getting filled. - // 2. Ready for the next generation that will be populated on the next `try_advance`. - // Note that after a `try_advance` call, both 1 and 2 can be receiving retired items. - // 3. The queue returned from `try_advance` to be drained. Items drained are safe to - // reclaim. + // ```text + // + // 1. Safe to drain + // +-------------------------- + // Items retired at N-1 can | 2. Epoch N-1 + // be observed by guards at | +----------------------- + // N. If we transition to | | 3. Epoch N + // N+1, guards at N can be +-------------------------- + // active still. This it is | 4. Epoch N+1 + // not safe to reclaim items +----------------------- + // from this queue until all 5. Epoch N+2 (reuse #1 queue) + // guards are at least N+1. + // ``` // // We cycle among the queues in a round-robin manner. - retiring: [SegQueue; 3], + retiring: [SegQueue; 4], } // Return the queue index for the `epoch`. fn queue(epoch: u64) -> usize { - epoch.into_usize() % 3 + epoch.into_usize() % 4 } fn last_queue(epoch: u64) -> usize { - queue(epoch.wrapping_sub(1)) + queue(epoch.wrapping_sub(2)) } impl Registry { @@ -587,7 +595,12 @@ mod tests { // Verify that we reclaim the ID flushed by the registering thread. // - // This requires two epoch advancements. + // This requires three epoch advancements. + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + } + { let drain = registry.try_advance().unwrap(); assert!(drain.is_empty()); @@ -752,22 +765,31 @@ mod tests { let gen_b = retire_at(200); assert_eq!(gen_b, gen_a + 1); - // 2nd advance after A (1st after B): drains A's queue → [100]. + // 2st advance after A: must NOT drain item 100. { - let drained: Vec<_> = registry.try_advance().unwrap().collect(); - assert_eq!(drained, &[100]); + let drain = registry.try_advance().unwrap(); + assert!( + drain.is_empty(), + "100 must not drain on 2nd advance after A" + ); } // Retire 300 at generation C. let _gen_c = retire_at(300); - // 2nd advance after B: drains B's queue → [200]. + // 3rd advance after A (1st after B): drains A's queue → [100]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[100]); + } + + // 3rd advance after B: drains B's queue → [200]. { let drained: Vec<_> = registry.try_advance().unwrap().collect(); assert_eq!(drained, &[200]); } - // 2nd advance after C: drains C's queue → [300]. + // 3rd advance after C: drains C's queue → [300]. { let drained: Vec<_> = registry.try_advance().unwrap().collect(); assert_eq!(drained, &[300]); diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs index c93451cb1..c310b5d8f 100644 --- a/diskann-inmem/src/integration/store.rs +++ b/diskann-inmem/src/integration/store.rs @@ -3,7 +3,9 @@ * Licensed under the MIT license. */ -use crate::store; +use diskann_utils::views::Matrix; + +use crate::{num::Bytes, store}; #[derive(Debug)] pub struct Store { @@ -11,6 +13,38 @@ pub struct Store { } impl Store { + /// Construct a store with `capacity` writable slots, each holding `entry_bytes` bytes. + /// + /// A single zeroed frozen point is created internally to satisfy the underlying + /// store's requirement of at least one frozen entry; it occupies the highest slot + /// index and is always readable. + /// + /// # Panics + /// + /// Panics if the underlying store could not be constructed (e.g. `capacity` plus the + /// frozen point exceeds `u32::MAX`). + pub fn new(capacity: usize, entry_bytes: usize) -> Self { + let data = Matrix::new(0u8, 1, entry_bytes); + let store = store::Store::new(capacity, Bytes::new(entry_bytes), 0, data.as_view()) + .expect("failed to construct store"); + Self { store } + } + + /// Return the total number of slots, including the frozen point. + pub fn slots(&self) -> usize { + self.store.frozen().end as usize + } + + /// Return the range of writable (non-frozen) slot indices. + pub fn writable(&self) -> std::ops::Range { + 0..(self.store.frozen().start as usize) + } + + /// Attempt to reclaim retired slots, returning the number reclaimed if any. + pub fn reclaim(&self) -> Option { + self.store.try_drain() + } + pub fn acquire(&self) -> Option> { self.store.acquire().map(Writer::new) } diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 8cd0d77a7..a5d3b2fd6 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -8,12 +8,14 @@ use std::{fmt::Debug, marker::PhantomData}; use diskann::{ANNError, ANNResult}; use diskann_vector::{ UnalignedSlice, - distance::{self, DistanceProvider, Metric}, + conversion::SliceCast, + distance::{self, DistanceProvider, InnerProduct, Metric, Specialize, SquaredL2}, }; use diskann_wide::{ ARCH, arch::{Current, FTarget2}, }; +use half::f16; use thiserror::Error; use crate::{layers, num::Bytes}; @@ -99,25 +101,27 @@ where ////////////// #[derive(Debug)] -struct Distance +struct Distance where T: 'static, + U: 'static, { - f: distance::Distance, + f: distance::Distance, dim: usize, } -impl Clone for Distance { +impl Clone for Distance { fn clone(&self) -> Self { *self } } -impl Copy for Distance {} +impl Copy for Distance {} -impl Distance +impl Distance where T: 'static, + U: 'static, { #[cold] #[inline(never)] @@ -174,20 +178,51 @@ struct DistanceError { // QueryDistance // /////////////////// +// A baby [`std::borrow::Cow`]. #[derive(Debug)] -struct QueryDistance<'a, T> +enum Calf<'a, T> { + Borrowed(&'a [T]), + Owned(Box<[T]>), +} + +impl std::ops::Deref for Calf<'_, T> { + type Target = [T]; + fn deref(&self) -> &Self::Target { + match self { + Self::Borrowed(slice) => slice, + Self::Owned(boxed) => boxed, + } + } +} + +impl<'a, T> From<&'a [T]> for Calf<'a, T> { + fn from(slice: &'a [T]) -> Self { + Self::Borrowed(slice) + } +} + +impl From> for Calf<'_, T> { + fn from(boxed: Box<[T]>) -> Self { + Self::Owned(boxed) + } +} + +#[derive(Debug)] +struct QueryDistance<'a, T, U> where T: 'static, + U: 'static, { - distance: Distance, - query: &'a [T], + distance: Distance, + query: Calf<'a, T>, } -impl<'a, T> QueryDistance<'a, T> +impl<'a, T, U> QueryDistance<'a, T, U> where T: 'static, + U: 'static, { - fn new(distance: Distance, query: &'a [T]) -> Self { + fn new(distance: Distance, query: Calf<'a, T>) -> Self { if query.len() != distance.dim() { panic!("oops"); } @@ -207,9 +242,10 @@ where } } -impl layers::QueryDistance for QueryDistance<'_, T> +impl layers::QueryDistance for QueryDistance<'_, T, U> where - T: Debug + Sync + 'static, + T: Debug + Sync + Send + 'static, + U: Debug + Sync + Send + 'static, { fn evaluate(&self, x: &[u8]) -> ANNResult { if x.len() != self.distance.bytes() { @@ -217,7 +253,7 @@ where } else { Ok(self.distance.f.call_unaligned( unsafe { UnalignedSlice::new(self.query.as_ptr().cast::(), self.distance.dim) }, - unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.distance.dim) }, + unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.distance.dim) }, )) } } @@ -234,18 +270,14 @@ struct QueryDistanceError { xlen: usize, } -//-----------// -// Version 2 // -//-----------// - macro_rules! specialize { - ($me:ident, $query:ident, $visitor:ident, $T:ty, $(($var:ident, $N:literal, $f:ty)),* $(,)?) => { + ($me:ident, $query:ident, $visitor:ident, ($T:ty, $U:ty), $(($var:ident, $N:literal, $f:ty)),* $(,)?) => { match ($me.metric, $me.dim()) { $( (Metric::$var, $N) => { - let wrapped = Wrap::, _>::new($query); + let wrapped = Wrap::, $T, $U>::new($query); return Ok(unsafe { - $visitor.visit_sized::<{ $N * std::mem::size_of::<$T>() }, _>(wrapped) + $visitor.visit_sized::<{ $N * std::mem::size_of::<$U>() }, _>(wrapped) }) }, )* @@ -261,14 +293,52 @@ impl layers::Search for Full { where V: layers::QueryVisitor<'a>, { - use diskann_vector::distance::{Specialize, SquaredL2}; + let query = Calf::Borrowed(query); - specialize!(self, query, visitor, f32, (L2, 100, SquaredL2)); + specialize!( + self, + query, + visitor, + (f32, f16), + (L2, 100, SquaredL2), + (InnerProduct, 768, InnerProduct), + ); + // Fallback Ok(visitor.visit(QueryDistance::new(self.distance, query))) } } +impl layers::Search for Full { + type Query<'a> = &'a [f16]; + + fn query_distance<'a, V>(&'a self, query: &'a [f16], visitor: V) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + let mut as_f32: Box<[f32]> = std::iter::repeat_n(0.0, self.dim()).collect(); + diskann_wide::arch::dispatch2(SliceCast::new(), &mut *as_f32, query); + let query = Calf::Owned(as_f32); + + specialize!( + self, + query, + visitor, + (f32, f16), + (L2, 100, SquaredL2), + (InnerProduct, 768, InnerProduct), + ); + + // Fallback + let distance = Distance { + f: >::distance_comparer(self.metric, Some(self.dim())), + dim: self.dim(), + }; + + Ok(visitor.visit(QueryDistance::new(distance, query))) + } +} + impl layers::Search for Full { type Query<'a> = &'a [u8]; @@ -276,41 +346,64 @@ impl layers::Search for Full { where V: layers::QueryVisitor<'a>, { - use diskann_vector::distance::{Specialize, SquaredL2}; + let query = Calf::Borrowed(query); + + specialize!(self, query, visitor, (u8, u8), (L2, 128, SquaredL2)); - specialize!(self, query, visitor, u8, (L2, 128, SquaredL2)); + // Fallback + Ok(visitor.visit(QueryDistance::new(self.distance, query))) + } +} +impl layers::Search for Full { + type Query<'a> = &'a [i8]; + + fn query_distance<'a, V>(&'a self, query: &'a [i8], visitor: V) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + let query = Calf::Borrowed(query); Ok(visitor.visit(QueryDistance::new(self.distance, query))) } } #[derive(Debug)] -struct Wrap<'a, I, T> { - query: &'a [T], - inner: PhantomData, +struct Wrap<'a, I, T, U> { + query: Calf<'a, T>, + ps: PhantomData<(I, U)>, } -impl<'a, I, T> Wrap<'a, I, T> { - fn new(query: &'a [T]) -> Self { +impl<'a, I, T, U> Wrap<'a, I, T, U> { + fn new(query: Calf<'a, T>) -> Self { Self { query, - inner: PhantomData, + ps: PhantomData, } } } -impl layers::QueryDistance for Wrap<'_, I, T> +impl layers::QueryDistance for Wrap<'_, I, T, U> where - I: for<'a> FTarget2, UnalignedSlice<'a, T>> + I: for<'a> FTarget2, UnalignedSlice<'a, U>> + Send + Sync + Debug, T: Send + Sync + 'static + Debug, + U: Send + Sync + 'static + Debug, { #[inline(always)] fn evaluate(&self, x: &[u8]) -> ANNResult { // TODO: This is not fully valid - we need to check. - let x = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.query.len()) }; - Ok(I::run(ARCH, self.query.into(), x)) + let x = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.query.len()) }; + Ok(I::run(ARCH, (*self.query).into(), x)) } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; +} diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index d827fd6aa..822b02be0 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -228,7 +228,10 @@ impl Store { return Err(RetireError::SlotIsReserved { tag: current }); } - let guard = self.registry.guard().map_err(RetireError::GuardUnavailable)?; + let guard = self + .registry + .guard() + .map_err(RetireError::GuardUnavailable)?; let retiring = Tag::RETIRING; @@ -563,6 +566,4 @@ impl Drop for Slot<'_> { #[cfg(test)] mod tests { use super::*; - - } From 969156cc08751679dbc4d2f5d6d6c1b62ed5edc8 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 23 Jun 2026 13:43:43 -0700 Subject: [PATCH 19/45] Checkpoint. --- Cargo.lock | 1 + diskann-inmem/Cargo.toml | 6 +- diskann-inmem/integration/index/datatype.rs | 119 ------ .../integration/index/{traits.rs => index.rs} | 2 +- diskann-inmem/integration/index/mod.rs | 14 +- diskann-inmem/integration/index/runner.rs | 371 ++++++++++++++++++ diskann-inmem/integration/index/tests.rs | 42 ++ diskann-inmem/integration/main.rs | 2 + diskann-inmem/integration/support/datatype.rs | 338 ++++++++++++++++ diskann-inmem/integration/support/io.rs | 54 +++ diskann-inmem/integration/support/mod.rs | 7 + diskann-inmem/src/layers/full.rs | 2 +- diskann-inmem/src/provider.rs | 62 ++- diskann-inmem/src/sharded.rs | 83 +++- 14 files changed, 952 insertions(+), 151 deletions(-) delete mode 100644 diskann-inmem/integration/index/datatype.rs rename diskann-inmem/integration/index/{traits.rs => index.rs} (96%) create mode 100644 diskann-inmem/integration/index/runner.rs create mode 100644 diskann-inmem/integration/index/tests.rs create mode 100644 diskann-inmem/integration/support/datatype.rs create mode 100644 diskann-inmem/integration/support/io.rs create mode 100644 diskann-inmem/integration/support/mod.rs diff --git a/Cargo.lock b/Cargo.lock index d74087d26..074e475e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -816,6 +816,7 @@ dependencies = [ "crossbeam-queue", "dashmap", "diskann", + "diskann-benchmark-core", "diskann-benchmark-runner", "diskann-utils", "diskann-vector", diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 3f3b61893..24e510c1f 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -25,6 +25,8 @@ serde = { workspace = true, features = ["derive"], optional = true } serde_json = { workspace = true, optional = true } anyhow = { workspace = true, optional = true } rand = { workspace = true, optional = true } +diskann-benchmark-core = { workspace = true, optional = true } +tokio = { workspace = true, optional = true } [lints] workspace = true @@ -35,7 +37,7 @@ tokio = { workspace = true, features = ["macros"] } tempfile = { workspace = true } [[bin]] -name = "integration-tests" +name = "integration-test" path = "integration/main.rs" required-features = ["integration-test"] @@ -45,6 +47,8 @@ default = [] # Enable stress test module integration-test = [ "dep:diskann-benchmark-runner", + "dep:diskann-benchmark-core", + "dep:tokio", "dep:serde", "dep:serde_json", "dep:anyhow", diff --git a/diskann-inmem/integration/index/datatype.rs b/diskann-inmem/integration/index/datatype.rs deleted file mode 100644 index 253ba7a60..000000000 --- a/diskann-inmem/integration/index/datatype.rs +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use half::f16; -use thiserror::Error; - -////////////// -// DataType // -////////////// - -#[derive(Debug, Clone, Copy)] -pub(crate) enum DataType { - F32, - F16, - U8, - I8, -} - -impl DataType { - fn as_str(self) -> &'static str { - match self { - Self::F32 => "f32", - Self::F16 => "f16", - Self::U8 => "u8", - Self::I8 => "i8", - } - } -} - -impl std::fmt::Display for DataType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} - -pub(crate) trait AsDataType { - const DATA_TYPE: DataType; -} - -macro_rules! as_data_type { - ($T:ty, $variant:ident) => { - impl AsDataType for $T { - const DATA_TYPE: DataType = DataType::$variant; - } - }; -} - -as_data_type!(f32, F32); -as_data_type!(f16, F16); -as_data_type!(u8, U8); -as_data_type!(i8, I8); - -#[derive(Debug, Error)] -#[error("wrong data-type: expected {}, got {}", self.expected, self.got)] -pub(crate) struct WrongDataType { - expected: DataType, - got: DataType, -} - -impl WrongDataType { - fn new(expected: DataType, got: DataType) -> Self { - Self { expected, got } - } -} - -/////////// -// Slice // -/////////// - -#[derive(Debug, Clone, Copy)] -pub(crate) enum Slice<'a> { - F32(&'a [f32]), - F16(&'a [f16]), - U8(&'a [u8]), - I8(&'a [i8]), -} - -impl<'a> Slice<'a> { - pub(crate) fn data_type(&self) -> DataType { - match self { - Self::F32(_) => DataType::F32, - Self::F16(_) => DataType::F16, - Self::U8(_) => DataType::U8, - Self::I8(_) => DataType::I8, - } - } - - pub(crate) fn try_cast(self) -> Result<&'a [T], WrongDataType> - where - T: FromSlice, - { - T::from_slice(self) - } -} - -pub(crate) trait FromSlice: Sized { - fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType>; -} - -macro_rules! from_slice { - ($T:ty, $variant:ident) => { - impl FromSlice for $T { - fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType> { - if let Slice::$variant(s) = slice { - Ok(s) - } else { - Err(WrongDataType::new(DataType::$variant, slice.data_type())) - } - } - } - }; -} - -from_slice!(f32, F32); -from_slice!(f16, F16); -from_slice!(u8, U8); -from_slice!(i8, I8); diff --git a/diskann-inmem/integration/index/traits.rs b/diskann-inmem/integration/index/index.rs similarity index 96% rename from diskann-inmem/integration/index/traits.rs rename to diskann-inmem/integration/index/index.rs index f73e1a041..da5e5a593 100644 --- a/diskann-inmem/integration/index/traits.rs +++ b/diskann-inmem/integration/index/index.rs @@ -13,7 +13,7 @@ use half::f16; use diskann_inmem::{Context, Provider, Strategy, layers}; -use super::datatype::{AsDataType, DataType, FromSlice, Slice}; +use crate::support::datatype::{AsDataType, DataType, FromSlice, Slice}; pub(crate) trait Index { fn data_type(&self) -> DataType; diff --git a/diskann-inmem/integration/index/mod.rs b/diskann-inmem/integration/index/mod.rs index 46883ab49..986734406 100644 --- a/diskann-inmem/integration/index/mod.rs +++ b/diskann-inmem/integration/index/mod.rs @@ -3,5 +3,15 @@ * Licensed under the MIT license. */ -mod datatype; -mod traits; +mod index; +mod runner; +mod tests; + +use index::Index; + +use diskann_benchmark_runner::{Registry, RegistryError}; + +pub(super) fn register(registry: &mut Registry) -> Result<(), RegistryError> { + runner::register(registry)?; + Ok(()) +} diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs new file mode 100644 index 000000000..500395fb0 --- /dev/null +++ b/diskann-inmem/integration/index/runner.rs @@ -0,0 +1,371 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::Arc; + +use anyhow::Context; +use diskann::graph::DiskANNIndex; +use diskann_benchmark_runner::{ + Checker, Checkpoint, Output, Registry, RegistryError, + benchmark::{FailureScore, MatchScore}, + files::InputFile, +}; +use diskann_vector::distance::Metric; +use diskann_utils::views::Matrix; +use half::f16; + +use diskann_inmem::{Provider, layers}; + +use crate::{ + index::Index, + support::{datatype::DataType, io::load_and_convert}, +}; + +pub(super) fn register(registry: &mut Registry) -> Result<(), RegistryError> { + registry.register("full-precision-integration-test", FullPrecision)?; + Ok(()) +} + +mod dto { + use super::*; + + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "kebab-case")] + pub(super) enum SerdeMetric { + L2, + InnerProduct, + Cosine, + } + + impl From for Metric { + fn from(m: SerdeMetric) -> Self { + match m { + SerdeMetric::L2 => Metric::L2, + SerdeMetric::InnerProduct => Metric::InnerProduct, + SerdeMetric::Cosine => Metric::Cosine, + } + } + } + + impl TryFrom for SerdeMetric { + type Error = anyhow::Error; + fn try_from(m: Metric) -> anyhow::Result { + match m { + Metric::L2 => Ok(SerdeMetric::L2), + Metric::InnerProduct => Ok(SerdeMetric::InnerProduct), + Metric::Cosine => Ok(SerdeMetric::Cosine), + Metric::CosineNormalized => anyhow::bail!("cosine normalized is not supported"), + } + } + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Data { + pub(super) data: InputFile, + pub(super) queries: InputFile, + pub(super) groundtruth: InputFile, + pub(super) metric: SerdeMetric, + pub(super) data_type: DataType, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) enum Layer { + FullPrecision { data_type: DataType }, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Build { + pub(super) pruned_degree: usize, + pub(super) max_degree: usize, + pub(super) l_build: usize, + pub(super) alpha: f32, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Test { + pub(super) data: Data, + pub(super) layer: Layer, + pub(super) build: Build, + } +} + +#[derive(Debug)] +struct Data { + data: InputFile, + queries: InputFile, + groundtruth: InputFile, + metric: Metric, + data_type: DataType, +} + +impl Data { + fn from_raw(mut raw: dto::Data, checker: &mut Checker) -> anyhow::Result { + let dto::Data { + mut data, + mut queries, + mut groundtruth, + metric, + data_type, + } = raw; + data.resolve(checker)?; + queries.resolve(checker)?; + groundtruth.resolve(checker)?; + Ok(Self { + data, + queries, + groundtruth, + metric: metric.into(), + data_type, + }) + } + + fn as_raw(&self) -> anyhow::Result { + Ok(dto::Data { + data: self.data.clone(), + queries: self.queries.clone(), + groundtruth: self.groundtruth.clone(), + metric: self.metric.try_into()?, + data_type: self.data_type, + }) + } +} + +#[derive(Debug)] +enum Layer { + FullPrecision { data_type: DataType }, +} + +impl Layer { + fn from_raw(raw: dto::Layer) -> Self { + match raw { + dto::Layer::FullPrecision { data_type } => Self::FullPrecision { data_type }, + } + } + + fn as_raw(&self) -> dto::Layer { + match self { + Self::FullPrecision { data_type } => dto::Layer::FullPrecision { + data_type: *data_type, + }, + } + } +} + +#[derive(Debug)] +struct Build { + config: diskann::graph::Config, +} + +impl Build { + fn from_raw(raw: dto::Build, metric: Metric) -> anyhow::Result { + let dto::Build { + pruned_degree, + max_degree, + l_build, + alpha, + } = raw; + let config = diskann::graph::config::Builder::new_with( + pruned_degree, + diskann::graph::config::MaxDegree::new(max_degree), + l_build, + metric.into(), + |b| { + b.alpha(alpha); + }, + ) + .build()?; + + Ok(Self { config }) + } + + fn as_raw(&self) -> dto::Build { + dto::Build { + pruned_degree: self.config.pruned_degree().get(), + max_degree: self.config.max_degree().get(), + l_build: self.config.l_build().get(), + alpha: self.config.alpha(), + } + } +} + +#[derive(Debug)] +struct Test { + data: Data, + layer: Layer, + build: Build, +} + +impl Test { + fn from_raw(raw: dto::Test, checker: &mut Checker) -> anyhow::Result { + let data = Data::from_raw(raw.data, checker)?; + let layer = Layer::from_raw(raw.layer); + let build = Build::from_raw(raw.build, data.metric)?; + + Ok(Self { data, layer, build }) + } + + fn as_raw(&self) -> anyhow::Result { + Ok(dto::Test { + data: self.data.as_raw()?, + layer: self.layer.as_raw(), + build: self.build.as_raw(), + }) + } +} + +/////////////// +// Benchmark // +/////////////// + +impl diskann_benchmark_runner::Input for Test { + type Raw = dto::Test; + + fn tag() -> &'static str { + "integration-test" + } + + fn from_raw(raw: dto::Test, checker: &mut Checker) -> anyhow::Result { + ::from_raw(raw, checker) + } + + fn serialize(&self) -> anyhow::Result { + let raw = self.as_raw()?; + Ok(serde_json::to_value(raw)?) + } + + fn example() -> dto::Test { + dto::Test { + data: dto::Data { + data: InputFile::new("path/to/data"), + queries: InputFile::new("path/to/queries"), + groundtruth: InputFile::new("path/to/groundtruth"), + metric: dto::SerdeMetric::L2, + data_type: DataType::F32, + }, + layer: dto::Layer::FullPrecision { + data_type: DataType::F32, + }, + build: dto::Build { + pruned_degree: 16, + max_degree: 20, + l_build: 50, + alpha: 1.2, + }, + } + } +} + +#[derive(Debug)] +struct FullPrecision; + +impl diskann_benchmark_runner::Benchmark for FullPrecision { + type Input = Test; + type Output = (); + + fn try_match(&self, input: &Test) -> Result { + if let Layer::FullPrecision { .. } = input.layer { + Ok(MatchScore(0)) + } else { + Err(FailureScore(1)) + } + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&Test>, + ) -> std::fmt::Result { + write!(f, "nop") + } + + fn run( + &self, + input: &Test, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result<()> { + let Layer::FullPrecision { data_type } = input.layer else { + anyhow::bail!("oops"); + }; + + // Load the data and perform any necessary data conversions. + let data = { + let mut io = std::fs::File::open(&*input.data.data) + .with_context(|| format!("could not open {}", input.data.data.display()))?; + + load_and_convert(&mut io, input.data.data_type, data_type)? + }; + + let dim = data.nrows(); + + let config = diskann_inmem::provider::Config::new( + data.nrows(), + input.build.config.max_degree().get() + ); + + fn finish(provider: DP, config: diskann::graph::Config) -> Arc + where + DP: diskann::provider::DataProvider, + DiskANNIndex: Index, + { + Arc::new(DiskANNIndex::new(config, provider, None)) + } + + let index_config = input.build.config.clone(); + let index: Arc = match data_type { + DataType::F32 => { + let start = Matrix::new(0.0f32, dim, 1); + let provider = Provider::new( + layers::Full::::new(dim, input.data.metric), + config, + start.row_iter(), + ); + + finish(provider, index_config) + }, + DataType::F16 => { + let start = Matrix::new(f16::from_f32(0.0f32), dim, 1); + let provider = Provider::new( + layers::Full::::new(dim, input.data.metric), + config, + start.row_iter(), + ); + + finish(provider, index_config) + }, + DataType::U8 => { + let start = Matrix::new(0u8, dim, 1); + let provider = Provider::new( + layers::Full::::new(dim, input.data.metric), + config, + start.row_iter(), + ); + finish(provider, index_config) + }, + DataType::I8 => { + let start = Matrix::new(0i8, dim, 1); + let provider = Provider::new( + layers::Full::::new(dim, input.data.metric), + config, + start.row_iter(), + ); + + finish(provider, index_config) + }, + }; + + let rt = diskann_benchmark_core::tokio::runtime(1)?; + + super::tests::insert( + &*index, + data.as_view(), + rt.handle(), + )?; + + Ok(()) + } +} diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs new file mode 100644 index 000000000..0142b7ae5 --- /dev/null +++ b/diskann-inmem/integration/index/tests.rs @@ -0,0 +1,42 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann::graph::search::Knn; +use diskann_benchmark_core::recall::Rows; + +use super::Index; +use crate::support::datatype::DatasetView; + +pub(super) fn insert( + index: &dyn Index, + dataset: DatasetView<'_>, + rt: &tokio::runtime::Handle, +) -> anyhow::Result<()> { + for i in 0..dataset.nrows() { + rt.block_on(index.insert(dataset.row(i).unwrap(), i as u64))?; + } + Ok(()) +} + +fn knn( + index: &dyn Index, + knn: Knn, + queries: DatasetView<'_>, + groundtruth: &dyn Rows, + rt: &tokio::runtime::Handle, +) -> anyhow::Result<()> { + anyhow::ensure!( + queries.nrows() == groundtruth.nrows(), + "number of queries ({}) must match number of groundtruth entries ({})", + queries.nrows(), + groundtruth.nrows(), + ); + + for i in 0..queries.nrows() { + let mut neighbors = Vec::new(); + rt.block_on(index.search(queries.row(i).unwrap(), knn, &mut neighbors))?; + } + Ok(()) +} diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs index ce2f582f6..ef1f03e97 100644 --- a/diskann-inmem/integration/main.rs +++ b/diskann-inmem/integration/main.rs @@ -5,6 +5,7 @@ mod index; mod store; +mod support; use diskann_benchmark_runner::{App, Registry, output}; @@ -12,6 +13,7 @@ use diskann_benchmark_runner::{App, Registry, output}; fn registry() -> anyhow::Result { let mut registry = Registry::new(); registry.register("store-stress", store::StoreStress)?; + index::register(&mut registry)?; Ok(registry) } diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs new file mode 100644 index 000000000..8b7eac764 --- /dev/null +++ b/diskann-inmem/integration/support/datatype.rs @@ -0,0 +1,338 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::views::{Matrix, MatrixView}; +use half::f16; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +////////////// +// DataType // +////////////// + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum DataType { + F32, + F16, + U8, + I8, +} + +impl DataType { + fn as_str(self) -> &'static str { + match self { + Self::F32 => "f32", + Self::F16 => "f16", + Self::U8 => "u8", + Self::I8 => "i8", + } + } +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +pub(crate) trait AsDataType { + const DATA_TYPE: DataType; +} + +#[derive(Debug, Error)] +#[error("wrong data-type: expected {}, got {}", self.expected, self.got)] +pub(crate) struct WrongDataType { + expected: DataType, + got: DataType, +} + +impl WrongDataType { + fn new(expected: DataType, got: DataType) -> Self { + Self { expected, got } + } +} + +/////////// +// Slice // +/////////// + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Slice<'a> { + F32(&'a [f32]), + F16(&'a [f16]), + U8(&'a [u8]), + I8(&'a [i8]), +} + +impl<'a> Slice<'a> { + pub(crate) fn data_type(&self) -> DataType { + match self { + Self::F32(_) => DataType::F32, + Self::F16(_) => DataType::F16, + Self::U8(_) => DataType::U8, + Self::I8(_) => DataType::I8, + } + } + + pub(crate) fn len(&self) -> usize { + match self { + Self::F32(s) => s.len(), + Self::F16(s) => s.len(), + Self::U8(s) => s.len(), + Self::I8(s) => s.len(), + } + } + + pub(crate) fn try_cast(self) -> Result<&'a [T], WrongDataType> + where + T: FromSlice, + { + T::from_slice(self) + } +} + +pub(crate) trait FromSlice: Sized { + fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType>; +} + +////////////// +// SliceMut // +////////////// + +#[derive(Debug)] +pub(crate) enum SliceMut<'a> { + F32(&'a mut [f32]), + F16(&'a mut [f16]), + U8(&'a mut [u8]), + I8(&'a mut [i8]), +} + +fn try_map(dst: &mut [T], src: &[U], f: F) -> anyhow::Result<()> +where + T: std::fmt::Display + AsDataType, + U: std::fmt::Display + AsDataType + Copy, + F: Fn(U) -> Result, +{ + std::iter::zip(dst.iter_mut(), src.iter()).try_for_each(|(d, s)| { + let converted = match f(*s) { + Ok(c) => c, + Err(e) => anyhow::bail!( + "could not losslessly convert {} {} to {}", + U::DATA_TYPE, + s, + T::DATA_TYPE, + ), + }; + *d = converted; + Ok(()) + }) +} + +fn f32_to_f16(x: f32) -> Result { + let y = f16::from_f32(x); + let z = f32::from(y); + if z != x { Err(()) } else { Ok(y) } +} + +fn f32_to_u8(x: f32) -> Result { + let y = x as u8; + let z = f32::from(y); + if z != x { Err(()) } else { Ok(y) } +} + +fn f32_to_i8(x: f32) -> Result { + let y = x as i8; + let z = f32::from(y); + if z != x { Err(()) } else { Ok(y) } +} + +fn f16_to_u8(x: f16) -> Result { + f32_to_u8(x.into()) +} + +fn f16_to_i8(x: f16) -> Result { + f32_to_i8(x.into()) +} + +impl<'a> SliceMut<'a> { + fn len(&self) -> usize { + match self { + Self::F32(s) => s.len(), + Self::F16(s) => s.len(), + Self::U8(s) => s.len(), + Self::I8(s) => s.len(), + } + } + + pub(crate) fn convert_lossless(&mut self, rhs: Slice<'_>) -> anyhow::Result<()> { + if self.len() != rhs.len() { + anyhow::bail!( + "lhs len {} must be equal to rhs len {}", + self.len(), + rhs.len() + ); + } + + match (self, rhs) { + (SliceMut::F32(dst), Slice::F32(src)) => dst.copy_from_slice(src), + (SliceMut::F32(dst), Slice::F16(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::F32(dst), Slice::U8(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::F32(dst), Slice::I8(src)) => try_map(dst, src, |x| x.try_into())?, + + (SliceMut::F16(dst), Slice::F32(src)) => try_map(dst, src, f32_to_f16)?, + (SliceMut::F16(dst), Slice::F16(src)) => dst.copy_from_slice(src), + (SliceMut::F16(dst), Slice::U8(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::F16(dst), Slice::I8(src)) => try_map(dst, src, |x| x.try_into())?, + + (SliceMut::U8(dst), Slice::F32(src)) => try_map(dst, src, f32_to_u8)?, + (SliceMut::U8(dst), Slice::F16(src)) => try_map(dst, src, f16_to_u8)?, + (SliceMut::U8(dst), Slice::U8(src)) => dst.copy_from_slice(src), + (SliceMut::U8(dst), Slice::I8(src)) => try_map(dst, src, |x| x.try_into())?, + + (SliceMut::I8(dst), Slice::F32(src)) => try_map(dst, src, f32_to_i8)?, + (SliceMut::I8(dst), Slice::F16(src)) => try_map(dst, src, f16_to_i8)?, + (SliceMut::I8(dst), Slice::U8(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::I8(dst), Slice::I8(src)) => dst.copy_from_slice(src), + }; + + Ok(()) + } +} + +///////////// +// Dataset // +///////////// + +#[derive(Debug)] +pub(crate) enum Dataset { + F32(Matrix), + F16(Matrix), + U8(Matrix), + I8(Matrix), +} + +impl Dataset { + pub(crate) fn nrows(&self) -> usize { + self.as_view().nrows() + } + + pub(crate) fn ncols(&self) -> usize { + self.as_view().ncols() + } + + pub(crate) fn row(&self, i: usize) -> Option> { + match self { + Self::F32(m) => m.get_row(i).map(Slice::from), + Self::F16(m) => m.get_row(i).map(Slice::from), + Self::U8(m) => m.get_row(i).map(Slice::from), + Self::I8(m) => m.get_row(i).map(Slice::from), + } + } + + pub(crate) fn as_view(&self) -> DatasetView<'_> { + match self { + Self::F32(m) => DatasetView::F32(m.as_view()), + Self::F16(m) => DatasetView::F16(m.as_view()), + Self::U8(m) => DatasetView::U8(m.as_view()), + Self::I8(m) => DatasetView::I8(m.as_view()), + } + } + + pub(crate) fn as_slice(&self) -> Slice<'_> { + match self { + Self::F32(m) => m.as_slice().into(), + Self::F16(m) => m.as_slice().into(), + Self::U8(m) => m.as_slice().into(), + Self::I8(m) => m.as_slice().into(), + } + } +} + +///////////////// +// DatasetView // +///////////////// + +#[derive(Debug, Clone, Copy)] +pub(crate) enum DatasetView<'a> { + F32(MatrixView<'a, f32>), + F16(MatrixView<'a, f16>), + U8(MatrixView<'a, u8>), + I8(MatrixView<'a, i8>), +} + +impl<'a> DatasetView<'a> { + pub(crate) fn nrows(&self) -> usize { + match self { + Self::F32(m) => m.nrows(), + Self::F16(m) => m.nrows(), + Self::U8(m) => m.nrows(), + Self::I8(m) => m.nrows(), + } + } + + pub(crate) fn ncols(&self) -> usize { + match self { + Self::F32(m) => m.ncols(), + Self::F16(m) => m.ncols(), + Self::U8(m) => m.ncols(), + Self::I8(m) => m.ncols(), + } + } + + pub(crate) fn row(&self, i: usize) -> Option> { + match self { + Self::F32(m) => m.get_row(i).map(Slice::from), + Self::F16(m) => m.get_row(i).map(Slice::from), + Self::U8(m) => m.get_row(i).map(Slice::from), + Self::I8(m) => m.get_row(i).map(Slice::from), + } + } +} + +//------// +// Impl // +//------// + +macro_rules! define { + ($T:ty, $variant:ident) => { + impl AsDataType for $T { + const DATA_TYPE: DataType = DataType::$variant; + } + + impl<'a> From<&'a [$T]> for Slice<'a> { + fn from(s: &'a [$T]) -> Self { + Self::$variant(s) + } + } + + impl<'a> From<&'a mut [$T]> for SliceMut<'a> { + fn from(s: &'a mut [$T]) -> Self { + Self::$variant(s) + } + } + + impl FromSlice for $T { + fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType> { + if let Slice::$variant(s) = slice { + Ok(s) + } else { + Err(WrongDataType::new(DataType::$variant, slice.data_type())) + } + } + } + + impl From> for Dataset { + fn from(m: Matrix<$T>) -> Self { + Self::$variant(m) + } + } + }; +} + +define!(f32, F32); +define!(f16, F16); +define!(u8, U8); +define!(i8, I8); diff --git a/diskann-inmem/integration/support/io.rs b/diskann-inmem/integration/support/io.rs new file mode 100644 index 000000000..50109fc7a --- /dev/null +++ b/diskann-inmem/integration/support/io.rs @@ -0,0 +1,54 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::{io::read_bin, views::Matrix}; +use half::f16; + +use super::datatype::{DataType, Dataset, Slice, SliceMut}; + +pub(crate) fn load_and_convert( + io: &mut IO, + src: DataType, + target: DataType, +) -> anyhow::Result +where + IO: std::io::Read + std::io::Seek, +{ + let data = match src { + DataType::F32 => Dataset::from(read_bin::(io)?), + DataType::F16 => Dataset::from(read_bin::(io)?), + DataType::U8 => Dataset::from(read_bin::(io)?), + DataType::I8 => Dataset::from(read_bin::(io)?), + }; + + if src == target { + return Ok(data); + } + + let dst = match target { + DataType::F32 => { + let mut dst = Matrix::new(0.0f32, data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + Dataset::from(dst) + } + DataType::F16 => { + let mut dst = Matrix::new(f16::from_f32(0.0f32), data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + Dataset::from(dst) + } + DataType::U8 => { + let mut dst = Matrix::new(0u8, data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + Dataset::from(dst) + } + DataType::I8 => { + let mut dst = Matrix::new(0i8, data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + Dataset::from(dst) + } + }; + + Ok(dst) +} diff --git a/diskann-inmem/integration/support/mod.rs b/diskann-inmem/integration/support/mod.rs new file mode 100644 index 000000000..9b6eeae43 --- /dev/null +++ b/diskann-inmem/integration/support/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) mod datatype; +pub(crate) mod io; diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index a5d3b2fd6..ab62b7c21 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -299,7 +299,7 @@ impl layers::Search for Full { self, query, visitor, - (f32, f16), + (f32, f32), (L2, 100, SquaredL2), (InnerProduct, 768, InnerProduct), ); diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 239c2b6f0..6f6c5458c 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -42,7 +42,7 @@ impl Provider where M: Id, { - pub fn new(layer: L, capacity: usize, start_points: I) -> Self + pub fn new(layer: L, config: Config, start_points: I) -> Self where I: IntoIterator, L: layers::Set, @@ -55,8 +55,16 @@ where layers::Set::into_bytes(&layer, point, row).unwrap(); } - let store = Store::new(capacity, bytes, 32, data.as_view()).unwrap(); - let mapping = Sharded::new(capacity); + let store = Store::new( + config.capacity(), + bytes, + config.max_degree(), + data.as_view(), + ) + .unwrap(); + + let mapping = Sharded::new(config.capacity()); + Self { store, layer, @@ -65,6 +73,29 @@ where } } +#[derive(Debug)] +pub struct Config { + capacity: usize, + max_degree: usize, +} + +impl Config { + pub fn new(capacity: usize, max_degree: usize) -> Self { + Self { + capacity, + max_degree, + } + } + + pub fn capacity(&self) -> usize { + self.capacity + } + + pub fn max_degree(&self) -> usize { + self.max_degree + } +} + /////////////////// // Data Provider // /////////////////// @@ -117,9 +148,24 @@ where { async fn delete(&self, _context: &Context, gid: &M) -> ANNResult<()> { // TODO: These need to actually happen in lock-step. - let internal = self.mapping.remove(gid).unwrap(); - self.store.retire(internal.into_usize()).unwrap(); - Ok(()) + let entry = match self.mapping.occupied_entry(gid.clone()) { + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "id already deleted", + )); + } + Some(e) => e, + }; + + match self.store.retire(entry.internal().into_usize()) { + Ok(()) => { + // Successfully retired the internal slot. We can safely release the ID mapping. + entry.delete(); + Ok(()) + } + Err(err) => Err(ANNError::opaque(err)), + } } async fn release(&self, _context: &Context, _id: Self::InternalId) -> ANNResult<()> { @@ -777,7 +823,9 @@ mod tests { let full = Full::::new(1, Metric::L2); let start_points: [&[f32]; _] = [&[1.0], &[2.0]]; - let provider = Provider::new(full, 10, start_points); + let config = Config::new(10, 16); + + let provider = Provider::new(full, config, start_points); let config = diskann::graph::config::Builder::new( 10, diff --git a/diskann-inmem/src/sharded.rs b/diskann-inmem/src/sharded.rs index 2714a1192..ef64fa01c 100644 --- a/diskann-inmem/src/sharded.rs +++ b/diskann-inmem/src/sharded.rs @@ -5,9 +5,12 @@ use std::hash::Hash; -use dashmap::{DashMap, mapref::entry::Entry}; +use dashmap::{ + DashMap, + mapref::entry::{self, OccupiedEntry}, +}; use diskann::utils::IntoUsize; -use parking_lot::RwLock; +use parking_lot::{RwLock, RwLockWriteGuard}; use thiserror::Error; const SHARD_SIZE: usize = 1024; @@ -69,8 +72,8 @@ where // on the dashmap shard, and another thread racing on the same `internal` will // block on the backward shard's write lock. let forward = match self.forward.entry(external.clone()) { - Entry::Occupied(_) => return Err(InsertError::ExternalExists), - Entry::Vacant(vacant) => vacant, + entry::Entry::Occupied(_) => return Err(InsertError::ExternalExists), + entry::Entry::Vacant(vacant) => vacant, }; let mut shard = self.backward[outer].write(); @@ -113,24 +116,33 @@ where self.backward[outer].read()[inner].clone() } - /// Remove the mapping for `external`. Returns the freed internal id, or `None` if - /// no such mapping existed. - pub(crate) fn remove(&self, external: &Q) -> Option + /// Validate that a mapping exists for `external` and return an [`Entry`] if successful. + /// + /// The [`Entry`] provides a means of error-free deferred deletion to enable coordinated + /// deletion of slots among multiple stores. + pub(crate) fn occupied_entry(&self, external: I) -> Option> where - I: Eq + Hash + std::borrow::Borrow, - Q: Eq + Hash + ?Sized, + I: Eq + Hash, { - let (_, internal) = self.forward.remove(external)?; - let Shard { outer, inner } = self.shard(internal); - - // The backward slot should be populated by the `insert` invariant. - // - // If not - this is a program bug. - let mut shard = self.backward[outer].write(); - assert!(shard[inner].is_some(), "id {} removed improperly", internal); - shard[inner] = None; - - Some(internal) + match self.forward.entry(external) { + entry::Entry::Vacant(_) => None, + entry::Entry::Occupied(forward) => { + let internal = *forward.get(); + let Shard { outer, inner } = self.shard(internal); + let backward = self.backward[outer].write(); + assert!( + backward[inner].is_some(), + "id {} removed improperly", + internal + ); + + Some(Entry { + forward, + backward, + entry: inner, + }) + } + } } fn shard(&self, i: u32) -> Shard { @@ -156,3 +168,34 @@ pub(crate) enum InsertError { #[error("the internal id is already mapped")] InternalExists, } + +/// A handle to a valid entry in a [`Sharded`]. +/// +/// This can be used to guarantee the presence of an entry prior to deletion to support +/// atomic deletes. +pub(crate) struct Entry<'a, I> +where + I: Eq + Hash, +{ + forward: OccupiedEntry<'a, I, u32>, + backward: RwLockWriteGuard<'a, Box<[Option]>>, + entry: usize, +} + +impl<'a, I> Entry<'a, I> +where + I: Eq + Hash, +{ + pub(crate) fn internal(&self) -> u32 { + *self.forward.get() + } + + pub(crate) fn external(&self) -> &I { + self.forward.key() + } + + pub(crate) fn delete(mut self) { + self.forward.remove(); + self.backward[self.entry] = None; + } +} From a1b23197abfe18c0dfdb44c825305a4a56567b73 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 23 Jun 2026 14:09:09 -0700 Subject: [PATCH 20/45] Checkpoint. --- diskann-inmem/integration/index/runner.rs | 138 +++++++++--------- diskann-inmem/integration/support/datatype.rs | 24 ++- diskann-inmem/src/layers/full.rs | 2 +- diskann-utils/src/views.rs | 13 ++ 4 files changed, 108 insertions(+), 69 deletions(-) diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index 500395fb0..ff0b79b35 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -12,15 +12,18 @@ use diskann_benchmark_runner::{ benchmark::{FailureScore, MatchScore}, files::InputFile, }; -use diskann_vector::distance::Metric; use diskann_utils::views::Matrix; +use diskann_vector::distance::Metric; use half::f16; use diskann_inmem::{Provider, layers}; use crate::{ index::Index, - support::{datatype::DataType, io::load_and_convert}, + support::{ + datatype::{DataType, DatasetView}, + io::load_and_convert, + }, }; pub(super) fn register(registry: &mut Registry) -> Result<(), RegistryError> { @@ -111,9 +114,11 @@ impl Data { metric, data_type, } = raw; + data.resolve(checker)?; queries.resolve(checker)?; groundtruth.resolve(checker)?; + Ok(Self { data, queries, @@ -168,6 +173,7 @@ impl Build { l_build, alpha, } = raw; + let config = diskann::graph::config::Builder::new_with( pruned_degree, diskann::graph::config::MaxDegree::new(max_degree), @@ -215,6 +221,62 @@ impl Test { build: self.build.as_raw(), }) } + + fn index( + &self, + capacity: usize, + start_points: DatasetView<'_>, + ) -> anyhow::Result> { + match self.layer { + Layer::FullPrecision { data_type } => { + if start_points.data_type() != data_type { + anyhow::bail!( + "mismatched data types for start point - expected {}, got {}", + data_type, + start_points.data_type(), + ); + } + + let dim = start_points.ncols(); + let metric = self.data.metric; + let config = diskann_inmem::provider::Config::new( + capacity, + self.build.config.max_degree().get(), + ); + + let index_config = self.build.config.clone(); + + let index = match start_points { + DatasetView::F32(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + index_config, + ), + DatasetView::F16(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + index_config, + ), + DatasetView::U8(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + index_config, + ), + DatasetView::I8(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + index_config, + ), + }; + + Ok(index) + } + } + } +} + +fn finish(provider: DP, config: diskann::graph::Config) -> Arc +where + DP: diskann::provider::DataProvider, + DiskANNIndex: Index, +{ + Arc::new(DiskANNIndex::new(config, provider, None)) } /////////////// @@ -259,6 +321,10 @@ impl diskann_benchmark_runner::Input for Test { } } +//////////////// +// Benchmarks // +//////////////// + #[derive(Debug)] struct FullPrecision; @@ -289,7 +355,7 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { output: &mut dyn Output, ) -> anyhow::Result<()> { let Layer::FullPrecision { data_type } = input.layer else { - anyhow::bail!("oops"); + anyhow::bail!("expected full-precision"); }; // Load the data and perform any necessary data conversions. @@ -300,71 +366,9 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { load_and_convert(&mut io, input.data.data_type, data_type)? }; - let dim = data.nrows(); - - let config = diskann_inmem::provider::Config::new( - data.nrows(), - input.build.config.max_degree().get() - ); - - fn finish(provider: DP, config: diskann::graph::Config) -> Arc - where - DP: diskann::provider::DataProvider, - DiskANNIndex: Index, - { - Arc::new(DiskANNIndex::new(config, provider, None)) - } - - let index_config = input.build.config.clone(); - let index: Arc = match data_type { - DataType::F32 => { - let start = Matrix::new(0.0f32, dim, 1); - let provider = Provider::new( - layers::Full::::new(dim, input.data.metric), - config, - start.row_iter(), - ); - - finish(provider, index_config) - }, - DataType::F16 => { - let start = Matrix::new(f16::from_f32(0.0f32), dim, 1); - let provider = Provider::new( - layers::Full::::new(dim, input.data.metric), - config, - start.row_iter(), - ); - - finish(provider, index_config) - }, - DataType::U8 => { - let start = Matrix::new(0u8, dim, 1); - let provider = Provider::new( - layers::Full::::new(dim, input.data.metric), - config, - start.row_iter(), - ); - finish(provider, index_config) - }, - DataType::I8 => { - let start = Matrix::new(0i8, dim, 1); - let provider = Provider::new( - layers::Full::::new(dim, input.data.metric), - config, - start.row_iter(), - ); - - finish(provider, index_config) - }, - }; - + let index = input.index(data.nrows(), data.medoid().as_view())?; let rt = diskann_benchmark_core::tokio::runtime(1)?; - - super::tests::insert( - &*index, - data.as_view(), - rt.handle(), - )?; + super::tests::insert(&*index, data.as_view(), rt.handle())?; Ok(()) } diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index 8b7eac764..e15c786e0 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_utils::views::{Matrix, MatrixView}; +use diskann_utils::{sampling::medoid::ComputeMedoid, views::{Matrix, MatrixView}}; use half::f16; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -249,6 +249,10 @@ impl Dataset { Self::I8(m) => m.as_slice().into(), } } + + pub(crate) fn medoid(&self) -> Dataset { + self.as_view().medoid() + } } ///////////////// @@ -264,6 +268,15 @@ pub(crate) enum DatasetView<'a> { } impl<'a> DatasetView<'a> { + pub(crate) fn data_type(&self) -> DataType { + match self { + Self::F32(_) => DataType::F32, + Self::F16(_) => DataType::F16, + Self::U8(_) => DataType::U8, + Self::I8(_) => DataType::I8, + } + } + pub(crate) fn nrows(&self) -> usize { match self { Self::F32(m) => m.nrows(), @@ -290,6 +303,15 @@ impl<'a> DatasetView<'a> { Self::I8(m) => m.get_row(i).map(Slice::from), } } + + pub(crate) fn medoid(&self) -> Dataset { + match self { + Self::F32(v) => Matrix::row_vector(Box::from(f32::compute_medoid(*v))).into(), + Self::F16(v) => Matrix::row_vector(Box::from(f16::compute_medoid(*v))).into(), + Self::U8(v) => Matrix::row_vector(Box::from(u8::compute_medoid(*v))).into(), + Self::I8(v) => Matrix::row_vector(Box::from(i8::compute_medoid(*v))).into(), + } + } } //------// diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index ab62b7c21..2f67da695 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -140,7 +140,7 @@ where } fn bytes(&self) -> usize { - self.dim * std::mem::size_of::() + self.dim * std::mem::size_of::() } } diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index a9352918c..844420e82 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -344,6 +344,19 @@ where unsafe { self.get_row_unchecked_mut(row) } } + /// Return row `row` as a mutable slice. + pub fn get_row_mut(&mut self, row: usize) -> Option<&mut [T::Elem]> + where + T: MutDenseData, + { + if row < self.nrows() { + // SAFETY: `row` is in-bounds. + Some(unsafe { self.get_row_unchecked_mut(row) }) + } else { + None + } + } + /// Returns the requested row without boundschecking. /// /// # Safety From b52d9aaf2297c7c8958e895280411f181d13ff96 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 23 Jun 2026 17:24:28 -0700 Subject: [PATCH 21/45] Checkpoint. --- diskann-benchmark-runner/src/utils/fmt.rs | 13 +- diskann-inmem/integration/index/index.rs | 127 +++++++++++- diskann-inmem/integration/index/mod.rs | 2 +- diskann-inmem/integration/index/runner.rs | 184 ++++++++++++++++-- diskann-inmem/integration/index/tests.rs | 98 ++++++++-- diskann-inmem/integration/support/datatype.rs | 5 +- diskann-inmem/src/counters.rs | 178 +++++++++++++++++ diskann-inmem/src/integration/counters.rs | 17 ++ diskann-inmem/src/integration/mod.rs | 1 + diskann-inmem/src/lib.rs | 1 + diskann-inmem/src/provider.rs | 87 ++++++++- 11 files changed, 672 insertions(+), 41 deletions(-) create mode 100644 diskann-inmem/src/counters.rs create mode 100644 diskann-inmem/src/integration/counters.rs diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 462db6257..f7e2a9225 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -410,13 +410,24 @@ impl<'a> KeyValue<'a> { self.max_key_length = self.max_key_length.max(key.len()); self.kv.push((key, value)) } + + pub fn render(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self) + } } impl std::fmt::Display for KeyValue<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let width = self.max_key_length; + let mut prefix = ""; for (k, v) in self.kv.iter() { - writeln!(f, "{:>width$}: {v}", k)?; + let rendered = v.to_string(); + if rendered.contains('\n') { + write!(f, "{}{:>width$}:\n{}", prefix, k, Indent::new(&rendered, 2))? + } else { + write!(f, "{}{:>width$}: {rendered}", prefix, k)?; + } + prefix = "\n"; } Ok(()) } diff --git a/diskann-inmem/integration/index/index.rs b/diskann-inmem/integration/index/index.rs index da5e5a593..ca0f7c8c8 100644 --- a/diskann-inmem/integration/index/index.rs +++ b/diskann-inmem/integration/index/index.rs @@ -8,10 +8,14 @@ use std::{future::Future, pin::Pin}; use diskann::{ graph::{DiskANNIndex, search::Knn}, neighbor::Neighbor, + utils::IntoUsize, }; +use diskann_benchmark_runner::utils::fmt::KeyValue; use half::f16; +use serde::{Deserialize, Serialize}; +use thiserror::Error; -use diskann_inmem::{Context, Provider, Strategy, layers}; +use diskann_inmem::{Context, Provider, Strategy, integration, layers}; use crate::support::datatype::{AsDataType, DataType, FromSlice, Slice}; @@ -23,7 +27,7 @@ pub(crate) trait Index { query: Slice<'a>, knn: Knn, neighbors: &'a mut Vec>, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a>>; fn insert<'a>( &'a self, @@ -31,9 +35,118 @@ pub(crate) trait Index { id: u64, ) -> Pin> + 'a>>; + fn counters(&self) -> Counters; // fn retire(&self, id: u64) -> anyhow::Result<()>; } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct KnnSearch { + hops: usize, + cmps: usize, +} + +impl KnnSearch { + pub(crate) fn new() -> Self { + Self { hops: 0, cmps: 0 } + } +} + +impl From for KnnSearch { + fn from(stats: diskann::graph::index::SearchStats) -> Self { + Self { + hops: stats.hops.into_usize(), + cmps: stats.cmps.into_usize(), + } + } +} + +impl std::ops::AddAssign for KnnSearch { + fn add_assign(&mut self, rhs: Self) { + self.hops = self.hops.wrapping_add(rhs.hops); + self.cmps = self.cmps.wrapping_add(rhs.cmps); + } +} + +impl std::fmt::Display for KnnSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "hops = {}, cmps = {}", self.hops, self.cmps) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Counters { + query_distance: u64, + distance: u64, + get_vector: u64, + set_vector: u64, + get_neighbors: u64, + set_neighbors: u64, + append_neighbors: u64, +} + +impl Counters { + pub(crate) fn delta(&self, after: &Counters) -> anyhow::Result { + #[derive(Debug, Error)] + #[error( + "counter \"{}\" non-monotonically increasing from {} to {}", + self.0, + self.1, + self.2 + )] + struct NonMonotonic(&'static str, u64, u64); + + fn check(before: u64, after: u64, field: &'static str) -> Result { + after + .checked_sub(before) + .ok_or(NonMonotonic(field, before, after)) + } + + let delta = Self { + query_distance: check(self.query_distance, after.query_distance, "query_distance")?, + distance: check(self.distance, after.distance, "distance")?, + get_vector: check(self.get_vector, after.get_vector, "get_vector")?, + set_vector: check(self.set_vector, after.set_vector, "set_vector")?, + get_neighbors: check(self.get_neighbors, after.get_neighbors, "get_neighbors")?, + set_neighbors: check(self.set_neighbors, after.set_neighbors, "set_neighbors")?, + append_neighbors: check( + self.append_neighbors, + after.append_neighbors, + "append_neighbors", + )?, + }; + + Ok(delta) + } +} + +impl std::fmt::Display for Counters { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("query_distance", &self.query_distance); + kv.push("distance", &self.distance); + kv.push("get_vector", &self.get_vector); + kv.push("set_vector", &self.set_vector); + kv.push("get_neighbors", &self.get_neighbors); + kv.push("set_neighbors", &self.set_neighbors); + kv.push("append_neighbors", &self.append_neighbors); + kv.render(f) + } +} + +impl From for Counters { + fn from(snapshot: integration::counters::CounterSnapshot) -> Self { + Self { + query_distance: snapshot.query_distance, + distance: snapshot.distance, + get_vector: snapshot.get_vector, + set_vector: snapshot.set_vector, + get_neighbors: snapshot.get_neighbors, + set_neighbors: snapshot.set_neighbors, + append_neighbors: snapshot.append_neighbors, + } + } +} + /////////// // Impls // /////////// @@ -52,14 +165,14 @@ where query: Slice<'a>, knn: Knn, neighbors: &'a mut Vec>, - ) -> Pin> + 'a>> { + ) -> Pin> + 'a>> { let fut = async move { let query = query.try_cast()?; - let _ = self + let stats = self .search(knn, &Strategy, &Context, query, neighbors) .await?; - Ok(()) + Ok(stats.into()) }; Box::pin(fut) @@ -80,6 +193,10 @@ where Box::pin(fut) } + fn counters(&self) -> Counters { + self.provider().counters().into() + } + // fn retire(&self, id: u64) -> anyhow::Result<()> { // } } diff --git a/diskann-inmem/integration/index/mod.rs b/diskann-inmem/integration/index/mod.rs index 986734406..ef15e8fe1 100644 --- a/diskann-inmem/integration/index/mod.rs +++ b/diskann-inmem/integration/index/mod.rs @@ -7,7 +7,7 @@ mod index; mod runner; mod tests; -use index::Index; +use index::{Counters, Index, KnnSearch}; use diskann_benchmark_runner::{Registry, RegistryError}; diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index ff0b79b35..d46247211 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -3,25 +3,27 @@ * Licensed under the MIT license. */ -use std::sync::Arc; +use std::{io::Write, sync::Arc}; use anyhow::Context; -use diskann::graph::DiskANNIndex; +use diskann::graph::{DiskANNIndex, search::Knn}; use diskann_benchmark_runner::{ Checker, Checkpoint, Output, Registry, RegistryError, benchmark::{FailureScore, MatchScore}, files::InputFile, + utils::fmt::Indent, }; use diskann_utils::views::Matrix; use diskann_vector::distance::Metric; use half::f16; +use serde::{Deserialize, Serialize}; use diskann_inmem::{Provider, layers}; use crate::{ - index::Index, + index::{Counters, Index}, support::{ - datatype::{DataType, DatasetView}, + datatype::{DataType, Dataset, DatasetView}, io::load_and_convert, }, }; @@ -88,11 +90,25 @@ mod dto { pub(super) alpha: f32, } + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct KnnSearch { + pub(super) knn: usize, + pub(super) search_l: usize, + #[serde(deserialize_with = "Deserialize::deserialize")] + pub(super) beam_width: Option, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Search { + pub(super) knn: Vec, + } + #[derive(Debug, Serialize, Deserialize)] pub(super) struct Test { pub(super) data: Data, pub(super) layer: Layer, pub(super) build: Build, + pub(super) search: Search, } } @@ -137,6 +153,43 @@ impl Data { data_type: self.data_type, }) } + + fn load_as(&self, data_type: DataType) -> anyhow::Result { + let data = { + let mut io = std::fs::File::open(&*self.data) + .with_context(|| format!("could not open {}", self.data.display()))?; + + load_and_convert(&mut io, self.data_type, data_type)? + }; + + let queries = { + let mut io = std::fs::File::open(&*self.queries) + .with_context(|| format!("could not open {}", self.queries.display()))?; + + load_and_convert(&mut io, self.data_type, data_type)? + }; + + let groundtruth = { + let mut io = std::fs::File::open(&*self.groundtruth) + .with_context(|| format!("could not open {}", self.queries.display()))?; + + let raw = diskann_utils::io::read_bin::(&mut io)?; + raw.map(|&x| u64::from(x)) + }; + + Ok(Bundle { + data, + queries, + groundtruth, + }) + } +} + +#[derive(Debug)] +struct Bundle { + data: Dataset, + queries: Dataset, + groundtruth: Matrix, } #[derive(Debug)] @@ -198,11 +251,47 @@ impl Build { } } +#[derive(Debug)] +struct Search { + knn: Vec, +} + +impl Search { + fn from_raw(raw: dto::Search) -> anyhow::Result { + fn make_knn(raw: &dto::KnnSearch) -> anyhow::Result { + Ok(Knn::new(raw.knn, raw.search_l, raw.beam_width)?) + } + + Ok(Self { + knn: raw + .knn + .iter() + .map(make_knn) + .collect::>>()?, + }) + } + + fn as_raw(&self) -> dto::Search { + fn make_knn(knn: &Knn) -> dto::KnnSearch { + dto::KnnSearch { + knn: knn.k_value().get(), + search_l: knn.l_value().get(), + beam_width: Some(knn.beam_width().get()), + } + } + + dto::Search { + knn: self.knn.iter().map(make_knn).collect(), + } + } +} + #[derive(Debug)] struct Test { data: Data, layer: Layer, build: Build, + search: Search, } impl Test { @@ -210,8 +299,14 @@ impl Test { let data = Data::from_raw(raw.data, checker)?; let layer = Layer::from_raw(raw.layer); let build = Build::from_raw(raw.build, data.metric)?; + let search = Search::from_raw(raw.search)?; - Ok(Self { data, layer, build }) + Ok(Self { + data, + layer, + build, + search, + }) } fn as_raw(&self) -> anyhow::Result { @@ -219,6 +314,7 @@ impl Test { data: self.data.as_raw()?, layer: self.layer.as_raw(), build: self.build.as_raw(), + search: self.search.as_raw(), }) } @@ -317,6 +413,25 @@ impl diskann_benchmark_runner::Input for Test { l_build: 50, alpha: 1.2, }, + search: dto::Search { + knn: vec![ + dto::KnnSearch { + knn: 10, + search_l: 50, + beam_width: None, + }, + dto::KnnSearch { + knn: 10, + search_l: 50, + beam_width: Some(3), + }, + dto::KnnSearch { + knn: 20, + search_l: 100, + beam_width: Some(3), + }, + ], + }, } } } @@ -330,7 +445,7 @@ struct FullPrecision; impl diskann_benchmark_runner::Benchmark for FullPrecision { type Input = Test; - type Output = (); + type Output = BuildAndSearch; fn try_match(&self, input: &Test) -> Result { if let Layer::FullPrecision { .. } = input.layer { @@ -352,23 +467,62 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { &self, input: &Test, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> anyhow::Result<()> { + mut output: &mut dyn Output, + ) -> anyhow::Result { let Layer::FullPrecision { data_type } = input.layer else { anyhow::bail!("expected full-precision"); }; // Load the data and perform any necessary data conversions. - let data = { - let mut io = std::fs::File::open(&*input.data.data) - .with_context(|| format!("could not open {}", input.data.data.display()))?; - - load_and_convert(&mut io, input.data.data_type, data_type)? - }; + let Bundle { + data, + queries, + groundtruth, + } = input.data.load_as(data_type)?; let index = input.index(data.nrows(), data.medoid().as_view())?; let rt = diskann_benchmark_core::tokio::runtime(1)?; - super::tests::insert(&*index, data.as_view(), rt.handle())?; + let build = super::tests::insert(&*index, data.as_view(), rt.handle())?; + + let mut knn = Vec::new(); + for param in input.search.knn.iter() { + let stats = super::tests::knn( + &*index, + param.clone(), + queries.as_view(), + &groundtruth.as_view(), + rt.handle(), + )?; + + knn.push(stats); + } + + let build_and_search = BuildAndSearch { build, knn }; + + writeln!(output, "{}", build_and_search)?; + + Ok(build_and_search) + } +} + +//////////// +// Output // +//////////// + +#[derive(Debug, Serialize, Deserialize)] +struct BuildAndSearch { + build: Counters, + knn: Vec, +} + +impl std::fmt::Display for BuildAndSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "build stats")?; + writeln!(f, "{}", Indent::new(&self.build.to_string(), 4))?; + writeln!(f, "knn stats")?; + for k in self.knn.iter() { + writeln!(f, "{}", k)?; + } Ok(()) } diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs index 0142b7ae5..9819bf4de 100644 --- a/diskann-inmem/integration/index/tests.rs +++ b/diskann-inmem/integration/index/tests.rs @@ -4,29 +4,35 @@ */ use diskann::graph::search::Knn; -use diskann_benchmark_core::recall::Rows; +use diskann_benchmark_core::recall::{RecallMetrics, Rows}; +use diskann_benchmark_runner::utils::fmt::KeyValue; +use diskann_utils::views::Matrix; +use serde::{Deserialize, Serialize}; -use super::Index; -use crate::support::datatype::DatasetView; +use crate::{ + index::{Counters, Index, KnnSearch}, + support::datatype::DatasetView, +}; pub(super) fn insert( index: &dyn Index, dataset: DatasetView<'_>, rt: &tokio::runtime::Handle, -) -> anyhow::Result<()> { +) -> anyhow::Result { + let before = index.counters(); for i in 0..dataset.nrows() { rt.block_on(index.insert(dataset.row(i).unwrap(), i as u64))?; } - Ok(()) + Ok(before.delta(&index.counters())?) } -fn knn( +pub(super) fn knn( index: &dyn Index, knn: Knn, queries: DatasetView<'_>, groundtruth: &dyn Rows, rt: &tokio::runtime::Handle, -) -> anyhow::Result<()> { +) -> anyhow::Result { anyhow::ensure!( queries.nrows() == groundtruth.nrows(), "number of queries ({}) must match number of groundtruth entries ({})", @@ -34,9 +40,79 @@ fn knn( groundtruth.nrows(), ); - for i in 0..queries.nrows() { - let mut neighbors = Vec::new(); - rt.block_on(index.search(queries.row(i).unwrap(), knn, &mut neighbors))?; + let mut ids = Matrix::new(u64::MAX, queries.nrows(), knn.k_value().get()); + + let before = index.counters(); + let mut misc = KnnSearch::new(); + let mut neighbors = Vec::new(); + for (i, out) in ids.row_iter_mut().enumerate() { + neighbors.clear(); + + let stats = rt.block_on(index.search(queries.row(i).unwrap(), knn, &mut neighbors))?; + misc += stats; + + std::iter::zip(out.iter_mut(), neighbors.iter()).for_each(|(d, s)| *d = s.id); + } + let counters = before.delta(&index.counters())?; + + let recall = diskann_benchmark_core::recall::knn( + groundtruth, + None, + &ids.as_view(), + knn.k_value().get(), + knn.k_value().get(), + diskann_benchmark_core::recall::GroundTruthMode::Fixed, + )?; + + Ok(KnnStats { + counters, + recall: recall.into(), + misc, + }) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct KnnRecall { + recall_k: usize, + recall_n: usize, + num_queries: usize, + average: f64, +} + +impl From for KnnRecall { + fn from(metrics: RecallMetrics) -> Self { + Self { + recall_k: metrics.recall_k, + recall_n: metrics.recall_n, + num_queries: metrics.num_queries, + average: metrics.average, + } + } +} + +impl std::fmt::Display for KnnRecall { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "recall = {:.4}, recall_k = {}, recall_n = {}, num_queries = {}", + self.average, self.recall_k, self.recall_n, self.num_queries + ) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct KnnStats { + recall: KnnRecall, + counters: Counters, + misc: KnnSearch, +} + +impl std::fmt::Display for KnnStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("counters", &self.counters); + kv.push("recall", &self.recall); + kv.push("misc", &self.misc); + kv.render(f) } - Ok(()) } diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index e15c786e0..9ba051552 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -3,7 +3,10 @@ * Licensed under the MIT license. */ -use diskann_utils::{sampling::medoid::ComputeMedoid, views::{Matrix, MatrixView}}; +use diskann_utils::{ + sampling::medoid::ComputeMedoid, + views::{Matrix, MatrixView}, +}; use half::f16; use serde::{Deserialize, Serialize}; use thiserror::Error; diff --git a/diskann-inmem/src/counters.rs b/diskann-inmem/src/counters.rs new file mode 100644 index 000000000..736ccb491 --- /dev/null +++ b/diskann-inmem/src/counters.rs @@ -0,0 +1,178 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) use inner::{Counters, LocalCounters}; + +#[cfg(not(feature = "integration-test"))] +mod inner { + use std::marker::PhantomData; + + #[derive(Debug, Default)] + pub(crate) struct Counters; + + impl Counters { + pub(crate) fn new() -> Self { + Self + } + + pub(crate) fn local(&self) -> LocalCounters<'_> { + LocalCounters::new() + } + } + + #[derive(Debug)] + pub(crate) struct LocalCounters<'a> { + _marker: PhantomData<&'a ()>, + } + + impl LocalCounters<'_> { + fn new() -> Self { + Self { + _marker: PhantomData, + } + } + + pub(crate) fn fork(&self) -> Self { + Self::new() + } + + pub(crate) fn query_distance(&mut self, _i: u64) {} + pub(crate) fn distance(&mut self, _i: u64) {} + pub(crate) fn distance_ref(&self, _i: u64) {} + pub(crate) fn get_vector(&mut self, _i: u64) {} + pub(crate) fn get_vector_ref(&self, _i: u64) {} + pub(crate) fn set_vector(&mut self, _i: u64) {} + pub(crate) fn get_neighbors(&mut self, _i: u64) {} + pub(crate) fn set_neighbors(&mut self, _i: u64) {} + pub(crate) fn append_vector(&mut self, _i: u64) {} + } +} + +#[cfg(feature = "integration-test")] +mod inner { + use std::sync::atomic::{AtomicU64, Ordering::Relaxed}; + + #[derive(Debug, Default)] + pub(crate) struct Counters { + query_distance: AtomicU64, + distance: AtomicU64, + get_vector: AtomicU64, + set_vector: AtomicU64, + get_neighbors: AtomicU64, + set_neighbors: AtomicU64, + append_neighbors: AtomicU64, + } + + impl Counters { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn local(&self) -> LocalCounters<'_> { + LocalCounters::new(self) + } + + pub(crate) fn snapshot(&self) -> crate::integration::counters::CounterSnapshot { + crate::integration::counters::CounterSnapshot { + query_distance: self.query_distance.load(Relaxed), + distance: self.distance.load(Relaxed), + get_vector: self.get_vector.load(Relaxed), + set_vector: self.set_vector.load(Relaxed), + get_neighbors: self.get_neighbors.load(Relaxed), + set_neighbors: self.set_neighbors.load(Relaxed), + append_neighbors: self.append_neighbors.load(Relaxed), + } + } + } + + #[derive(Debug)] + pub(crate) struct LocalCounters<'a> { + query_distance: u64, + // This fields needs to be `AtomicU64` because we increment in some loops where we + // have to increment it behind a shared reference. + distance: AtomicU64, + // This fields needs to be `AtomicU64` because we increment in some loops where we + // have to increment it behind a shared reference. + get_vector: AtomicU64, + set_vector: u64, + get_neighbors: u64, + set_neighbors: u64, + append_neighbors: u64, + parent: &'a Counters, + } + + impl<'a> LocalCounters<'a> { + fn new(parent: &'a Counters) -> Self { + Self { + query_distance: 0, + distance: AtomicU64::new(0), + get_vector: AtomicU64::new(0), + set_vector: 0, + get_neighbors: 0, + set_neighbors: 0, + append_neighbors: 0, + parent, + } + } + + pub(crate) fn fork(&self) -> LocalCounters<'a> { + Self::new(self.parent) + } + + pub(crate) fn query_distance(&mut self, i: u64) { + self.query_distance += i; + } + + pub(crate) fn distance(&mut self, i: u64) { + *self.distance.get_mut() += i; + } + + pub(crate) fn distance_ref(&self, i: u64) { + self.distance.fetch_add(i, Relaxed); + } + + pub(crate) fn get_vector(&mut self, i: u64) { + *self.get_vector.get_mut() += i; + } + + pub(crate) fn get_vector_ref(&self, i: u64) { + self.get_vector.fetch_add(i, Relaxed); + } + + pub(crate) fn set_vector(&mut self, i: u64) { + self.set_vector += i; + } + + pub(crate) fn get_neighbors(&mut self, i: u64) { + self.get_neighbors += i; + } + + pub(crate) fn set_neighbors(&mut self, i: u64) { + self.set_neighbors += i; + } + + pub(crate) fn append_vector(&mut self, i: u64) { + self.append_neighbors += i; + } + } + + impl Drop for LocalCounters<'_> { + fn drop(&mut self) { + let parent = self.parent; + + fn update(dst: &AtomicU64, src: u64) { + dst.fetch_add(src, Relaxed); + } + + update(&parent.query_distance, self.query_distance); + update(&parent.distance, *self.distance.get_mut()); + update(&parent.get_vector, *self.get_vector.get_mut()); + update(&parent.set_vector, self.set_vector); + update(&parent.get_neighbors, self.get_neighbors); + update(&parent.set_neighbors, self.set_neighbors); + update(&parent.append_neighbors, self.append_neighbors); + } + } +} diff --git a/diskann-inmem/src/integration/counters.rs b/diskann-inmem/src/integration/counters.rs new file mode 100644 index 000000000..b6af15ef4 --- /dev/null +++ b/diskann-inmem/src/integration/counters.rs @@ -0,0 +1,17 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +/// A snapshot of global counters. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct CounterSnapshot { + pub query_distance: u64, + pub distance: u64, + pub get_vector: u64, + pub set_vector: u64, + pub get_neighbors: u64, + pub set_neighbors: u64, + pub append_neighbors: u64, +} diff --git a/diskann-inmem/src/integration/mod.rs b/diskann-inmem/src/integration/mod.rs index 2f378e125..312e41f4c 100644 --- a/diskann-inmem/src/integration/mod.rs +++ b/diskann-inmem/src/integration/mod.rs @@ -3,4 +3,5 @@ * Licensed under the MIT license. */ +pub mod counters; pub mod store; diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index a94dbec83..b0f2c4d32 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -8,6 +8,7 @@ pub mod num; mod buffer; +mod counters; mod epoch; mod freelist; mod neighbors; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 6f6c5458c..977613792 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -19,7 +19,8 @@ use diskann::{ use diskann_utils::views::Matrix; use crate::{ - layers::{self, Distance, QueryDistance}, + counters::{Counters, LocalCounters}, + layers::{self, QueryDistance}, num::Bytes, sharded::Sharded, store::{self, Store}, @@ -36,6 +37,10 @@ where store: Store, layer: L, mapping: Sharded, + + // `Counters` is only non-trivial under the `integration-test` feature flag. Otherwise, + // all counter related operations are no-ops. + counters: Counters, } impl Provider @@ -69,8 +74,19 @@ where store, layer, mapping, + counters: Counters::new(), } } + + fn local_counters(&self) -> LocalCounters<'_> { + self.counters.local() + } + + /// Return a snapshot of the current event counters. + #[cfg(feature = "integration-test")] + pub fn counters(&self) -> crate::integration::counters::CounterSnapshot { + self.counters.snapshot() + } } #[derive(Debug)] @@ -228,6 +244,12 @@ where >::into_bytes(&self.layer, element, slot.as_mut_slice())?; self.mapping.insert(id.clone(), slot.slot()).unwrap(); + // This is a rather expensive update. + // + // However, counters are only active with the `integration-test` feature, which + // is not expected to be enabled for general use. + self.local_counters().set_vector(1); + Ok(diskann::provider::NoopGuard::new(slot.slot())) }; @@ -248,6 +270,7 @@ pub struct SearchAccessor<'a> { // The parent provider for the accessor. provider: &'a (dyn std::any::Any + Send + Sync), start_points: std::ops::Range, + counters: LocalCounters<'a>, } impl diskann::provider::HasId for SearchAccessor<'_> { @@ -272,6 +295,10 @@ impl glue::SearchAccessor for SearchAccessor<'_> { for p in self.start_points.clone() { match self.reader.read(p.into_usize()) { Some(point) => { + // Counters are no-ops without `integration-test`. + self.counters.get_vector(1); + self.counters.query_distance(1); + f(p, self.expand_beam.evaluate(point)?); } None => { @@ -302,11 +329,21 @@ impl glue::SearchAccessor for SearchAccessor<'_> { let work = move || -> ANNResult<()> { for i in ids { self.reader.neighbors().get(i, &mut self.ids).unwrap(); + self.counters.get_neighbors(1); // Filter out unvisited IDs and ensure that all the IDs we are about self.ids .retain(|i| pred.eval_mut(i) && self.reader.is_in_bounds(i.into_usize())); + // TODO: Move to an external buffer to avoid any dynamic dispatcn in + // `expand_beam_inner` - then we can do a bulk-update on the counters. + let mut on_neighbors = |id, distance| { + self.counters.get_vector(1); + self.counters.query_distance(1); + + on_neighbors(id, distance); + }; + unsafe { self.expand_beam .run(&self.ids, 8, &self.reader, &mut on_neighbors) @@ -533,7 +570,28 @@ where #[derive(Debug)] pub struct PruneAccessor<'a> { reader: store::Reader<'a>, - distance: &'a dyn Distance, + distance: &'a dyn layers::Distance, + counters: LocalCounters<'a>, +} + +#[derive(Debug)] +pub struct Distance<'a> { + distance: &'a dyn layers::Distance, + counters: LocalCounters<'a>, +} + +impl<'a> Distance<'a> { + fn new(distance: &'a dyn layers::Distance, counters: LocalCounters<'a>) -> Self { + Self { distance, counters } + } +} + +impl diskann_vector::DistanceFunction<&[u8], &[u8], f32> for Distance<'_> { + #[inline] + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.counters.distance_ref(1); + self.distance.evaluate(x, y).unwrap() + } } impl diskann::provider::HasId for PruneAccessor<'_> { @@ -554,7 +612,7 @@ impl glue::PruneAccessor for PruneAccessor<'_> { Self: 'a; type Distance<'a> - = &'a dyn Distance + = Distance<'a> where Self: 'a; @@ -569,7 +627,7 @@ impl glue::PruneAccessor for PruneAccessor<'_> { where Itr: ExactSizeIterator + Clone + Send + Sync, { - Ok((self, &*self.distance)) + Ok((self, Distance::new(self.distance, self.counters.fork()))) } } @@ -579,7 +637,10 @@ impl provider::NeighborAccessor for PruneAccessor<'_> { id: Self::Id, neighbors: &mut AdjacencyList, ) -> impl std::future::Future> + Send { - let work = move || Ok(self.reader.neighbors().get(id, neighbors).unwrap()); + let work = move || { + self.counters.get_neighbors(1); + Ok(self.reader.neighbors().get(id, neighbors).unwrap()) + }; ready(work) } } @@ -590,7 +651,10 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { id: Self::Id, neighbors: &[Self::Id], ) -> impl std::future::Future> + Send { - let work = move || Ok(self.reader.neighbors().set(id, neighbors).unwrap()); + let work = move || { + self.counters.set_neighbors(1); + Ok(self.reader.neighbors().set(id, neighbors).unwrap()) + }; ready(work) } @@ -600,6 +664,7 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { neighbors: &[Self::Id], ) -> impl std::future::Future> + Send { let work = move || -> ANNResult<()> { + self.counters.append_vector(1); self.reader .neighbors() .lock(id) @@ -620,7 +685,13 @@ impl workingset::View for &PruneAccessor<'_> { where Self: 'a; fn get(&self, id: u32) -> Option<&[u8]> { - self.reader.read(id.into_usize()) + match self.reader.read(id.into_usize()) { + Some(data) => { + self.counters.get_vector_ref(1); + Some(data) + } + None => None, + } } } @@ -660,6 +731,7 @@ where expand_beam, provider, start_points: provider.store.frozen(), + counters: provider.local_counters(), }; Ok(accessor) } @@ -746,6 +818,7 @@ where Ok(PruneAccessor { reader: provider.store.reader()?, distance: ::as_distance(&provider.layer), + counters: provider.local_counters(), }) } } From 5efb805c4b856ca8524ea72027190e50f79dcc40 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 23 Jun 2026 17:47:33 -0700 Subject: [PATCH 22/45] More progress. --- diskann-benchmark-runner/src/app.rs | 6 +- diskann-benchmark-runner/src/utils/fmt.rs | 25 +++++---- diskann-inmem/integration/index/runner.rs | 2 +- diskann-inmem/integration/support/exact.rs | 65 ++++++++++++++++++++++ diskann-inmem/integration/support/mod.rs | 1 + 5 files changed, 84 insertions(+), 15 deletions(-) create mode 100644 diskann-inmem/integration/support/exact.rs diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index 42f146ca5..09fd58df6 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -228,7 +228,7 @@ impl App { writeln!(output)?; } else { writeln!(output)?; - write!(output, "{}", Indent::new(&description, 8))?; + writeln!(output, "{}", Indent::new(&description, 8))?; } } } @@ -258,8 +258,8 @@ impl App { )?; writeln!(output, "Closest matches:\n")?; for (i, mismatch) in mismatches.into_iter().enumerate() { - writeln!(output, " {}. \"{}\":", i + 1, mismatch.method(),)?; - writeln!(output, "{}", Indent::new(mismatch.reason(), 8),)?; + writeln!(output, " {}. \"{}\":", i + 1, mismatch.method())?; + writeln!(output, "{}\n", Indent::new(mismatch.reason(), 8))?; } writeln!(output)?; diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index f7e2a9225..92fa6d7d7 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -203,7 +203,7 @@ impl std::fmt::Display for Banner<'_> { /// use diskann_benchmark_runner::utils::fmt::Indent; /// /// let indented = Indent::new("hello\nworld", 4).to_string(); -/// assert_eq!(indented, " hello\n world\n"); +/// assert_eq!(indented, " hello\n world"); /// ``` #[derive(Debug, Clone, Copy)] pub struct Indent<'a> { @@ -221,9 +221,15 @@ impl<'a> Indent<'a> { impl std::fmt::Display for Indent<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let spaces = self.spaces; - self.string - .lines() - .try_for_each(|ln| writeln!(f, "{: >spaces$}{}", "", ln)) + let mut first = true; + for ln in self.string.lines() { + if !first { + writeln!(f)?; + } + write!(f, "{: >spaces$}{}", "", ln)?; + first = false; + } + Ok(()) } } @@ -384,10 +390,7 @@ where /// kv.push("a", &1); /// kv.push("hello", &"world"); /// -/// let expected = -/// " a: 1 -/// hello: world -/// "; +/// let expected = " a: 1\nhello: world"; /// /// assert_eq!(kv.to_string(), expected); /// ``` @@ -575,19 +578,19 @@ string, , string #[test] fn test_indent_single_line() { let s = Indent::new("hello", 4).to_string(); - assert_eq!(s, " hello\n"); + assert_eq!(s, " hello"); } #[test] fn test_indent_multi_line() { let s = Indent::new("hello\nworld\nfoo", 2).to_string(); - assert_eq!(s, " hello\n world\n foo\n"); + assert_eq!(s, " hello\n world\n foo"); } #[test] fn test_indent_zero_spaces() { let s = Indent::new("hello\nworld", 0).to_string(); - assert_eq!(s, "hello\nworld\n"); + assert_eq!(s, "hello\nworld"); } #[test] diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index d46247211..626541e50 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -521,7 +521,7 @@ impl std::fmt::Display for BuildAndSearch { writeln!(f, "{}", Indent::new(&self.build.to_string(), 4))?; writeln!(f, "knn stats")?; for k in self.knn.iter() { - writeln!(f, "{}", k)?; + writeln!(f, "{}\n", k)?; } Ok(()) diff --git a/diskann-inmem/integration/support/exact.rs b/diskann-inmem/integration/support/exact.rs new file mode 100644 index 000000000..5b4aec9f8 --- /dev/null +++ b/diskann-inmem/integration/support/exact.rs @@ -0,0 +1,65 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{fmt::Write, borrow::Cow}; + +trait ExactMatch { + fn exact_match(&self, other: &Self, matcher: Matcher<'_>); +} + +struct Mismatch { + path: String, + expected: String, + got: String, + remark: Option>, +} + +pub(crate) struct Matcher<'a> { + mismatches: &'a mut Vec, + path: &'a mut String, + len: usize, +} + +impl<'a> Matcher<'a> { + pub(crate) fn push(&mut self, field: &D) -> Matcher<'_> + where + D: std::fmt::Display, + { + let len = self.path.len(); + if len == 0 { + write!(self.path, "{}", field).unwrap(); + } else { + write!(self.path, ".{}", field).unwrap(); + } + + Matcher { + mismatches: self.mismatches, + path: self.path, + len, + } + } + + pub(crate) fn mismatch(&mut self, expected: &D, got: &D, remark: Option) + where + D: std::fmt::Display, + R: Into>, + { + let mismatch = Mismatch { + path: self.path.clone(), + expected: expected.to_string(), + got: got.to_string(), + remark: remark.map(|x| x.into()) + }; + + self.mismatches.push(mismatch); + } +} + +impl Drop for Matcher<'_> { + fn drop(&mut self) { + self.path.truncate(self.len); + } +} + diff --git a/diskann-inmem/integration/support/mod.rs b/diskann-inmem/integration/support/mod.rs index 9b6eeae43..6f04269d8 100644 --- a/diskann-inmem/integration/support/mod.rs +++ b/diskann-inmem/integration/support/mod.rs @@ -5,3 +5,4 @@ pub(crate) mod datatype; pub(crate) mod io; +pub(crate) mod exact; From d5a24f4a766e8a6668e44a3c9cbd5553330eb3eb Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 24 Jun 2026 15:04:48 -0700 Subject: [PATCH 23/45] Unit testing is almost there ... --- diskann-inmem/Cargo.toml | 1 + diskann-inmem/integration/index/index.rs | 42 +- diskann-inmem/integration/index/runner.rs | 33 +- diskann-inmem/integration/index/tests.rs | 27 +- .../integration/jsons/integration.json | 134 +++++++ .../{example => jsons}/store-stress-test.json | 0 .../{example => jsons}/store-stress.json | 0 diskann-inmem/integration/main.rs | 67 +++- diskann-inmem/integration/support/check.rs | 371 ++++++++++++++++++ diskann-inmem/integration/support/exact.rs | 65 --- diskann-inmem/integration/support/mod.rs | 3 +- .../integration/support/tolerance.rs | 32 ++ diskann-inmem/src/provider.rs | 4 +- 13 files changed, 704 insertions(+), 75 deletions(-) create mode 100644 diskann-inmem/integration/jsons/integration.json rename diskann-inmem/integration/{example => jsons}/store-stress-test.json (100%) rename diskann-inmem/integration/{example => jsons}/store-stress.json (100%) create mode 100644 diskann-inmem/integration/support/check.rs delete mode 100644 diskann-inmem/integration/support/exact.rs create mode 100644 diskann-inmem/integration/support/tolerance.rs diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 24e510c1f..1d1bc00d3 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -53,4 +53,5 @@ integration-test = [ "dep:serde_json", "dep:anyhow", "dep:rand", + "diskann-utils/testing", ] diff --git a/diskann-inmem/integration/index/index.rs b/diskann-inmem/integration/index/index.rs index ca0f7c8c8..9f8410084 100644 --- a/diskann-inmem/integration/index/index.rs +++ b/diskann-inmem/integration/index/index.rs @@ -17,7 +17,10 @@ use thiserror::Error; use diskann_inmem::{Context, Provider, Strategy, integration, layers}; -use crate::support::datatype::{AsDataType, DataType, FromSlice, Slice}; +use crate::support::{ + check::{CheckMatch, Match, MatchBuilder, check_all_fields}, + datatype::{AsDataType, DataType, FromSlice, Slice}, +}; pub(crate) trait Index { fn data_type(&self) -> DataType; @@ -36,7 +39,6 @@ pub(crate) trait Index { ) -> Pin> + 'a>>; fn counters(&self) -> Counters; - // fn retire(&self, id: u64) -> anyhow::Result<()>; } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -73,6 +75,20 @@ impl std::fmt::Display for KnnSearch { } } +impl CheckMatch for KnnSearch { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { hops, cmps }, + ); + + builder.finish_with_remark(Some( + "check assumes deterministic (usually single-threaded) execution".into(), + )) + } +} + #[derive(Debug, Serialize, Deserialize)] pub(crate) struct Counters { query_distance: u64, @@ -147,6 +163,28 @@ impl From for Counters { } } +impl CheckMatch for Counters { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { + query_distance, + distance, + get_vector, + set_vector, + get_neighbors, + set_neighbors, + append_neighbors + } + ); + + builder.finish_with_remark(Some( + "check assumes deterministic (usually single-threaded) execution".into(), + )) + } +} + /////////// // Impls // /////////// diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index 626541e50..191c8dc9a 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -9,7 +9,7 @@ use anyhow::Context; use diskann::graph::{DiskANNIndex, search::Knn}; use diskann_benchmark_runner::{ Checker, Checkpoint, Output, Registry, RegistryError, - benchmark::{FailureScore, MatchScore}, + benchmark::{FailureScore, MatchScore, Regression, PassFail}, files::InputFile, utils::fmt::Indent, }; @@ -23,13 +23,15 @@ use diskann_inmem::{Provider, layers}; use crate::{ index::{Counters, Index}, support::{ + check::{CheckMatch, Match, MatchBuilder, check_all_fields}, datatype::{DataType, Dataset, DatasetView}, io::load_and_convert, + tolerance, }, }; pub(super) fn register(registry: &mut Registry) -> Result<(), RegistryError> { - registry.register("full-precision-integration-test", FullPrecision)?; + registry.register_regression("full-precision-integration-test", FullPrecision)?; Ok(()) } @@ -505,6 +507,22 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { } } +impl Regression for FullPrecision { + type Tolerances = tolerance::Empty; + type Pass = Match; + type Fail = Match; + + fn check( + &self, + _tolerances: &Self::Tolerances, + _input: &Self::Input, + before: &Self::Output, + after: &Self::Output, + ) -> anyhow::Result> { + Ok(before.check_match(after).pass_fail()) + } +} + //////////// // Output // //////////// @@ -527,3 +545,14 @@ impl std::fmt::Display for BuildAndSearch { Ok(()) } } + +impl CheckMatch for BuildAndSearch { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { build, knn }, + ); + builder.finish() + } +} diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs index 9819bf4de..3354c5022 100644 --- a/diskann-inmem/integration/index/tests.rs +++ b/diskann-inmem/integration/index/tests.rs @@ -11,7 +11,10 @@ use serde::{Deserialize, Serialize}; use crate::{ index::{Counters, Index, KnnSearch}, - support::datatype::DatasetView, + support::{ + check::{CheckMatch, Match, MatchBuilder, check_all_fields}, + datatype::DatasetView, + }, }; pub(super) fn insert( @@ -90,6 +93,17 @@ impl From for KnnRecall { } } +impl CheckMatch for KnnRecall { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { recall_k, recall_n, num_queries, average } + ); + builder.finish() + } +} + impl std::fmt::Display for KnnRecall { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -116,3 +130,14 @@ impl std::fmt::Display for KnnStats { kv.render(f) } } + +impl CheckMatch for KnnStats { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { recall, counters, misc } + ); + builder.finish() + } +} diff --git a/diskann-inmem/integration/jsons/integration.json b/diskann-inmem/integration/jsons/integration.json new file mode 100644 index 000000000..84369ab38 --- /dev/null +++ b/diskann-inmem/integration/jsons/integration.json @@ -0,0 +1,134 @@ +{ + "search_directories": [ + "disk_index_search" + ], + "output_directory": null, + "jobs": [ + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "f32" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + }, + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "f16" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + }, + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "u8" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + } + ] +} diff --git a/diskann-inmem/integration/example/store-stress-test.json b/diskann-inmem/integration/jsons/store-stress-test.json similarity index 100% rename from diskann-inmem/integration/example/store-stress-test.json rename to diskann-inmem/integration/jsons/store-stress-test.json diff --git a/diskann-inmem/integration/example/store-stress.json b/diskann-inmem/integration/jsons/store-stress.json similarity index 100% rename from diskann-inmem/integration/example/store-stress.json rename to diskann-inmem/integration/jsons/store-stress.json diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs index ef1f03e97..afbf7af18 100644 --- a/diskann-inmem/integration/main.rs +++ b/diskann-inmem/integration/main.rs @@ -30,13 +30,68 @@ fn main() -> anyhow::Result<()> { mod tests { use super::*; + use std::path::Path; + use diskann_benchmark_runner::{app::Commands, output::Memory}; + use diskann_utils::test_data_root; + use serde::{Serialize, Deserialize}; // The directory containing the committed example input files. fn example_directory() -> std::path::PathBuf { std::path::Path::new(env!("CARGO_MANIFEST_DIR")) .join("integration") - .join("example") + .join("jsons") + } + + // TODO: add first class `diskann-benchmark-runner` support for this. + fn load_from_file(path: &std::path::Path) -> T + where + T: for<'a> Deserialize<'a>, + { + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + serde_json::from_reader(reader).unwrap() + } + + fn value_from_file(path: &std::path::Path) -> serde_json::Value { + load_from_file(path) + } + + fn save_to_file(path: &std::path::Path, value: &T) + where + T: Serialize + ?Sized, + { + if path.exists() { + panic!("path {} already exists!", path.display()); + } + let buffer = std::fs::File::create(path).unwrap(); + serde_json::to_writer_pretty(buffer, value).unwrap(); + } + + fn prefix_search_directories(raw: &mut serde_json::Value, root: &std::path::Path) { + let key = "search_directories"; + if let serde_json::Value::Object(obj) = raw { + let value = obj + .get_mut(key) + .expect("key \"search-directories\" should exist"); + if let serde_json::Value::Array(directories) = value { + for value in directories.iter_mut() { + if let serde_json::Value::String(dir) = value { + *dir = root.join(&dir).to_str().unwrap().into(); + } + } + } else { + panic!("Expected an Array - got {}", raw); + } + } else { + panic!("Expected an Object - got {}", raw); + } + } + + fn prepend(input: &Path, output: &Path) { + let mut v = value_from_file(input); + prefix_search_directories(&mut v, &test_data_root()); + save_to_file(output, &v); } // Drive the named example through the full runner flow: load the JSON input file, @@ -46,10 +101,13 @@ mod tests { assert!(input_file.exists(), "missing example file: {input_file:?}"); let tempdir = tempfile::tempdir().unwrap(); + let modified_input_file = tempdir.path().join("input.json"); let output_file = tempdir.path().join("output.json"); + prepend(&input_file, &modified_input_file); + let command = Commands::Run { - input_file, + input_file: modified_input_file, output_file: output_file.clone(), dry_run: false, // Unit tests are a debug build; bypass the runner's debug-mode guard. @@ -68,4 +126,9 @@ mod tests { fn store_stress_integration() { run_example("store-stress-test.json"); } + + #[test] + fn graph_index() { + run_example("integration.json"); + } } diff --git a/diskann-inmem/integration/support/check.rs b/diskann-inmem/integration/support/check.rs new file mode 100644 index 000000000..f517a12d2 --- /dev/null +++ b/diskann-inmem/integration/support/check.rs @@ -0,0 +1,371 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{ + borrow::Cow, + fmt::{Display, Write}, +}; + +use diskann_benchmark_runner::{utils::fmt::Table, benchmark::PassFail}; +use serde::{Serialize, Serializer}; + +pub(crate) trait CheckMatch { + fn check_match(&self, previous: &Self) -> Match; +} + +#[must_use = "this is a result type"] +#[derive(Debug, Serialize)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum Match { + Ok, + Mismatch { + got: String, + expected: String, + remark: Option>, + }, + Nested { + children: Vec<(Key, Match)>, + remark: Option>, + }, +} + +impl Match { + #[must_use = "this has no side-effects"] + pub(crate) fn is_ok(&self) -> bool { + matches!(self, Self::Ok) + } + + pub(crate) fn mismatch(got: &dyn Display, expected: &dyn Display) -> Self { + Self::mismatch_with_remark(got, expected, None) + } + + pub(crate) fn mismatch_with_remark( + got: &dyn Display, + expected: &dyn Display, + remark: Option>, + ) -> Self { + Self::Mismatch { + expected: expected.to_string(), + got: got.to_string(), + remark, + } + } + + pub(crate) fn pass_fail(self) -> PassFail { + if self.is_ok() { + PassFail::Pass(self) + } else { + PassFail::Fail(self) + } + } +} + +impl std::fmt::Display for Match { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ok => f.write_str("ok"), + Self::Mismatch { + got, + expected, + remark, + } => { + let header = ["got", "expected", "remark"]; + let mut table = Table::new(header, 1); + let mut row = table.row(0); + row.insert(got.clone(), 0); + row.insert(expected.clone(), 1); + if let Some(remark) = remark { + row.insert(remark.clone(), 2); + } + + table.fmt(f) + } + Self::Nested { children, remark } => { + let mut records = Vec::new(); + if let Some(remark) = remark { + records.push(Record { + path: String::new(), + got: "", + expected: "", + remark, + }); + } + + let mut buf = String::new(); + gather_mismatches(children, &mut records, Stack::new(&mut buf)); + + let mut table = Table::new(["path", "got", "expected", "remark"], records.len()); + for (i, r) in records.into_iter().enumerate() { + let mut row = table.row(i); + row.insert(r.path, 0); + row.insert(r.got.to_owned(), 1); + row.insert(r.expected.to_owned(), 2); + row.insert(r.remark.to_owned(), 3); + } + + table.fmt(f) + } + } + } +} + +fn gather_mismatches<'a>( + mismatches: &'a [(Key, Match)], + records: &mut Vec>, + mut path: Stack<'_>, +) { + for (k, m) in mismatches.iter() { + match m { + Match::Ok => continue, + Match::Mismatch { + got, + expected, + remark, + } => { + let record = Record { + path: path.push(k).get(), + got, + expected, + remark: remark.as_deref().unwrap_or(""), + }; + records.push(record); + } + Match::Nested { children, remark } => { + let path = path.push(k); + + if let Some(remark) = remark { + records.push(Record { + path: path.get(), + got: "", + expected: "", + remark, + }) + } + + gather_mismatches(children, records, path) + } + } + } +} + +#[derive(Debug)] +struct Stack<'a> { + s: &'a mut String, + len: usize, +} + +impl<'a> Stack<'a> { + fn new(s: &'a mut String) -> Self { + s.clear(); + Self { s, len: 0 } + } + + fn push(&mut self, key: &Key) -> Stack<'_> { + let len = self.s.len(); + if len == 0 { + write!(self.s, "{}", key).unwrap(); + } else { + write!(self.s, ".{}", key).unwrap(); + } + + Stack { s: self.s, len } + } + + fn get(&self) -> String { + self.s.clone() + } +} + +impl Drop for Stack<'_> { + fn drop(&mut self) { + self.s.truncate(self.len) + } +} + +#[derive(Debug)] +struct Record<'a> { + path: String, + got: &'a str, + expected: &'a str, + remark: &'a str, +} + +///////// +// Key // +///////// + +#[derive(Debug, Clone)] +pub(crate) enum Key { + Str(&'static str), + Position(usize), + String(String), +} + +impl Key { + pub(crate) fn display(key: &D) -> Self + where + D: std::fmt::Display, + { + Key::String(key.to_string()) + } +} + +impl std::fmt::Display for Key { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Str(s) => f.write_str(s), + Self::Position(i) => write!(f, "{}", i), + Self::String(s) => f.write_str(s), + } + } +} + +impl Serialize for Key { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::Str(s) => serializer.serialize_str(s), + Self::Position(i) => serializer.serialize_u64(*i as u64), + Self::String(s) => serializer.serialize_str(s), + } + } +} + +impl From<&'static str> for Key { + fn from(s: &'static str) -> Key { + Key::Str(s) + } +} + +impl From for Key { + fn from(i: usize) -> Key { + Key::Position(i) + } +} + +impl From for Key { + fn from(s: String) -> Key { + Key::String(s) + } +} + + +///////////// +// Builder // +///////////// + +#[derive(Debug)] +pub(crate) struct MatchBuilder { + children: Vec<(Key, Match)>, +} + +impl MatchBuilder { + pub(crate) fn new() -> Self { + Self { + children: Vec::new(), + } + } + + pub(crate) fn push(&mut self, key: Key, child: Match) { + if !child.is_ok() { + self.children.push((key, child)); + } + } + + pub(crate) fn finish(self) -> Match { + self.finish_with_remark(None) + } + + pub(crate) fn finish_with_remark(self, remark: Option>) -> Match { + if self.children.is_empty() { + Match::Ok + } else { + Match::Nested { + children: self.children, + remark: remark, + } + } + } +} + +macro_rules! check_match_impl { + ($T:ty) => { + impl CheckMatch for $T { + fn check_match( + &self, + previous: &Self, + ) -> Match { + if self == previous { + Match::Ok + } else { + Match::mismatch(self, previous) + } + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(check_match_impl!($Ts);)+ + } +} + +check_match_impl!( + bool, u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, f32, f64, &str, String +); + +impl CheckMatch for [T] +where + T: CheckMatch, +{ + fn check_match(&self, previous: &[T]) -> Match { + if self.len() != previous.len() { + return Match::mismatch_with_remark( + &self.len(), + &previous.len(), + Some("number of results is different between runs".into()), + ); + } + + let mut builder = MatchBuilder::new(); + for (i, (got, expected)) in std::iter::zip(self.iter(), previous.iter()).enumerate() { + builder.push(Key::from(i), got.check_match(expected)); + } + + builder.finish() + } +} + +impl CheckMatch for Vec +where + T: CheckMatch, +{ + fn check_match(&self, previous: &Vec) -> Match { + self.as_slice().check_match(previous.as_slice()) + } +} + +//////////// +// Macros // +//////////// + +macro_rules! check_all_fields { + ($self:expr, $prev:expr, { $($field:ident),+ $(,)? } $(,)?) => {{ + let Self { $($field),+ } = $self; + let mut builder = $crate::support::check::MatchBuilder::new(); + $( + builder.push( + stringify!($field).into(), + <_ as $crate::support::check::CheckMatch>::check_match( + $field, + &$prev.$field + ), + ); + )+ + builder + }}; +} + +pub(crate) use check_all_fields; diff --git a/diskann-inmem/integration/support/exact.rs b/diskann-inmem/integration/support/exact.rs deleted file mode 100644 index 5b4aec9f8..000000000 --- a/diskann-inmem/integration/support/exact.rs +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{fmt::Write, borrow::Cow}; - -trait ExactMatch { - fn exact_match(&self, other: &Self, matcher: Matcher<'_>); -} - -struct Mismatch { - path: String, - expected: String, - got: String, - remark: Option>, -} - -pub(crate) struct Matcher<'a> { - mismatches: &'a mut Vec, - path: &'a mut String, - len: usize, -} - -impl<'a> Matcher<'a> { - pub(crate) fn push(&mut self, field: &D) -> Matcher<'_> - where - D: std::fmt::Display, - { - let len = self.path.len(); - if len == 0 { - write!(self.path, "{}", field).unwrap(); - } else { - write!(self.path, ".{}", field).unwrap(); - } - - Matcher { - mismatches: self.mismatches, - path: self.path, - len, - } - } - - pub(crate) fn mismatch(&mut self, expected: &D, got: &D, remark: Option) - where - D: std::fmt::Display, - R: Into>, - { - let mismatch = Mismatch { - path: self.path.clone(), - expected: expected.to_string(), - got: got.to_string(), - remark: remark.map(|x| x.into()) - }; - - self.mismatches.push(mismatch); - } -} - -impl Drop for Matcher<'_> { - fn drop(&mut self) { - self.path.truncate(self.len); - } -} - diff --git a/diskann-inmem/integration/support/mod.rs b/diskann-inmem/integration/support/mod.rs index 6f04269d8..982329f46 100644 --- a/diskann-inmem/integration/support/mod.rs +++ b/diskann-inmem/integration/support/mod.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub(crate) mod check; pub(crate) mod datatype; pub(crate) mod io; -pub(crate) mod exact; +pub(crate) mod tolerance; diff --git a/diskann-inmem/integration/support/tolerance.rs b/diskann-inmem/integration/support/tolerance.rs new file mode 100644 index 000000000..7f8efa870 --- /dev/null +++ b/diskann-inmem/integration/support/tolerance.rs @@ -0,0 +1,32 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_benchmark_runner::{Checker, Input}; +use serde::{Deserialize, Serialize}; + +/// A tolerance [`Input`] for [`diskann_benchmark_runner::benchmark::Regression`]s that +/// do not need any external tolerances. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub(crate) struct Empty; + +impl Input for Empty { + type Raw = Self; + + fn tag() -> &'static str { + "empty-tolerance" + } + + fn from_raw(raw: Self::Raw, _: &mut Checker) -> anyhow::Result { + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { + Self + } +} diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 977613792..27c3c2f1e 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -738,10 +738,10 @@ where } pub fn test_function<'a>( - x: &'a Provider>, + x: &'a Provider>, strategy: &'a Strategy, context: &'a Context, - query: &'a [f32], + query: &'a [u8], ) -> SearchAccessor<'a> { glue::SearchStrategy::search_accessor(strategy, x, context, query).unwrap() } From b822031297a4bda59d873950dcc1b674f1bfb1dc Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 24 Jun 2026 18:12:28 -0700 Subject: [PATCH 24/45] The end is in sight. --- diskann-benchmark/src/index/inmem2.rs | 7 +- diskann-inmem/Cargo.toml | 2 +- diskann-inmem/integration/index/index.rs | 3 +- diskann-inmem/integration/index/runner.rs | 47 ++- diskann-inmem/integration/index/tests.rs | 2 +- diskann-inmem/integration/jsons/checks.json | 14 + .../jsons/integration-baseline.json | 362 ++++++++++++++++++ diskann-inmem/integration/main.rs | 133 ++++++- diskann-inmem/integration/store.rs | 1 - diskann-inmem/integration/support/check.rs | 44 ++- diskann-inmem/integration/support/datatype.rs | 2 +- diskann-inmem/integration/support/io.rs | 2 +- diskann-inmem/src/counters.rs | 5 - diskann-inmem/src/epoch.rs | 15 +- diskann-inmem/src/integration/counters.rs | 2 +- diskann-inmem/src/integration/store.rs | 4 - diskann-inmem/src/layers/full.rs | 4 +- diskann-inmem/src/layers/mod.rs | 2 +- diskann-inmem/src/lib.rs | 14 + diskann-inmem/src/neighbors.rs | 6 + diskann-inmem/src/provider.rs | 37 +- diskann-inmem/src/store.rs | 234 ++++++++++- 22 files changed, 846 insertions(+), 96 deletions(-) create mode 100644 diskann-inmem/integration/jsons/checks.json create mode 100644 diskann-inmem/integration/jsons/integration-baseline.json diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index c656efd6b..9e703d4d2 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -160,7 +160,8 @@ impl Benchmark for Inmem2 { let metric = Metric::L2; let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; let layer = Full::::new(dim, metric); - let provider = Provider::new(layer, num_points, start.row_iter()); + let config = diskann_inmem::provider::Config::new(num_points, exact_max_degree); + let provider = Provider::new(layer, config, start.row_iter()); let config = graph::config::Builder::new_with( input.max_degree, @@ -422,7 +423,9 @@ impl Benchmark for Inmem2Stream { let metric = Metric::L2; let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; let layer = Full::::new(dim, metric); - let provider = Provider::new(layer, max_points, start.row_iter()); + + let config = diskann_inmem::provider::Config::new(max_points, exact_max_degree); + let provider = Provider::new(layer, config, start.row_iter()); let config = graph::config::Builder::new_with( input.max_degree, diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 1d1bc00d3..82d17e1a8 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -20,7 +20,7 @@ thiserror = { workspace = true } half = { workspace = true } # Integration Test Dependencies -diskann-benchmark-runner = { workspace = true, optional = true } +diskann-benchmark-runner = { workspace = true, optional = true, features = ["ux-tools"] } serde = { workspace = true, features = ["derive"], optional = true } serde_json = { workspace = true, optional = true } anyhow = { workspace = true, optional = true } diff --git a/diskann-inmem/integration/index/index.rs b/diskann-inmem/integration/index/index.rs index 9f8410084..6a630a2e9 100644 --- a/diskann-inmem/integration/index/index.rs +++ b/diskann-inmem/integration/index/index.rs @@ -11,14 +11,13 @@ use diskann::{ utils::IntoUsize, }; use diskann_benchmark_runner::utils::fmt::KeyValue; -use half::f16; use serde::{Deserialize, Serialize}; use thiserror::Error; use diskann_inmem::{Context, Provider, Strategy, integration, layers}; use crate::support::{ - check::{CheckMatch, Match, MatchBuilder, check_all_fields}, + check::{CheckMatch, Match, check_all_fields}, datatype::{AsDataType, DataType, FromSlice, Slice}, }; diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index 191c8dc9a..bd2a68335 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -9,7 +9,7 @@ use anyhow::Context; use diskann::graph::{DiskANNIndex, search::Knn}; use diskann_benchmark_runner::{ Checker, Checkpoint, Output, Registry, RegistryError, - benchmark::{FailureScore, MatchScore, Regression, PassFail}, + benchmark::{FailureScore, MatchScore, PassFail, Regression}, files::InputFile, utils::fmt::Indent, }; @@ -23,7 +23,7 @@ use diskann_inmem::{Provider, layers}; use crate::{ index::{Counters, Index}, support::{ - check::{CheckMatch, Match, MatchBuilder, check_all_fields}, + check::{CheckMatch, Match, check_all_fields}, datatype::{DataType, Dataset, DatasetView}, io::load_and_convert, tolerance, @@ -124,7 +124,7 @@ struct Data { } impl Data { - fn from_raw(mut raw: dto::Data, checker: &mut Checker) -> anyhow::Result { + fn from_raw(raw: dto::Data, checker: Option<&mut Checker>) -> anyhow::Result { let dto::Data { mut data, mut queries, @@ -133,9 +133,11 @@ impl Data { data_type, } = raw; - data.resolve(checker)?; - queries.resolve(checker)?; - groundtruth.resolve(checker)?; + if let Some(checker) = checker { + data.resolve(checker)?; + queries.resolve(checker)?; + groundtruth.resolve(checker)?; + } Ok(Self { data, @@ -297,7 +299,7 @@ struct Test { } impl Test { - fn from_raw(raw: dto::Test, checker: &mut Checker) -> anyhow::Result { + fn from_raw(raw: dto::Test, checker: Option<&mut Checker>) -> anyhow::Result { let data = Data::from_raw(raw.data, checker)?; let layer = Layer::from_raw(raw.layer); let build = Build::from_raw(raw.build, data.metric)?; @@ -389,7 +391,7 @@ impl diskann_benchmark_runner::Input for Test { } fn from_raw(raw: dto::Test, checker: &mut Checker) -> anyhow::Result { - ::from_raw(raw, checker) + ::from_raw(raw, Some(checker)) } fn serialize(&self) -> anyhow::Result { @@ -450,17 +452,14 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { type Output = BuildAndSearch; fn try_match(&self, input: &Test) -> Result { - if let Layer::FullPrecision { .. } = input.layer { - Ok(MatchScore(0)) - } else { - Err(FailureScore(1)) - } + let Layer::FullPrecision { .. } = input.layer; + Ok(MatchScore(0)) } fn description( &self, f: &mut std::fmt::Formatter<'_>, - input: Option<&Test>, + _input: Option<&Test>, ) -> std::fmt::Result { write!(f, "nop") } @@ -468,12 +467,10 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { fn run( &self, input: &Test, - checkpoint: Checkpoint<'_>, + _checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, ) -> anyhow::Result { - let Layer::FullPrecision { data_type } = input.layer else { - anyhow::bail!("expected full-precision"); - }; + let Layer::FullPrecision { data_type } = input.layer; // Load the data and perform any necessary data conversions. let Bundle { @@ -556,3 +553,17 @@ impl CheckMatch for BuildAndSearch { builder.finish() } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn example_parses() { + let _ = Test::from_raw(::example(), None).unwrap(); + } +} diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs index 3354c5022..dd9b686c1 100644 --- a/diskann-inmem/integration/index/tests.rs +++ b/diskann-inmem/integration/index/tests.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use crate::{ index::{Counters, Index, KnnSearch}, support::{ - check::{CheckMatch, Match, MatchBuilder, check_all_fields}, + check::{CheckMatch, Match, check_all_fields}, datatype::DatasetView, }, }; diff --git a/diskann-inmem/integration/jsons/checks.json b/diskann-inmem/integration/jsons/checks.json new file mode 100644 index 000000000..ed53e8d06 --- /dev/null +++ b/diskann-inmem/integration/jsons/checks.json @@ -0,0 +1,14 @@ +{ + "checks": [ + { + "input": { + "type": "integration-test", + "content": {} + }, + "tolerance": { + "type": "empty-tolerance", + "content": null + } + } + ] +} diff --git a/diskann-inmem/integration/jsons/integration-baseline.json b/diskann-inmem/integration/jsons/integration-baseline.json new file mode 100644 index 000000000..6f8f5b014 --- /dev/null +++ b/diskann-inmem/integration/jsons/integration-baseline.json @@ -0,0 +1,362 @@ +[ + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "f32" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 2447, + "distance": 59957, + "get_neighbors": 14477, + "get_vector": 42151, + "query_distance": 22813, + "set_neighbors": 430, + "set_vector": 256 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 514, + "get_vector": 1449, + "query_distance": 1449, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1449, + "hops": 514 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 519, + "get_vector": 1450, + "query_distance": 1450, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1450, + "hops": 519 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 1013, + "get_vector": 1615, + "query_distance": 1615, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1615, + "hops": 1013 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + }, + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "f16" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 2447, + "distance": 59957, + "get_neighbors": 14477, + "get_vector": 42151, + "query_distance": 22813, + "set_neighbors": 430, + "set_vector": 256 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 514, + "get_vector": 1449, + "query_distance": 1449, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1449, + "hops": 514 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 519, + "get_vector": 1450, + "query_distance": 1450, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1450, + "hops": 519 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 1013, + "get_vector": 1615, + "query_distance": 1615, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1615, + "hops": 1013 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + }, + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "u8" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 2447, + "distance": 59957, + "get_neighbors": 14477, + "get_vector": 42151, + "query_distance": 22813, + "set_neighbors": 430, + "set_vector": 256 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 514, + "get_vector": 1449, + "query_distance": 1449, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1449, + "hops": 514 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 519, + "get_vector": 1450, + "query_distance": 1450, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1450, + "hops": 519 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 1013, + "get_vector": 1615, + "query_distance": 1615, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1615, + "hops": 1013 + }, + "recall": { + "average": 0.91, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + } +] \ No newline at end of file diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs index afbf7af18..446543808 100644 --- a/diskann-inmem/integration/main.rs +++ b/diskann-inmem/integration/main.rs @@ -32,9 +32,33 @@ mod tests { use std::path::Path; - use diskann_benchmark_runner::{app::Commands, output::Memory}; + use diskann_benchmark_runner::{ + app::{Check, Commands}, + output::Memory, + }; use diskann_utils::test_data_root; - use serde::{Serialize, Deserialize}; + use serde::{Deserialize, Serialize}; + use serde_json::Value; + + // Environment variable used to regenerate committed regression baselines. + const DISKANN_TEST_ENV: &str = "DISKANN_TEST"; + + // Return `true` if `DISKANN_TEST=overwrite` is set, instructing regression tests to + // overwrite their committed baselines instead of checking against them. + // + // If `DISKANN_TEST` is set to anything other than `overwrite`, panic. + fn overwrite_baselines() -> bool { + match std::env::var(DISKANN_TEST_ENV) { + Ok(v) if v == "overwrite" => true, + Ok(v) => { + panic!("unknown value for {DISKANN_TEST_ENV}: \"{v}\". Expected \"overwrite\"") + } + Err(std::env::VarError::NotPresent) => false, + Err(std::env::VarError::NotUnicode(_)) => { + panic!("value for {DISKANN_TEST_ENV} is not unicode") + } + } + } // The directory containing the committed example input files. fn example_directory() -> std::path::PathBuf { @@ -57,11 +81,11 @@ mod tests { load_from_file(path) } - fn save_to_file(path: &std::path::Path, value: &T) + fn save_to_file(path: &std::path::Path, value: &T, force: bool) where T: Serialize + ?Sized, { - if path.exists() { + if path.exists() && !force { panic!("path {} already exists!", path.display()); } let buffer = std::fs::File::create(path).unwrap(); @@ -88,10 +112,10 @@ mod tests { } } - fn prepend(input: &Path, output: &Path) { + fn prepend(input: &Path, output: &Path, root: &Path) { let mut v = value_from_file(input); - prefix_search_directories(&mut v, &test_data_root()); - save_to_file(output, &v); + prefix_search_directories(&mut v, root); + save_to_file(output, &v, false); } // Drive the named example through the full runner flow: load the JSON input file, @@ -104,7 +128,7 @@ mod tests { let modified_input_file = tempdir.path().join("input.json"); let output_file = tempdir.path().join("output.json"); - prepend(&input_file, &modified_input_file); + prepend(&input_file, &modified_input_file, &test_data_root()); let command = Commands::Run { input_file: modified_input_file, @@ -122,6 +146,93 @@ mod tests { assert!(output_file.exists(), "results file was not written"); } + // Drive the named example through the runner, then run a regression check comparing the + // freshly produced results against a committed baseline. + // + // By default this fails the test if the regression check reports a negative result. When + // `DISKANN_TEST=overwrite` is set, the committed baseline is instead overwritten with the + // freshly produced results (enabling future migrations) and no check is performed. + fn run_regression_example(input_name: &str, tolerances_name: &str, baseline_name: &str) { + let input_file = example_directory().join(input_name); + let tolerances_file = example_directory().join(tolerances_name); + let baseline_file = example_directory().join(baseline_name); + assert!(input_file.exists(), "missing example file: {input_file:?}"); + assert!( + tolerances_file.exists(), + "missing tolerances file: {tolerances_file:?}" + ); + + let tempdir = tempfile::tempdir().unwrap(); + let modified_input_file = tempdir.path().join("input.json"); + let output_file = tempdir.path().join("output.json"); + + prepend(&input_file, &modified_input_file, &test_data_root()); + + // Run the benchmark to produce the "after" results. + let command = Commands::Run { + input_file: modified_input_file.clone(), + output_file: output_file.clone(), + dry_run: false, + // Unit tests are a debug build; bypass the runner's debug-mode guard. + allow_debug: true, + }; + let mut output = Memory::new(); + App::from_commands(command) + .run(®istry().unwrap(), &mut output) + .unwrap(); + assert!(output_file.exists(), "results file was not written"); + + // In overwrite mode, replace the committed baseline and skip the check. + if overwrite_baselines() { + // When over-writing, we need to scrub the file paths of the test directory. + // + // Otherwise, we end up with absolute paths in the baselines. + let mut v = value_from_file(&output_file); + scrub(&mut v, &test_data_root()); + save_to_file(&baseline_file, &v, true); + + return; + } + + assert!( + baseline_file.exists(), + "missing baseline {baseline_file:?}; regenerate it with {DISKANN_TEST_ENV}=overwrite" + ); + + // Run the regression check. A negative result (or any error) propagates here and + // fails the test. + let command = Commands::Check(Check::Run { + tolerances: tolerances_file, + input_file: modified_input_file, + before: baseline_file, + after: output_file, + output_file: None, + }); + let mut output = Memory::new(); + + if let Err(err) = App::from_commands(command).run(®istry().unwrap(), &mut output) { + panic!( + "Regression check failed:\n\n{}\n\n{}", + err, + String::from_utf8(output.into_inner()).unwrap() + ); + } + } + + fn scrub(value: &mut Value, root: &Path) { + let mut values = vec![value]; + while let Some(value) = values.pop() { + match value { + Value::Null | Value::Bool(_) | Value::Number(_) => {} + Value::String(s) => { + *s = diskann_benchmark_runner::ux::scrub_path(s.clone(), root, ""); + } + Value::Array(v) => v.iter_mut().for_each(|v| values.push(v)), + Value::Object(m) => m.values_mut().for_each(|v| values.push(v)), + } + } + } + #[test] fn store_stress_integration() { run_example("store-stress-test.json"); @@ -129,6 +240,10 @@ mod tests { #[test] fn graph_index() { - run_example("integration.json"); + run_regression_example( + "integration.json", + "checks.json", + "integration-baseline.json", + ); } } diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index c82dbf46d..af3d99fee 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -15,7 +15,6 @@ use std::{ collections::HashMap, io::Write, - ops::Range, sync::{ Mutex, atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering::Relaxed}, diff --git a/diskann-inmem/integration/support/check.rs b/diskann-inmem/integration/support/check.rs index f517a12d2..f807827e2 100644 --- a/diskann-inmem/integration/support/check.rs +++ b/diskann-inmem/integration/support/check.rs @@ -3,28 +3,46 @@ * Licensed under the MIT license. */ +//! # Baseline Checking +//! +//! The [`Regression`](diskann_benchmark_runner::benchmark::Regression) provides a means +//! of performing before/after comparisons against previously generated results. However, +//! presentation of these results is largely left to the devices of the implementors. +//! +//! This module provides a means of aggregating all match failures (if any) and presenting +//! all failures as a single unit. + use std::{ borrow::Cow, fmt::{Display, Write}, }; -use diskann_benchmark_runner::{utils::fmt::Table, benchmark::PassFail}; +use diskann_benchmark_runner::{benchmark::PassFail, utils::fmt::Table}; use serde::{Serialize, Serializer}; +/// Perform a basline check on `self` and a `previous`ly saved result. pub(crate) trait CheckMatch { fn check_match(&self, previous: &Self) -> Match; } +/// The result of a basline. #[must_use = "this is a result type"] #[derive(Debug, Serialize)] #[serde(rename_all = "kebab-case")] pub(crate) enum Match { + /// Successful match. Ok, + + /// A mismatch on a specific field. Mismatch { got: String, expected: String, remark: Option>, }, + + /// A collection of mismatches for an aggregate data type or collection. + /// + /// Use [`MatchBuilder`] to easier construction. Nested { children: Vec<(Key, Match)>, remark: Option>, @@ -32,15 +50,22 @@ pub(crate) enum Match { } impl Match { + /// Return `true` if `self` is [`Match::Ok`]. #[must_use = "this has no side-effects"] pub(crate) fn is_ok(&self) -> bool { matches!(self, Self::Ok) } + /// Record a single mismatch between the retrieved value `got` and the `expected` result. pub(crate) fn mismatch(got: &dyn Display, expected: &dyn Display) -> Self { Self::mismatch_with_remark(got, expected, None) } + /// Record a single mismatch between the retrieved value `got` and the `expected` result + /// with an additional optional remark. + /// + /// The remark can be used for contexts where matches are more complex than simple + /// equality. pub(crate) fn mismatch_with_remark( got: &dyn Display, expected: &dyn Display, @@ -53,6 +78,9 @@ impl Match { } } + /// Convert `self` into a [`PassFail`] for regression checks. + /// + /// Returns `PassFail::Pass` only if `self.is_ok`. pub(crate) fn pass_fail(self) -> PassFail { if self.is_ok() { PassFail::Pass(self) @@ -196,6 +224,10 @@ struct Record<'a> { // Key // ///////// +/// A key to develop the full hierarchical path for a match. +/// +/// Keys can either be strings or positional indices. The latter are used when traversing +/// arrays. #[derive(Debug, Clone)] pub(crate) enum Key { Str(&'static str), @@ -203,15 +235,6 @@ pub(crate) enum Key { String(String), } -impl Key { - pub(crate) fn display(key: &D) -> Self - where - D: std::fmt::Display, - { - Key::String(key.to_string()) - } -} - impl std::fmt::Display for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -253,7 +276,6 @@ impl From for Key { } } - ///////////// // Builder // ///////////// diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index 9ba051552..2a399c6e3 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -122,7 +122,7 @@ where std::iter::zip(dst.iter_mut(), src.iter()).try_for_each(|(d, s)| { let converted = match f(*s) { Ok(c) => c, - Err(e) => anyhow::bail!( + Err(_) => anyhow::bail!( "could not losslessly convert {} {} to {}", U::DATA_TYPE, s, diff --git a/diskann-inmem/integration/support/io.rs b/diskann-inmem/integration/support/io.rs index 50109fc7a..c00a9ca00 100644 --- a/diskann-inmem/integration/support/io.rs +++ b/diskann-inmem/integration/support/io.rs @@ -6,7 +6,7 @@ use diskann_utils::{io::read_bin, views::Matrix}; use half::f16; -use super::datatype::{DataType, Dataset, Slice, SliceMut}; +use super::datatype::{DataType, Dataset, SliceMut}; pub(crate) fn load_and_convert( io: &mut IO, diff --git a/diskann-inmem/src/counters.rs b/diskann-inmem/src/counters.rs index 736ccb491..b53940dc2 100644 --- a/diskann-inmem/src/counters.rs +++ b/diskann-inmem/src/counters.rs @@ -39,7 +39,6 @@ mod inner { } pub(crate) fn query_distance(&mut self, _i: u64) {} - pub(crate) fn distance(&mut self, _i: u64) {} pub(crate) fn distance_ref(&self, _i: u64) {} pub(crate) fn get_vector(&mut self, _i: u64) {} pub(crate) fn get_vector_ref(&self, _i: u64) {} @@ -125,10 +124,6 @@ mod inner { self.query_distance += i; } - pub(crate) fn distance(&mut self, i: u64) { - *self.distance.get_mut() += i; - } - pub(crate) fn distance_ref(&self, i: u64) { self.distance.fetch_add(i, Relaxed); } diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs index efc3d63af..e146d4a87 100644 --- a/diskann-inmem/src/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -321,14 +321,6 @@ impl Registry { } } - #[cfg(test)] - fn snapshot(&self) -> Vec { - self.guards - .iter() - .map(|s| s.load(Ordering::Relaxed)) - .collect() - } - #[cfg(test)] fn waiting(&self) -> u64 { self.can_advance(&mut NoDelay).1 @@ -428,12 +420,7 @@ impl std::fmt::Display for Unavailable { impl std::error::Error for Unavailable {} -impl From for diskann::ANNError { - #[track_caller] - fn from(unavailable: Unavailable) -> Self { - diskann::ANNError::opaque(unavailable) - } -} +crate::opaque!(Unavailable); // Delays // diff --git a/diskann-inmem/src/integration/counters.rs b/diskann-inmem/src/integration/counters.rs index b6af15ef4..fcb16d6cd 100644 --- a/diskann-inmem/src/integration/counters.rs +++ b/diskann-inmem/src/integration/counters.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -/// A snapshot of global counters. +/// A snapshot of global [`Counters`](crate::counters::Counters). #[derive(Debug, Clone)] #[non_exhaustive] pub struct CounterSnapshot { diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs index c310b5d8f..eba8a9632 100644 --- a/diskann-inmem/src/integration/store.rs +++ b/diskann-inmem/src/integration/store.rs @@ -88,8 +88,4 @@ impl<'a> Writer<'a> { pub fn as_mut_slice(&mut self) -> &mut [u8] { self.slot.as_mut_slice() } - - pub fn slot(&self) -> u32 { - self.slot.slot() - } } diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 2f67da695..f1099316b 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -276,9 +276,9 @@ macro_rules! specialize { $( (Metric::$var, $N) => { let wrapped = Wrap::, $T, $U>::new($query); - return Ok(unsafe { + return Ok( $visitor.visit_sized::<{ $N * std::mem::size_of::<$U>() }, _>(wrapped) - }) + ) }, )* _ => {}, diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index 3cc796ea4..e7e1c18f2 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -84,7 +84,7 @@ pub trait QueryVisitor<'a>: Sized { where T: QueryDistance + 'a; - unsafe fn visit_sized(self, distance: T) -> Self::Output + fn visit_sized(self, distance: T) -> Self::Output where T: QueryDistance + 'a, { diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index b0f2c4d32..5a2934761 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -28,3 +28,17 @@ mod test; #[cfg(feature = "integration-test")] #[doc(hidden)] pub mod integration; + +macro_rules! opaque { + ($T:ty) => { + impl From<$T> for diskann::ANNError { + #[track_caller] + #[cold] + fn from(err: $T) -> diskann::ANNError { + diskann::ANNError::opaque(err) + } + } + }; +} + +pub(crate) use opaque; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index 507fd1b39..3425a2fc1 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -240,6 +240,8 @@ pub(crate) enum NeighborsError { #[error("index {} is out-of-bounds", self.0)] pub(crate) struct OutOfBounds(u32); +crate::opaque!(OutOfBounds); + /// A neighbor list was longer than the configured per-list capacity. /// /// `got` is the caller-supplied length (any `usize`); `max` is the per-list capacity, @@ -251,6 +253,8 @@ pub(crate) struct TooLong { max: u32, } +crate::opaque!(TooLong); + /// Errors during [`Neighbors::set`]. #[derive(Debug, Clone, Copy, Error)] pub(crate) enum SetError { @@ -263,6 +267,8 @@ pub(crate) enum SetError { TooLong(TooLong), } +crate::opaque!(SetError); + /// A locked adjacency list to implement atomic read-modify-write operations. /// /// Callers must not hold more than one `Lock` at a time. See [`LOCK_GRANULARITY`] for diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 27c3c2f1e..51d82dd4a 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -373,7 +373,7 @@ struct ExpandBeamVisitor { impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { type Output = ExpandBeam<'a>; - unsafe fn visit_sized(self, distance: T) -> Self::Output + fn visit_sized(self, distance: T) -> Self::Output where T: QueryDistance + 'a, { @@ -421,9 +421,14 @@ impl<'a> ExpandBeam<'a> { drop: drop::, }; - let ptr: *const T = Box::leak(Box::new(x)); + // NOTE: It's really important that we coerce the leaked `&mut` to a `*mut` rather + // than going straight to `*const` as the latter converts `&mut` to `&` first + // and then Miri gets upset when we try to drop it, thinking it's a shared reference + // rather than unique one. + let ptr: *mut T = Box::leak(Box::new(x)); + Self { - ptr: ptr.cast(), + ptr: ptr.cast_const().cast(), expand_beam, vtable: &vtable, lifetime: PhantomData, @@ -639,7 +644,7 @@ impl provider::NeighborAccessor for PruneAccessor<'_> { ) -> impl std::future::Future> + Send { let work = move || { self.counters.get_neighbors(1); - Ok(self.reader.neighbors().get(id, neighbors).unwrap()) + Ok(self.reader.neighbors().get(id, neighbors)?) }; ready(work) } @@ -653,7 +658,7 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { ) -> impl std::future::Future> + Send { let work = move || { self.counters.set_neighbors(1); - Ok(self.reader.neighbors().set(id, neighbors).unwrap()) + Ok(self.reader.neighbors().set(id, neighbors)?) }; ready(work) } @@ -665,12 +670,22 @@ impl provider::NeighborAccessorMut for PruneAccessor<'_> { ) -> impl std::future::Future> + Send { let work = move || -> ANNResult<()> { self.counters.append_vector(1); - self.reader - .neighbors() - .lock(id) - .unwrap() - .append(neighbors) - .unwrap(); + let lock = self.reader.neighbors().lock(id)?; + + // Due to race conditions between calls to `get_neighbors` and `append_vector` + // in `diskann` - it's possible that the state of the adjacency list has changed + // and we're now trying to add too many neighbors. + // + // We take care of that here by simply truncating. + // + // TODO: Introduce proper atomicity in the core algorithm. + if lock.len() + neighbors.len() > lock.capacity() { + let slack = lock.capacity() - lock.len(); + lock.append(&neighbors[..slack])?; + } else { + lock.append(neighbors)?; + } + Ok(()) }; diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 822b02be0..69ef3fc28 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -121,11 +121,6 @@ impl Store { (self.unfrozen as u32)..(self.buffer.len() as u32) } - /// Return the number of unfrozen slots managed by `self`. - pub(crate) fn capacity(&self) -> usize { - self.buffer.len() - self.unfrozen - } - /// Return the number of bytes occupied by each entry. pub(crate) fn bytes(&self) -> Bytes { self.unpadded @@ -349,6 +344,12 @@ impl Store { .get(i) .map(|tag| tag.load(Ordering::Relaxed).can_read()) } + + #[cfg(test)] + fn writable(&self) -> std::ops::Range { + 0..self.unfrozen as u32 + } + } #[derive(Debug, Error)] @@ -373,12 +374,6 @@ impl StoreError { } } -impl From for StoreError { - fn from(inner: StoreErrorInner) -> Self { - Self(inner) - } -} - #[derive(Debug, Error)] enum StoreErrorInner { #[error( @@ -563,7 +558,224 @@ impl Drop for Slot<'_> { // Tests // /////////// +/// These tests are basic functionality tests for the store. +/// +/// Longer running conurrency tests are in the integration test suite. #[cfg(test)] mod tests { use super::*; + + use diskann_utils::views::Matrix; + + // Build a store with `entries` writable slots of `entry_bytes` each, backed by `frozen` + // zeroed frozen points. The frozen points occupy the highest slot indices. + fn store(entries: usize, entry_bytes: usize, frozen: usize) -> Result { + let mut data = Matrix::new(0u8, frozen, entry_bytes); + let mut base = 0u8; + for row in data.row_iter_mut() { + row.fill(base); + base = base.wrapping_add(1); + } + + Store::new(entries, Bytes::new(entry_bytes), 0, data.as_view()) + } + + //------------------------// + // Constructor validation // + //------------------------// + + #[test] + fn new_rejects_mismatched_frozen_dim() { + // Frozen point has 8 columns but the store is asked for 16-byte entries. + let data = Matrix::new(0u8, 1, 8); + let err = Store::new(4, Bytes::new(16), 0, data.as_view()).unwrap_err(); + assert!(matches!( + err.0, + StoreErrorInner::MismatchedFrozenPointDim { dim: 8, .. } + )); + } + + #[test] + fn new_requires_a_frozen_point() { + let err = store(4, 8, 0).unwrap_err(); + assert!(matches!(err.0, StoreErrorInner::NeedFrozenPoint)); + } + + #[test] + fn new_rejects_total_slot_overflow() { + // `entries` alone fits in u32, but `entries + frozen` overflows it. + let data = Matrix::new(0u8, 1, 8); + let err = Store::new(u32::MAX as usize, Bytes::new(8), 0, data.as_view()).unwrap_err(); + assert!(matches!(err.0, StoreErrorInner::TooManyEntries { .. })); + } + + #[test] + fn new_rejects_too_many_neighbors() { + let data = Matrix::new(0u8, 1, 8); + let err = + Store::new(4, Bytes::new(8), u32::MAX.into_usize() + 1, data.as_view()).unwrap_err(); + assert!(matches!(err.0, StoreErrorInner::TooManyNeighbors { .. })); + } + + //--------// + // Layout // + //--------// + + #[test] + fn frozen_range_follows_writable_slots() { + let s = store(4, 8, 2).unwrap(); + + // Writable slots are [0, 4); frozen points occupy [4, 6). + assert_eq!(s.frozen(), 4..6); + + let reader = s.reader().unwrap(); + for i in 0..4 { + assert!(!s.can_read_approximate(i).unwrap()); + assert!(!reader.can_read(i).unwrap()); + assert!(reader.read(i).is_none()); + } + + assert!(s.can_read_approximate(4).unwrap()); + assert!(reader.can_read(4).unwrap()); + assert_eq!(reader.read(4).unwrap(), &[0, 0, 0, 0, 0, 0, 0, 0]); + + assert!(s.can_read_approximate(5).unwrap()); + assert!(reader.can_read(5).unwrap()); + assert_eq!(reader.read(5).unwrap(), &[1, 1, 1, 1, 1, 1, 1, 1]); + + assert!(s.can_read_approximate(6).is_none()); + assert!(reader.can_read(6).is_none()); + assert!(reader.read(6).is_none()); + } + + /////////////// + // Lifecycle // + /////////////// + + #[test] + fn acquire_write_publish_read_roundtrip() { + let s = store(4, 8, 1).unwrap(); + + let reader = s.reader().expect("reader guard available"); + + let idx = { + let mut slot = s.acquire().expect("a fresh store has free slots"); + let idx = slot.slot() as usize; + slot.as_mut_slice() + .copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + + // Before the slot is dropped - we should not be able to read it. + assert!(reader.read(idx).is_none()); + assert!(!s.can_read_approximate(idx).unwrap()); + + idx + }; + + assert_eq!(reader.read(idx), Some([1, 2, 3, 4, 5, 6, 7, 8].as_slice())); + assert!(s.can_read_approximate(idx).unwrap()); + } + + #[test] + fn acquire_exhausts_then_reports_none() { + let s = store(2, 8, 1).unwrap(); + // Hold the guards so the slots stay owned. + let _a = s.acquire().expect("first writable slot"); + let _b = s.acquire().expect("second writable slot"); + assert!( + s.acquire().is_none(), + "all writable slots are owned, so acquire must fail" + ); + } + + //--------// + // Retire // + //--------// + + #[test] + fn retire_out_of_bounds() { + let s = store(4, 8, 1).unwrap(); + assert!(matches!(s.retire(999), Err(RetireError::OutOfBounds))); + } + + #[test] + fn retire_rejects_reserved_slots() { + let s = store(4, 8, 1).unwrap(); + // An untouched writable slot is AVAILABLE, which is a reserved state. + assert!(matches!( + s.retire(0), + Err(RetireError::SlotIsReserved { .. }) + )); + // A frozen slot is likewise reserved. + let frozen = s.frozen().start as usize; + assert!(matches!( + s.retire(frozen), + Err(RetireError::SlotIsReserved { .. }) + )); + // An owned slot is not retirable. + let slot = s.acquire().unwrap(); + assert!(matches!( + s.retire(slot.slot() as usize), + Err(RetireError::SlotIsReserved { .. }) + )); + } + + #[test] + fn retire_published_slot_then_unreadable() { + let s = store(4, 8, 1).unwrap(); + + let idx = { + let slot = s.acquire().unwrap(); + slot.slot() as usize + // Publish + }; + + assert!(s.retire(idx).is_ok()); + + // A reader opened after retirement must not observe the retired slot. + let reader = s.reader().unwrap(); + assert_eq!(reader.read(idx), None); + assert_eq!(reader.can_read(idx), Some(false)); + + // The slot can also not be retired again. + assert!(matches!( + s.retire(idx), + Err(RetireError::SlotIsReserved { .. }) + )); + } + + //---------// + // Recycle // + //---------// + + #[test] + fn test_recycling() { + let entries = if cfg!(miri) { + 16 + } else { + 2048 + }; + + let s = store(entries, 4, 2).unwrap(); + + // Claim all slots. + let mut count = 0; + while let Some(slot) = s.acquire() { + count += 1; + } + + assert_eq!(count, s.writable().len()); + + // Now that all slots are claimed - retire all slots. + for i in s.writable() { + s.retire(i.into_usize()).unwrap(); + } + + // Verify that we can claim all slots again. + let mut count = 0; + while let Some(slot) = s.acquire() { + count += 1; + } + + assert_eq!(count, s.writable().len()); + } } From e030eec26c8cefd7d030772f249b3144d97909b1 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 25 Jun 2026 16:26:46 -0700 Subject: [PATCH 25/45] Checkpoint. --- diskann-inmem/integration/main.rs | 1 + diskann-inmem/src/layers/full.rs | 459 ++++++++++++++++++++++-------- diskann-inmem/src/layers/mod.rs | 113 +++++--- diskann-inmem/src/provider.rs | 5 + diskann-inmem/src/sharded.rs | 177 +++++++++++- diskann-inmem/src/store.rs | 72 ++++- 6 files changed, 645 insertions(+), 182 deletions(-) diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs index 446543808..ddc69289d 100644 --- a/diskann-inmem/integration/main.rs +++ b/diskann-inmem/integration/main.rs @@ -239,6 +239,7 @@ mod tests { } #[test] + #[cfg(not(miri))] fn graph_index() { run_regression_example( "integration.json", diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index f1099316b..07c491dc4 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -9,7 +9,10 @@ use diskann::{ANNError, ANNResult}; use diskann_vector::{ UnalignedSlice, conversion::SliceCast, - distance::{self, DistanceProvider, InnerProduct, Metric, Specialize, SquaredL2}, + distance::{ + self, Cosine, CosineNormalized, DistanceProvider, InnerProduct, Metric, Specialize, + SquaredL2, + }, }; use diskann_wide::{ ARCH, @@ -28,13 +31,13 @@ where { distance: Distance, metric: Metric, - _type: std::marker::PhantomData, } impl Full where T: 'static, { + /// Create a new full-precision layer for data with the given `dim` and `metric`. pub fn new(dim: usize, metric: Metric) -> Self where T: DistanceProvider, @@ -44,20 +47,29 @@ where dim, }; - Self { - distance, - metric, - _type: std::marker::PhantomData, - } + Self { distance, metric } } + /// Return the logical dimension of the data handled by this [`Layer`]. pub fn dim(&self) -> usize { self.distance.dim } + /// Return the number of bytes of the data handles by this [`Layer`]. pub fn bytes(&self) -> Bytes { Bytes::new(self.dim() * std::mem::size_of::()) } + + fn check_dim(&self, dim: usize) -> Result<(), QueryDistanceError> { + if self.dim() != dim { + Err(QueryDistanceError { + expected: self.dim(), + xlen: dim, + }) + } else { + Ok(()) + } + } } impl layers::Layer for Full @@ -74,12 +86,41 @@ where T: bytemuck::Pod + Send + Sync, { fn into_bytes(&self, v: &[T], bytes: &mut [u8]) -> ANNResult<()> { - assert_eq!(self.dim(), v.len()); - bytes.copy_from_slice(bytemuck::must_cast_slice::(v)); - Ok(()) + if v.len() != self.dim() { + Err(ANNError::from(SetError::Dim { + got: v.len(), + expected: self.dim(), + })) + } else if bytes.len() != self.bytes().value() { + Err(ANNError::from(SetError::Bytes { + got: bytes.len(), + expected: self.bytes().value(), + })) + } else { + bytes.copy_from_slice(bytemuck::must_cast_slice::(v)); + Ok(()) + } } } +#[derive(Debug, Error)] +enum SetError { + #[error( + "data of dimension {} does not match full precision layer's dimension {}", + got, + expected + )] + Dim { got: usize, expected: usize }, + #[error( + "raw byte slice of length {} does not match expected length {}", + got, + expected + )] + Bytes { got: usize, expected: usize }, +} + +crate::opaque!(SetError); + impl layers::AsDistance for Full where T: Debug + Send + Sync + 'static, @@ -140,7 +181,7 @@ where } fn bytes(&self) -> usize { - self.dim * std::mem::size_of::() + self.dim() * std::mem::size_of::() } } @@ -195,66 +236,62 @@ impl std::ops::Deref for Calf<'_, T> { } } -impl<'a, T> From<&'a [T]> for Calf<'a, T> { - fn from(slice: &'a [T]) -> Self { - Self::Borrowed(slice) - } -} - -impl From> for Calf<'_, T> { - fn from(boxed: Box<[T]>) -> Self { - Self::Owned(boxed) - } -} - +/// A fused query distance based on [`PureDistanceFunction`] to enable inlining of the final +/// distance function (`D`). +/// +/// The type of the embedded query (`T`) is distinct from the expected data-set (`U`) to +/// allow `f16` queries to be pre-converted to `f32`, saving on-the-fly conversion that +/// would otherwise be needed. #[derive(Debug)] -struct QueryDistance<'a, T, U> -where - T: 'static, - U: 'static, -{ - distance: Distance, +struct QueryDistance<'a, T, U, D> { query: Calf<'a, T>, + // The type of the data in the original dataset. + _data: PhantomData, + // The type of the `PureDistanceFunction` used for the implementation. + _distance: PhantomData, } -impl<'a, T, U> QueryDistance<'a, T, U> -where - T: 'static, - U: 'static, -{ - fn new(distance: Distance, query: Calf<'a, T>) -> Self { - if query.len() != distance.dim() { - panic!("oops"); +impl<'a, T, U, D> QueryDistance<'a, T, U, D> { + fn new(query: Calf<'a, T>) -> Self { + Self { + query, + _data: PhantomData, + _distance: PhantomData, } + } - Self { distance, query } + fn bytes(&self) -> usize { + std::mem::size_of::() * self.query.len() } - #[cold] #[inline(never)] - fn error(&self, x: &[u8]) -> ANNResult { + fn error(&self, len: usize) -> ANNResult { let error = QueryDistanceError { - expected: self.distance.bytes(), - xlen: x.len(), + expected: self.bytes(), + xlen: len, }; Err(ANNError::opaque(error)) } } -impl layers::QueryDistance for QueryDistance<'_, T, U> +impl layers::QueryDistance for QueryDistance<'_, T, U, D> where - T: Debug + Sync + Send + 'static, - U: Debug + Sync + Send + 'static, + T: Send + Sync + 'static + Debug, + U: Send + Sync + 'static + Debug, + D: for<'a> FTarget2, UnalignedSlice<'a, U>> + + Send + + Sync + + Debug, { + #[inline(always)] fn evaluate(&self, x: &[u8]) -> ANNResult { - if x.len() != self.distance.bytes() { - self.error(x) + if x.len() != self.bytes() { + self.error(x.len()) } else { - Ok(self.distance.f.call_unaligned( - unsafe { UnalignedSlice::new(self.query.as_ptr().cast::(), self.distance.dim) }, - unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.distance.dim) }, - )) + // SAFETY: We've validated that `x` has the correct length. + let x = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.query.len()) }; + Ok(D::run(ARCH, (*self.query).into(), x)) } } } @@ -270,20 +307,23 @@ struct QueryDistanceError { xlen: usize, } -macro_rules! specialize { - ($me:ident, $query:ident, $visitor:ident, ($T:ty, $U:ty), $(($var:ident, $N:literal, $f:ty)),* $(,)?) => { - match ($me.metric, $me.dim()) { - $( - (Metric::$var, $N) => { - let wrapped = Wrap::, $T, $U>::new($query); - return Ok( - $visitor.visit_sized::<{ $N * std::mem::size_of::<$U>() }, _>(wrapped) - ) - }, - )* - _ => {}, - } - } +crate::opaque!(QueryDistanceError); + +macro_rules! mint { + ($query:ident, $visitor:ident, $T:ty => { $N:literal, $f:ident }) => {{ + mint!($query, $visitor, { $T, $T } => { $N x $f }) + }}; + ($query:ident, $visitor:ident, { $T:ty, $U:ty } => { $N:literal x $f:ident }) => {{ + let inner = QueryDistance::<$T, $U, Specialize<$N, $f>>::new($query); + $visitor.visit_sized::<{ $N * std::mem::size_of::<$U>() }, _>(inner) + }}; + ($query:ident, $visitor:ident, $T:ty => $f:ident) => {{ + mint!($query, $visitor, { $T, $T } => $f) + }}; + ($query:ident, $visitor:ident, { $T:ty, $U:ty } => $f:ident) => {{ + let inner = QueryDistance::<$T, $U, $f>::new($query); + $visitor.visit(inner) + }}; } impl layers::Search for Full { @@ -293,19 +333,24 @@ impl layers::Search for Full { where V: layers::QueryVisitor<'a>, { + self.check_dim(query.len())?; + let query = Calf::Borrowed(query); - specialize!( - self, - query, - visitor, - (f32, f32), - (L2, 100, SquaredL2), - (InnerProduct, 768, InnerProduct), - ); + let output = match self.metric { + Metric::L2 => { + if self.dim() == 100 { + mint!(query, visitor, f32 => { 100, SquaredL2 }) + } else { + mint!(query, visitor, f32 => SquaredL2) + } + } + Metric::InnerProduct => mint!(query, visitor, f32 => InnerProduct), + Metric::Cosine => mint!(query, visitor, f32 => Cosine), + Metric::CosineNormalized => mint!(query, visitor, f32 => CosineNormalized), + }; - // Fallback - Ok(visitor.visit(QueryDistance::new(self.distance, query))) + Ok(output) } } @@ -316,26 +361,20 @@ impl layers::Search for Full { where V: layers::QueryVisitor<'a>, { + self.check_dim(query.len())?; + let mut as_f32: Box<[f32]> = std::iter::repeat_n(0.0, self.dim()).collect(); diskann_wide::arch::dispatch2(SliceCast::new(), &mut *as_f32, query); let query = Calf::Owned(as_f32); - specialize!( - self, - query, - visitor, - (f32, f16), - (L2, 100, SquaredL2), - (InnerProduct, 768, InnerProduct), - ); - - // Fallback - let distance = Distance { - f: >::distance_comparer(self.metric, Some(self.dim())), - dim: self.dim(), + let output = match self.metric { + Metric::L2 => mint!(query, visitor, { f32, f16 } => SquaredL2), + Metric::InnerProduct => mint!(query, visitor, { f32, f16 } => InnerProduct), + Metric::Cosine => mint!(query, visitor, { f32, f16 } => Cosine), + Metric::CosineNormalized => mint!(query, visitor, { f32, f16 } => CosineNormalized), }; - Ok(visitor.visit(QueryDistance::new(distance, query))) + Ok(output) } } @@ -346,12 +385,18 @@ impl layers::Search for Full { where V: layers::QueryVisitor<'a>, { + self.check_dim(query.len())?; + let query = Calf::Borrowed(query); - specialize!(self, query, visitor, (u8, u8), (L2, 128, SquaredL2)); + let output = match self.metric { + Metric::L2 => mint!(query, visitor, u8 => SquaredL2), + Metric::InnerProduct => mint!(query, visitor, u8 => InnerProduct), + Metric::Cosine => mint!(query, visitor, u8 => Cosine), + Metric::CosineNormalized => mint!(query, visitor, u8 => Cosine), + }; - // Fallback - Ok(visitor.visit(QueryDistance::new(self.distance, query))) + Ok(output) } } @@ -362,40 +407,18 @@ impl layers::Search for Full { where V: layers::QueryVisitor<'a>, { - let query = Calf::Borrowed(query); - Ok(visitor.visit(QueryDistance::new(self.distance, query))) - } -} + self.check_dim(query.len())?; -#[derive(Debug)] -struct Wrap<'a, I, T, U> { - query: Calf<'a, T>, - ps: PhantomData<(I, U)>, -} + let query = Calf::Borrowed(query); -impl<'a, I, T, U> Wrap<'a, I, T, U> { - fn new(query: Calf<'a, T>) -> Self { - Self { - query, - ps: PhantomData, - } - } -} + let output = match self.metric { + Metric::L2 => mint!(query, visitor, i8 => SquaredL2), + Metric::InnerProduct => mint!(query, visitor, i8 => InnerProduct), + Metric::Cosine => mint!(query, visitor, i8 => Cosine), + Metric::CosineNormalized => mint!(query, visitor, i8 => Cosine), + }; -impl layers::QueryDistance for Wrap<'_, I, T, U> -where - I: for<'a> FTarget2, UnalignedSlice<'a, U>> - + Send - + Sync - + Debug, - T: Send + Sync + 'static + Debug, - U: Send + Sync + 'static + Debug, -{ - #[inline(always)] - fn evaluate(&self, x: &[u8]) -> ANNResult { - // TODO: This is not fully valid - we need to check. - let x = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.query.len()) }; - Ok(I::run(ARCH, (*self.query).into(), x)) + Ok(output) } } @@ -405,5 +428,195 @@ where #[cfg(test)] mod tests { + use std::fmt::Display; + + use rand::{Rng, SeedableRng, rngs::StdRng}; + use super::*; + // Bring the inherent-call traits into method scope. The `Distance` / `QueryDistance` + // traits are not imported: their methods are reached through `&dyn _` trait objects, + // which does not require the trait to be in scope. + use crate::layers::{AsDistance as _, QueryVisitor, Search as _, Set as _}; + + /// Generate random elements of a layer's data type from a seeded RNG. + trait Sample: bytemuck::Pod { + fn sample(rng: &mut R) -> Self; + } + + impl Sample for f32 { + fn sample(rng: &mut R) -> Self { + rng.random_range(-1.0f32..1.0f32) + } + } + + impl Sample for f16 { + fn sample(rng: &mut R) -> Self { + f16::from_f32(rng.random_range(-1.0f32..1.0f32)) + } + } + + impl Sample for u8 { + fn sample(rng: &mut R) -> Self { + rng.random() + } + } + + impl Sample for i8 { + fn sample(rng: &mut R) -> Self { + rng.random() + } + } + + fn gen_vec(rng: &mut R, dim: usize) -> Vec { + (0..dim).map(|_| T::sample(rng)).collect() + } + + /// A [`QueryVisitor`] that simply boxes the minted kernel so the test can probe it + /// directly. Exercises both `visit` (dynamic) and `visit_sized` (specialized) paths. + struct Collect; + + impl<'a> QueryVisitor<'a> for Collect { + type Output = Box; + + fn visit(self, distance: Q) -> Self::Output + where + Q: layers::QueryDistance + 'a, + { + Box::new(distance) + } + } + + /// Compare two distances allowing for floating-point reassociation between the + /// specialized / converted kernels and the dynamic reference. + fn approx_eq(got: f32, want: f32) -> bool { + (got - want).abs() <= 1e-3 + 1e-4 * want.abs() + } + + /// Exercise every `Full` API across dimensions `1..=max_dim`. + /// + /// For each dimension we check that `bytes`/`into_bytes` agree, that `distance` and + /// `query_distance` are consistent with `DistanceProvider`, and that all of these + /// reject byte slices that are too long or too short. + fn test_impl(max_dim: usize, ctx: &dyn Display) + where + T: Sample + Debug + Send + Sync + DistanceProvider + 'static, + Full: for<'a> layers::Search = &'a [T]>, + { + let mut rng = StdRng::seed_from_u64(0x0D15_0ACE ^ max_dim as u64); + let metrics = [ + Metric::L2, + Metric::InnerProduct, + Metric::Cosine, + Metric::CosineNormalized, + ]; + + for dim in 1..=max_dim { + let a = gen_vec::(&mut rng, dim); + let b = gen_vec::(&mut rng, dim); + + // `bytes` and `into_bytes` agree: the encoded buffer equals the raw cast bytes. + let layer = Full::::new(dim, Metric::L2); + assert_eq!( + layer.bytes().value(), + dim * std::mem::size_of::(), + "{ctx}: dim {dim}: unexpected byte length", + ); + + let mut a_bytes = vec![0u8; layer.bytes().value()]; + layer.into_bytes(&a, &mut a_bytes).unwrap(); + assert_eq!( + a_bytes.as_slice(), + bytemuck::cast_slice::(&a), + "{ctx}: dim {dim}: into_bytes mismatch", + ); + + let mut b_bytes = vec![0u8; layer.bytes().value()]; + layer.into_bytes(&b, &mut b_bytes).unwrap(); + + for metric in metrics { + let full = Full::::new(dim, metric); + + // Reference value straight from `DistanceProvider`. + let reference = + >::distance_comparer(metric, Some(dim)).call(&a, &b); + + // `distance` is built from the same comparer, so it must match exactly. + let distance = full.as_distance(); + let via_distance = distance.evaluate(&a_bytes, &b_bytes).unwrap(); + assert_eq!( + via_distance, reference, + "{ctx}: dim {dim}, metric {metric:?}: distance != DistanceProvider", + ); + + // `query_distance` computes the same geometry. Specialized and f16-converted + // kernels may reassociate the summation, so compare approximately. + let query = full.query_distance(a.as_slice(), Collect).unwrap(); + let via_query = query.evaluate(&b_bytes).unwrap(); + assert!( + approx_eq(via_query, via_distance), + "{ctx}: dim {dim}, metric {metric:?}: query {via_query} != distance {via_distance}", + ); + + // Every distance API rejects byte slices that are too long or too short. + let short = &a_bytes[..a_bytes.len() - 1]; + let mut long = a_bytes.clone(); + long.push(0); + + assert!(distance.evaluate(short, &b_bytes).is_err()); + assert!(distance.evaluate(&long, &b_bytes).is_err()); + assert!(distance.evaluate(&a_bytes, short).is_err()); + assert!(distance.evaluate(&a_bytes, &long).is_err()); + + assert!(query.evaluate(short).is_err()); + assert!(query.evaluate(&long).is_err()); + } + + // `into_bytes` rejects mis-sized element and buffer slices. + let mut buf = vec![0u8; layer.bytes().value()]; + let too_many = gen_vec::(&mut rng, dim + 1); + assert!( + layer.into_bytes(&too_many, &mut buf).is_err(), + "{ctx}: dim {dim}: into_bytes accepted an over-long element slice", + ); + + assert!( + layer.query_distance(&too_many, Collect).is_err(), + "{ctx}: dim {dim}: incorrect query lengths should be rejected" + ); + + let mut short_buf = vec![0u8; layer.bytes().value().saturating_sub(1)]; + assert!( + layer.into_bytes(&a, &mut short_buf).is_err(), + "{ctx}: dim {dim}: into_bytes accepted an under-sized buffer", + ); + + let too_few = gen_vec::(&mut rng, dim - 1); + assert!( + layer.query_distance(&too_few, Collect).is_err(), + "{ctx}: dim {dim}: incorrect query lengths should be rejected" + ); + } + } + + // `max_dim` must exceed the largest specialized dimension for each type so the + // const-generic (`visit_sized`) paths are covered alongside the dynamic ones. + #[test] + fn full_f32() { + test_impl::(256, &"f32"); + } + + #[test] + fn full_f16() { + test_impl::(256, &"f16"); + } + + #[test] + fn full_u8() { + test_impl::(160, &"u8"); + } + + #[test] + fn full_i8() { + test_impl::(160, &"i8"); + } } diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index e7e1c18f2..b3b1cc187 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -3,87 +3,107 @@ * Licensed under the MIT license. */ +//! Distance layers indexing. +//! +//! An important assumption made by this module is that the data within each layer is +//! uniformly sized: each entry occupies the same number of bytes. Furthermore, the data +//! to be stored may not assume any particular alignment. Implementations will strive to +//! achieve a reasonable alignment, but this may not be relied on. +//! +//! # Query Distance Specialization +//! +//! The design of this module allows aggressive optimization of graph search kernels via +//! the [`Search`] and [`QueryVisitor`] pairs of traits. +//! +//! Implementations of [`Search`] can pass an [`QueryDistance`] kernels specialized to +//! to a specific geometry (dimensionality or metric type) which upstream [`QueryVisitor`] +//! will fuse into larger kernels. While this allows for high performance graph kernels, +//! some considerations should be taken into account: +//! +//! 1. For correctness purposes, upstream callers cannot do any kind of caching. As such, +//! the dispatch layer used to select the kernel passed to the [`QueryVisitor`] should +//! be relatively efficient. +//! +//! 2. Keep the number of specializations bounded for compile time reasons. + use diskann::ANNResult; -use diskann_vector::DistanceFunction; use crate::num::Bytes; -pub(crate) mod full; +mod full; pub use full::Full; -pub trait AddLifetime: Send + Sync + 'static { - type Of<'a>: Send + Sync; -} - -#[derive(Debug)] -pub struct Slice(std::marker::PhantomData); - -impl Slice { - pub fn new() -> Self { - Self(std::marker::PhantomData) - } -} - -impl Clone for Slice { - fn clone(&self) -> Self { - *self - } +/// Base layer for data representations. +pub trait Layer: Send + Sync + 'static { + /// Return the number of bytes needed by this layer representation. + /// + /// To be well-behaved, this function must be idempotent. + fn bytes(&self) -> Bytes; } -impl Copy for Slice {} - -impl Default for Slice { - fn default() -> Self { - Self::new() - } +/// Store a element of type `T` into a raw byte buffer. +/// +/// Implementations may assume that `bytes.len()` is equal to [`Layer::bytes`]. +pub trait Set: Layer { + /// Write into the stored representation. + fn into_bytes(&self, element: T, bytes: &mut [u8]) -> ANNResult<()>; } +/// A distance computation on raw byte slices. +/// +/// When paired with [`Layer`] via helpers like [`AsDistance`], implementations may assume +/// that `x` and `y` have length [`Layer::bytes`]. +/// +/// No alignment guarantees are made for `x` and `y`, though in practice they are likely +/// to be aligned to 32 or 64 bytes. pub trait Distance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult; } +/// Return a [`Distance`] function for a [`Layer`]. pub trait AsDistance: Send + Sync + std::fmt::Debug { fn as_distance(&self) -> &dyn Distance; } -impl DistanceFunction<&[u8], &[u8]> for &dyn Distance { - fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { - self.evaluate(x, y).unwrap() - } -} - +/// A unary query distance on raw bytes slices. +/// +/// When paired with [`Layer`] via helpers like [`Search`], implementations may assume +/// that `x` and `y` have length [`Layer::bytes`]. +/// +/// No alignment guarantees are made for `x` and `y`, though in practice they are likely +/// to be aligned to 32 or 64 bytes. pub trait QueryDistance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8]) -> ANNResult; } -pub trait Layer: Send + Sync + 'static { - /// Return the number of bytes needed by this layer representation. - /// - /// To be well-behaved, this function must be idempotent. - fn bytes(&self) -> Bytes; -} - -pub trait Set: Layer { - /// Write into the stored representation. - fn into_bytes(&self, element: T, bytes: &mut [u8]) -> ANNResult<()>; -} - -// Meta traits for `Search` and `Insert` compatibility. +/// Enable search over vectors defined by a [`Layer`]. pub trait Search: Send + Sync + 'static { + /// The type of the query. This should be equivalent to the generic parameter in + /// [`Set`], but needs to be replicated here due to limitations in the current trait + /// design. type Query<'a>; + /// Create a distance computer specialized for `query` and provide it to `visitor`. fn query_distance<'a, V>(&'a self, query: Self::Query<'a>, visitor: V) -> ANNResult where V: QueryVisitor<'a>; } +/// Specialize a kernel around a [`QueryDistance`] implementation. pub trait QueryVisitor<'a>: Sized { + /// The type of the type-erased output. type Output; + /// Specialize [`Self::Output`] for `distance`. fn visit(self, distance: T) -> Self::Output where T: QueryDistance + 'a; + /// Specialize [`Self::Output`] for `distance` accepting a hint that `distance` has been + /// specialized to work on data elements of exactly `BYTES` bytes long. + /// + /// This can be used to tailor surrounding code (e.g. software prefetches) for exactly + /// the length of the data being processed. fn visit_sized(self, distance: T) -> Self::Output where T: QueryDistance + 'a, @@ -92,7 +112,12 @@ pub trait QueryVisitor<'a>: Sized { } } +/// A insert-specific specialization of [`Search`]. +/// +/// Note that the bounds for this trait are unnecessarily complicated, but rely on changes +/// to `diskann` to full resolve. pub trait Insert: Search + for<'a> Set> + AsDistance { + /// A specialization of [`Search::query_distance`] targeting vector insert specifically. fn insert_distance<'a, V>(&'a self, query: Self::Query<'a>, visitor: V) -> ANNResult where V: QueryVisitor<'a>, diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 51d82dd4a..a8e08c191 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -357,6 +357,11 @@ impl glue::SearchAccessor for SearchAccessor<'_> { } } +// trait ExpandBeam2: Send + Sync + Debug { +// fn evaluate(&self, x: &[u8]) -> ANNResult; +// +// } + type FExpandBeam = unsafe fn( *const (), &[u32], diff --git a/diskann-inmem/src/sharded.rs b/diskann-inmem/src/sharded.rs index ef64fa01c..946d484db 100644 --- a/diskann-inmem/src/sharded.rs +++ b/diskann-inmem/src/sharded.rs @@ -45,10 +45,6 @@ where } } - pub(crate) fn capacity(&self) -> usize { - self.capacity - } - /// Establish a mapping between `external` and `internal`. /// /// # Errors @@ -152,6 +148,12 @@ where inner: i % SHARD_SIZE, } } + + #[cfg(test)] + fn capacity(&self) -> usize { + self.capacity + } + } struct Shard { @@ -190,12 +192,171 @@ where *self.forward.get() } - pub(crate) fn external(&self) -> &I { - self.forward.key() - } - pub(crate) fn delete(mut self) { self.forward.remove(); self.backward[self.entry] = None; } + + #[cfg(test)] + fn external(&self) -> &I { + self.forward.key() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_reports_capacity() { + for capacity in [0, 1, SHARD_SIZE - 1, SHARD_SIZE, SHARD_SIZE + 1, 3 * SHARD_SIZE] { + let map = Sharded::::new(capacity); + assert_eq!(map.capacity(), capacity); + } + } + + #[test] + fn insert_round_trips() { + let map = Sharded::::new(16); + assert!(map.insert(100, 3).is_ok()); + + assert_eq!(map.to_internal(&100), Some(3)); + assert_eq!(map.to_external(3), Some(100)); + assert!(map.contains_external(&100)); + + // Unmapped ids return nothing. + assert_eq!(map.to_internal(&101), None); + assert_eq!(map.to_external(4), None); + assert!(!map.contains_external(&101)); + } + + #[test] + fn insert_rejects_out_of_bounds_internal() { + let map = Sharded::::new(16); + assert!(matches!(map.insert(0, 16), Err(InsertError::OutOfBounds))); + assert!(matches!( + map.insert(0, u32::MAX), + Err(InsertError::OutOfBounds) + )); + + // The largest in-bounds id is accepted. + assert!(map.insert(0, 15).is_ok()); + } + + #[test] + fn insert_rejects_duplicate_external_and_preserves_state() { + let map = Sharded::::new(16); + map.insert(7, 5).unwrap(); + + assert!(matches!( + map.insert(7, 6), + Err(InsertError::ExternalExists) + )); + + // The failed insert must not have established any partial mapping. + assert_eq!(map.to_internal(&7), Some(5)); + assert_eq!(map.to_external(6), None); + assert!(!map.contains_external(&6)); + } + + #[test] + fn insert_rejects_duplicate_internal_and_preserves_state() { + let map = Sharded::::new(16); + map.insert(7, 5).unwrap(); + + assert!(matches!( + map.insert(8, 5), + Err(InsertError::InternalExists) + )); + + // The failed insert must not have established any partial mapping. + assert_eq!(map.to_external(5), Some(7)); + assert_eq!(map.to_internal(&8), None); + assert!(!map.contains_external(&8)); + } + + #[test] + fn to_external_handles_bounds_and_empty_slots() { + let map = Sharded::::new(16); + // In-bounds but unmapped slot. + assert_eq!(map.to_external(5), None); + // Out-of-bounds slot. + assert_eq!(map.to_external(16), None); + } + + #[test] + fn mappings_span_shard_boundaries() { + let capacity = 3 * SHARD_SIZE; + let map = Sharded::::new(capacity); + + // Ids straddling every internal shard boundary. + let ids: [u32; 6] = [ + 0, + (SHARD_SIZE - 1) as u32, + SHARD_SIZE as u32, + (2 * SHARD_SIZE - 1) as u32, + (2 * SHARD_SIZE) as u32, + (capacity - 1) as u32, + ]; + + for (external, &internal) in ids.iter().enumerate() { + map.insert(external as u32, internal).unwrap(); + } + + for (external, &internal) in ids.iter().enumerate() { + assert_eq!(map.to_internal(&(external as u32)), Some(internal)); + assert_eq!(map.to_external(internal), Some(external as u32)); + } + } + + #[test] + fn lookup_supports_borrowed_query() { + let map = Sharded::::new(16); + map.insert("alpha".to_string(), 1).unwrap(); + + // Borrowed `&str` lookups against `String` keys. + assert!(map.contains_external("alpha")); + assert_eq!(map.to_internal("alpha"), Some(1)); + assert!(!map.contains_external("beta")); + assert_eq!(map.to_internal("beta"), None); + } + + #[test] + fn occupied_entry_exposes_mapping() { + let map = Sharded::::new(16); + map.insert(42, 9).unwrap(); + + let entry = map.occupied_entry(42).expect("entry should exist"); + assert_eq!(entry.internal(), 9); + assert_eq!(*entry.external(), 42); + } + + #[test] + fn occupied_entry_absent_for_unmapped() { + let map = Sharded::::new(16); + assert!(map.occupied_entry(42).is_none()); + } + + #[test] + fn entry_delete_clears_both_directions() { + let map = Sharded::::new(16); + map.insert(42, 9).unwrap(); + + // Just creating and dropping an `occupied_entry` does not clear it. + { + let _ = map.occupied_entry(42).unwrap(); + assert!(map.contains_external(&42)); + assert_eq!(map.to_internal(&42), Some(9)); + assert_eq!(map.to_external(9), Some(42)); + } + + map.occupied_entry(42).expect("entry should exist").delete(); + + assert!(!map.contains_external(&42)); + assert_eq!(map.to_internal(&42), None); + assert_eq!(map.to_external(9), None); + + // The freed external and internal ids can be reused. + assert!(map.insert(42, 9).is_ok()); + } } diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 69ef3fc28..f5d5ab38b 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -3,6 +3,54 @@ * Licensed under the MIT license. */ +//! A concurrent in-memory data store for uniformly sized data. +//! +//! This supports concurrent data access, deletes, and inserts through a safe interface. +//! Data is stored internally in slots indexed from `[0..N)` with `K` points reserved at the +//! end at positions `[N..N+K)`. +//! +//! ## Reading +//! +//! Read access requires a [`Reader`] produced by [`Store::reader`]. [`Reader::read`] +//! provides read-only access to data at slot `i` if the data is valid for reads. +//! +//! ## Writing +//! +//! [`Store::acquire`] is used to find and claim an unused internal [`Slot`]. A [`Slot`] +//! provides write access to its coresponding data which is published when the [`Slot`] is +//! dropped. +//! +//! The index of the slot chosen may be obtained via [`Slot::slot`].k +//! +//! ## Deleting +//! +//! Data is deleted via [`Store::retire`]. This immediately marks the corresponding slot as +//! unavailable for future readers. However, the retired slot will not be reused until the +//! [`Store`] can guarantee that no [`Reader`]s that could be using the data are active. +//! +//! Slots are automatically reclaimed as part of slot acquisition in the "writing" phase. +//! +//! ## Neighbor Access +//! +//! The [`Store`] also contains a [`Neighbors`] instance to store adjacency lists. Since +//! neighbors are generally accessed less frequently than data with a higher volume of write +//! traffic, fine-grained locks are used for this data structure. +//! +//! # Details +//! +//! This uses an implementation of the epoch-based reclamation (EBR) provided by [`Registry`]. +//! Concurrency tags are mirrored inline with the stored data (just after the data payload) +//! to keep memory access localized. As such, high-performance implementations will want to +//! fetch the last cache line of data first to ensure the tag is resident in cache for faster +//! data checks. +//! +//! The EBR scheme allows readers to safely access data while only generating read traffic to +//! the CPU caches. The cost is that there is a delay between when slots are retired and when +//! they can be reused, with a long lived [`Reader`] blocking this reclamation. As such, +//! users of this data structure should ensure that [`Reader`]s are reasonably short lived. +//! +//! Internally, the data belongs to a single allocation. + use std::{ iter::repeat_n, num::{NonZeroU32, NonZeroUsize}, @@ -22,20 +70,32 @@ use crate::{ tag::{AtomicTag, Tag}, }; +/// A concurrent data and graph store. #[derive(Debug)] pub(crate) struct Store { // The invasive store where concurrency tags are stored inline with the data. // // These tags are mirrored from `tags` - with the latter being used for secondary scans // offering slightly better locality. + // + // The inline tags are stored after the data. buffer: Buffer, + + // The unpadded size of each row in `buffer`. This includes both the data **and** the + // 1-byte tag. Tags are located at byte `unpadded - 1`. unpadded: Bytes, // The number of unfrozen points. This is guaranteed to be less than `buffer`. unfrozen: usize, + + // The authoritative source of truth for the state of each slot. tags: Vec, freelist: Freelist, + + // EBR registry. registry: Registry, + + // Graph. neighbors: Neighbors, } @@ -46,6 +106,8 @@ const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); const RETRY_LIMIT: usize = 20; impl Store { + /// Create a new [`Store`] capable of holding [`entries`] non-frozen slots each of + /// length `bytes`. pub(crate) fn new( entries: usize, bytes: Bytes, @@ -349,9 +411,9 @@ impl Store { fn writable(&self) -> std::ops::Range { 0..self.unfrozen as u32 } - } +/// Errors occurring during [`Store::new`]. #[derive(Debug, Error)] #[error(transparent)] pub(crate) struct StoreError(StoreErrorInner); @@ -411,7 +473,7 @@ pub(crate) enum RetireError { CouldNotClaimSlot, } -/// An epoch protect reader into [`Store`]. +/// An epoch protected reader into a [`Store`]. /// /// Created via [`Store::reader`]. #[derive(Debug)] @@ -749,11 +811,7 @@ mod tests { #[test] fn test_recycling() { - let entries = if cfg!(miri) { - 16 - } else { - 2048 - }; + let entries = if cfg!(miri) { 16 } else { 2048 }; let s = store(entries, 4, 2).unwrap(); From 6170d3b38fb7f090a1ad20d0994ab96e513daca7 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 25 Jun 2026 17:53:26 -0700 Subject: [PATCH 26/45] Test coverage looking good. --- diskann-inmem/Cargo.toml | 1 + diskann-inmem/integration/index/index.rs | 8 +- diskann-inmem/integration/support/datatype.rs | 9 - diskann-inmem/src/epoch.rs | 1 - diskann-inmem/src/layers/full.rs | 9 +- diskann-inmem/src/num.rs | 39 +- diskann-inmem/src/provider.rs | 379 ++++++++++++------ diskann-inmem/src/sharded.rs | 22 +- diskann-inmem/src/store.rs | 54 ++- diskann-inmem/src/tag.rs | 13 + 10 files changed, 339 insertions(+), 196 deletions(-) diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index 82d17e1a8..d61b9f889 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -32,6 +32,7 @@ tokio = { workspace = true, optional = true } workspace = true [dev-dependencies] +diskann = { workspace = true, features = ["testing"] } rand = { workspace = true } tokio = { workspace = true, features = ["macros"] } tempfile = { workspace = true } diff --git a/diskann-inmem/integration/index/index.rs b/diskann-inmem/integration/index/index.rs index 6a630a2e9..43bd0a7e5 100644 --- a/diskann-inmem/integration/index/index.rs +++ b/diskann-inmem/integration/index/index.rs @@ -18,12 +18,10 @@ use diskann_inmem::{Context, Provider, Strategy, integration, layers}; use crate::support::{ check::{CheckMatch, Match, check_all_fields}, - datatype::{AsDataType, DataType, FromSlice, Slice}, + datatype::{AsDataType, FromSlice, Slice}, }; pub(crate) trait Index { - fn data_type(&self) -> DataType; - fn search<'a>( &'a self, query: Slice<'a>, @@ -193,10 +191,6 @@ where layers::Full: for<'a> layers::Insert = &'a [T]>, T: FromSlice + AsDataType + Send + Sync + 'static, { - fn data_type(&self) -> DataType { - T::DATA_TYPE - } - fn search<'a>( &'a self, query: Slice<'a>, diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index 2a399c6e3..d82c4579a 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -226,15 +226,6 @@ impl Dataset { self.as_view().ncols() } - pub(crate) fn row(&self, i: usize) -> Option> { - match self { - Self::F32(m) => m.get_row(i).map(Slice::from), - Self::F16(m) => m.get_row(i).map(Slice::from), - Self::U8(m) => m.get_row(i).map(Slice::from), - Self::I8(m) => m.get_row(i).map(Slice::from), - } - } - pub(crate) fn as_view(&self) -> DatasetView<'_> { match self { Self::F32(m) => DatasetView::F32(m.as_view()), diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs index e146d4a87..ac0a17814 100644 --- a/diskann-inmem/src/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -834,7 +834,6 @@ mod tests { TestGuardDelay, GuardDelay, post_guard_check => post_guard_check, - with_pre_cas => pre_cas, with_post_cas => post_cas, with_pre_fence => pre_fence, with_post_fence => post_fence, diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 07c491dc4..cffc2fdcf 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -390,7 +390,13 @@ impl layers::Search for Full { let query = Calf::Borrowed(query); let output = match self.metric { - Metric::L2 => mint!(query, visitor, u8 => SquaredL2), + Metric::L2 => { + if self.dim() == 128 { + mint!(query, visitor, u8 => { 128, SquaredL2 }) + } else { + mint!(query, visitor, u8 => SquaredL2) + } + } Metric::InnerProduct => mint!(query, visitor, u8 => InnerProduct), Metric::Cosine => mint!(query, visitor, u8 => Cosine), Metric::CosineNormalized => mint!(query, visitor, u8 => Cosine), @@ -427,6 +433,7 @@ impl layers::Search for Full { /////////// #[cfg(test)] +#[cfg(not(miri))] mod tests { use std::fmt::Display; diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index 16a9a42db..676af91c0 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -23,7 +23,7 @@ impl Bytes { } #[inline] - pub const fn checked_add(self, other: Bytes) -> Option { + pub(crate) const fn checked_add(self, other: Bytes) -> Option { match self.value().checked_add(other.value()) { Some(v) => Some(Bytes::new(v)), None => None, @@ -31,7 +31,7 @@ impl Bytes { } #[inline] - pub const fn checked_mul(self, other: usize) -> Option { + pub(crate) const fn checked_mul(self, other: usize) -> Option { match self.value().checked_mul(other) { Some(v) => Some(Bytes::new(v)), None => None, @@ -39,30 +39,17 @@ impl Bytes { } #[inline] - pub const fn div(self, other: NonZeroUsize) -> Bytes { + pub(crate) const fn div(self, other: NonZeroUsize) -> Bytes { Bytes::new(self.value() / other.get()) } - #[inline] - pub(crate) const fn unchecked_mul(self, other: usize) -> Bytes { - Bytes::new(self.value() * other) - } - - #[inline] - pub const fn checked_sub(self, other: Bytes) -> Option { - match self.value().checked_sub(other.value()) { - Some(v) => Some(Bytes::new(v)), - None => None, - } - } - #[inline] pub(crate) const fn unchecked_sub(self, other: Bytes) -> Bytes { Self::new(self.value() - other.value()) } #[inline] - pub const fn checked_next_multiple_of(self, other: Bytes) -> Option { + pub(crate) const fn checked_next_multiple_of(self, other: Bytes) -> Option { match self.value().checked_next_multiple_of(other.value()) { Some(v) => Some(Bytes::new(v)), None => None, @@ -183,19 +170,6 @@ mod tests { assert_eq!(Bytes::new(usize::MAX).checked_add(Bytes::new(1)), None); } - #[test] - fn checked_sub_success() { - assert_eq!( - Bytes::new(30).checked_sub(Bytes::new(10)), - Some(Bytes::new(20)) - ); - } - - #[test] - fn checked_sub_underflow() { - assert_eq!(Bytes::new(5).checked_sub(Bytes::new(10)), None); - } - #[test] fn checked_mul_success() { assert_eq!(Bytes::new(64).checked_mul(4), Some(Bytes::new(256))); @@ -211,11 +185,6 @@ mod tests { assert_eq!(Bytes::new(100).checked_mul(0), Some(Bytes::new(0))); } - #[test] - fn unchecked_mul() { - assert_eq!(Bytes::new(64).unchecked_mul(3), Bytes::new(192)); - } - #[test] fn unchecked_sub() { assert_eq!( diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index a8e08c191..24db529da 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{hash::Hash, marker::PhantomData}; +use std::hash::Hash; use diskann::{ ANNError, ANNErrorKind, ANNResult, @@ -82,6 +82,11 @@ where self.counters.local() } + /// Return the maximum number of neighbors that can be stored in the provider's graph. + pub fn max_degree(&self) -> usize { + self.store.max_degree() + } + /// Return a snapshot of the current event counters. #[cfg(feature = "integration-test")] pub fn counters(&self) -> crate::integration::counters::CounterSnapshot { @@ -137,8 +142,10 @@ where _context: &Self::Context, gid: &M, ) -> Result { - let id = self.mapping.to_internal(gid).unwrap(); - Ok(id) + match self.mapping.to_internal(gid) { + Some(id) => Ok(id), + None => Err(ANNError::message(ANNErrorKind::Opaque, "no mapping")), + } } /// Translate an internal id to its corresponding external id. @@ -147,8 +154,10 @@ where _context: &Self::Context, id: Self::InternalId, ) -> Result { - let gid = self.mapping.to_external(id).unwrap(); - Ok(gid) + match self.mapping.to_external(id) { + Some(gid) => Ok(gid), + None => Err(ANNError::message(ANNErrorKind::Opaque, "no mapping")), + } } } @@ -237,12 +246,18 @@ where element: T, ) -> impl std::future::Future> + Send { let work = move || { - let mut slot = self.store.acquire().unwrap(); + let mut slot = self.store.acquire().ok_or_else(|| { + ANNError::message(ANNErrorKind::Opaque, "could not allocate a new slot") + })?; // TODO: Proper cleanup via `Guard` or some other mechanism on the event of // insert failure. >::into_bytes(&self.layer, element, slot.as_mut_slice())?; - self.mapping.insert(id.clone(), slot.slot()).unwrap(); + self.mapping.insert(id.clone(), slot.slot())?; + + // Now that insert has succeeded - publish the slot. This method cannot fail, so + // we do not need to worry about potentially unwinding the ID mapping. + let id = slot.publish(); // This is a rather expensive update. // @@ -250,7 +265,7 @@ where // is not expected to be enabled for general use. self.local_counters().set_vector(1); - Ok(diskann::provider::NoopGuard::new(slot.slot())) + Ok(diskann::provider::NoopGuard::new(id)) }; ready(work) @@ -265,7 +280,7 @@ where pub struct SearchAccessor<'a> { reader: store::Reader<'a>, ids: AdjacencyList, - expand_beam: ExpandBeam<'a>, + expand_beam: Box, // The parent provider for the accessor. provider: &'a (dyn std::any::Any + Send + Sync), @@ -346,7 +361,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { unsafe { self.expand_beam - .run(&self.ids, 8, &self.reader, &mut on_neighbors) + .expand_beam(&self.ids, 8, &self.reader, &mut on_neighbors) }?; } @@ -357,18 +372,41 @@ impl glue::SearchAccessor for SearchAccessor<'_> { } } -// trait ExpandBeam2: Send + Sync + Debug { -// fn evaluate(&self, x: &[u8]) -> ANNResult; -// -// } +trait ExpandBeam2: Send + Sync + std::fmt::Debug { + /// Evaluate a raw distance function. + fn evaluate(&self, x: &[u8]) -> ANNResult; + + unsafe fn expand_beam( + &self, + list: &[u32], + lookahead: usize, + reader: &store::Reader<'_>, + f: &mut dyn FnMut(u32, f32), + ) -> ANNResult<()>; +} + +#[derive(Debug)] +#[repr(transparent)] +struct ExpandBeamImpl(T); + +impl ExpandBeam2 for ExpandBeamImpl +where + T: layers::QueryDistance, +{ + fn evaluate(&self, x: &[u8]) -> ANNResult { + self.0.evaluate(x) + } -type FExpandBeam = unsafe fn( - *const (), - &[u32], - usize, - &store::Reader<'_>, - &mut dyn FnMut(u32, f32), -) -> ANNResult<()>; + unsafe fn expand_beam( + &self, + list: &[u32], + lookahead: usize, + reader: &store::Reader<'_>, + f: &mut dyn FnMut(u32, f32), + ) -> ANNResult<()> { + unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, f) } + } +} #[derive(Debug)] struct ExpandBeamVisitor { @@ -376,7 +414,7 @@ struct ExpandBeamVisitor { } impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { - type Output = ExpandBeam<'a>; + type Output = Box; fn visit_sized(self, distance: T) -> Self::Output where @@ -384,95 +422,17 @@ impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { { // Make sure there's no lying. assert_eq!(Bytes::new(BYTES + 1), self.bytes); - - unsafe { ExpandBeam::new(distance, expand_beam_inner::) } + Box::new(ExpandBeamImpl::<_, BYTES>(distance)) } fn visit(self, distance: T) -> Self::Output where T: QueryDistance + 'a, { - unsafe { ExpandBeam::new(distance, expand_beam_inner::) } + Box::new(ExpandBeamImpl::<_, 0>(distance)) } } -#[derive(Debug)] -struct ExpandBeam<'a> { - ptr: *const (), - expand_beam: FExpandBeam, - vtable: &'static VTable, - lifetime: PhantomData<&'a ()>, -} - -#[derive(Debug)] -struct VTable { - evaluate: unsafe fn(*const (), &[u8]) -> ANNResult, - drop: unsafe fn(*mut ()), -} - -// SAFETY: We constrain `ptr` to be `Send`. -unsafe impl Send for ExpandBeam<'_> {} - -// SAFETY: We constrain `ptr` to be `Send`. -unsafe impl Sync for ExpandBeam<'_> {} - -impl<'a> ExpandBeam<'a> { - unsafe fn new(x: T, expand_beam: FExpandBeam) -> Self - where - T: layers::QueryDistance + Send + Sync + 'a, - { - let vtable = &VTable { - evaluate: evaluate::, - drop: drop::, - }; - - // NOTE: It's really important that we coerce the leaked `&mut` to a `*mut` rather - // than going straight to `*const` as the latter converts `&mut` to `&` first - // and then Miri gets upset when we try to drop it, thinking it's a shared reference - // rather than unique one. - let ptr: *mut T = Box::leak(Box::new(x)); - - Self { - ptr: ptr.cast_const().cast(), - expand_beam, - vtable: &vtable, - lifetime: PhantomData, - } - } - - unsafe fn run( - &self, - list: &[u32], - lookahead: usize, - reader: &store::Reader<'_>, - f: &mut dyn FnMut(u32, f32), - ) -> ANNResult<()> { - unsafe { (self.expand_beam)(self.ptr, list, lookahead, reader, f) } - } - - fn evaluate(&self, x: &[u8]) -> ANNResult { - unsafe { (self.vtable.evaluate)(self.ptr, x) } - } -} - -impl Drop for ExpandBeam<'_> { - fn drop(&mut self) { - unsafe { (self.vtable.drop)(self.ptr.cast_mut()) } - } -} - -unsafe fn drop(ptr: *mut ()) { - let _ = unsafe { Box::from_raw(ptr.cast::()) }; -} - -unsafe fn evaluate(ptr: *const (), x: &[u8]) -> ANNResult -where - T: layers::QueryDistance, -{ - let f = unsafe { &*ptr.cast::() }; - ::evaluate(f, x) -} - #[inline(always)] unsafe fn prefetch(ptr: *const u8, len: usize) { use std::arch::x86_64::*; @@ -493,25 +453,14 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { } } -const CACHE_LINE_SIZE: usize = 64; - -// pub unsafe fn test_function( -// list: &[u32], -// lookahead: usize, -// reader: &store::Reader<'_>, -// distance: &dyn layers::QueryDistance, -// f: &mut dyn FnMut(u32, f32), -// ) -> ANNResult<()> { -// unsafe { expand_beam_inner::<4>(list, lookahead, reader, distance, f) } -// } - /// Safety (no # yet because we need to revisit this - clippy will lint) /// /// * The concrete type of `distance` must be `T`. /// * All items in `list` must in-bounds with respect to `reader`. /// * The number of bytes associated with `N` cache lines must "make sense". +#[inline] unsafe fn expand_beam_inner( - distance: *const (), + distance: &T, list: &[u32], lookahead: usize, reader: &store::Reader<'_>, @@ -520,8 +469,6 @@ unsafe fn expand_beam_inner( where T: layers::QueryDistance, { - let distance = unsafe { &*distance.cast::() }; - debug_assert!( BYTES + 1 <= reader.bytes().value(), "we really rely on this: {}, bytes = {}", @@ -577,6 +524,9 @@ where // Insert // //////////// +/// The [`glue::PruneAccessor`] implementation for [`Provider`]. +/// +/// This type implements zero-copy access to the data within its parent provider during prunes. #[derive(Debug)] pub struct PruneAccessor<'a> { reader: store::Reader<'a>, @@ -584,6 +534,7 @@ pub struct PruneAccessor<'a> { counters: LocalCounters<'a>, } +/// The distance computer for [`PruneAccessor`]. #[derive(Debug)] pub struct Distance<'a> { distance: &'a dyn layers::Distance, @@ -906,24 +857,52 @@ where mod tests { use super::*; - use diskann::graph::DiskANNIndex; + use diskann::{ + graph::{DiskANNIndex, InplaceDeleteMethod, search::Knn, test::synthetic::Grid}, + neighbor::Neighbor, + provider::{DataProvider, Delete}, + }; use diskann_vector::distance::Metric; use crate::layers::Full; + /// The true tests live in the integration tests for this repo. + /// + /// The smoke test here uses a 2D grid of points to verify that our provider + /// implementations are more-or-less correct. + /// + /// Note that since `Provider` separates internal and external IDs, we multiply the + /// coordinates of each element in the grid by 10 and add 1 to verify that the ID + /// translation is behaving properly. + /// + /// For clarity, the expected structure of the grid is as follows: + /// + /// + /// 41 91 141 191 241 + /// 31 81 131 181 231 + /// 21 71 121 171 221 + /// 11 61 111 161 211 + /// 1 51 101 151 201 + /// #[tokio::test] async fn smoke() { - let full = Full::::new(1, Metric::L2); - let start_points: [&[f32]; _] = [&[1.0], &[2.0]]; + let grid = Grid::Two; + let size = 5; + let data = grid.data(size); + let start = grid.start_point(size); + let degree = 6; - let config = Config::new(10, 16); + let full = Full::::new(grid.dim().into(), Metric::L2); - let provider = Provider::new(full, config, start_points); + let config = Config::new(grid.num_points(size), degree); + + let provider = Provider::<_, u64>::new(full, config, std::iter::once(start.as_slice())); + assert_eq!(provider.max_degree(), degree); let config = diskann::graph::config::Builder::new( + 2 * (grid.dim() as usize), + diskann::graph::config::MaxDegree::new(provider.max_degree()), 10, - diskann::graph::config::MaxDegree::Same, - 100, (Metric::L2).into(), ) .build() @@ -931,8 +910,154 @@ mod tests { let index = DiskANNIndex::new(config, provider, None); - index.insert(&Strategy, &Context, &0, &[3.0]).await.unwrap(); - index.insert(&Strategy, &Context, &1, &[4.0]).await.unwrap(); - index.insert(&Strategy, &Context, &2, &[5.0]).await.unwrap(); + for (i, data) in data.row_iter().enumerate() { + index + .insert(&Strategy, &Context, &((10 * i + 1) as u64), data) + .await + .unwrap(); + } + + // Verify that each ID round trips. + for i in 0..data.nrows() { + let i = (10 * i + 1) as u64; + let internal = index.provider().to_internal_id(&Context, &i).unwrap(); + assert_ne!(internal as u64, i); + assert_eq!( + index.provider().to_external_id(&Context, internal).unwrap(), + i + ); + + assert!( + !index + .provider() + .status_by_external_id(&Context, &i) + .await + .unwrap() + .is_deleted() + ); + assert!( + !index + .provider() + .status_by_internal_id(&Context, internal) + .await + .unwrap() + .is_deleted() + ); + } + + // Assert that out-of-bounds translations returns errors. + assert!(index.provider().to_internal_id(&Context, &0).is_err()); + assert!(index.provider().to_external_id(&Context, 26).is_err()); + + // Searches should return something reasonable. + let knn = Knn::new(10, 10, None).unwrap(); + let mut neighbors = Vec::>::new(); + index + .search(knn, &Strategy, &Context, &[0.0, 0.0], &mut neighbors) + .await + .unwrap(); + + assert_eq!(neighbors[0].as_tuple(), (1, 0.0)); + assert_eq!(neighbors[1].as_tuple(), (11, 1.0)); // this can be swapped with 2 + assert_eq!(neighbors[2].as_tuple(), (51, 1.0)); + assert_eq!(neighbors[3].as_tuple(), (61, 2.0)); + + // If we run inplace delete on point 61, it longer be present. + index + .inplace_delete( + Strategy, + &Context, + &61, + 3, + InplaceDeleteMethod::VisitedAndTopK { + k_value: 10, + l_value: 10, + }, + ) + .await + .unwrap(); + + assert!( + index + .provider() + .status_by_external_id(&Context, &61) + .await + .unwrap() + .is_deleted() + ); + + // We can't delete the same thing twice. + assert!( + index + .inplace_delete( + Strategy, + &Context, + &61, + 3, + InplaceDeleteMethod::VisitedAndTopK { + k_value: 10, + l_value: 10 + }, + ) + .await + .is_err() + ); + + // Rerun search - the point 61 should now be gone. + let mut neighbors = Vec::>::new(); + index + .search(knn, &Strategy, &Context, &[0.0, 0.0], &mut neighbors) + .await + .unwrap(); + + assert_eq!(neighbors[0].as_tuple(), (1, 0.0)); + assert_eq!(neighbors[1].as_tuple(), (51, 1.0)); // this can be swapped with 2 + assert_eq!(neighbors[2].as_tuple(), (11, 1.0)); + assert_eq!(neighbors[3].as_tuple(), (101, 4.0)); // we can also accept "21" + + // We can't insert an existing ID. + assert!( + index + .insert(&Strategy, &Context, &1, &[10.0, 10.0]) + .await + .is_err() + ); + + // If we insert a new ID but the query vector is too long - make sure we leave the + // provider untouched. + assert!( + index + .insert(&Strategy, &Context, &2, &[10.0, 10.0, 10.0]) + .await + .is_err() + ); + + // Check that we can reinsert the same point with a different ID and have it be + // returned from search. + index + .insert(&Strategy, &Context, &62, &[1.0, 1.0]) + .await + .unwrap(); + + // We can't insert an ID - but this time it's because we don't have any more internal + // slots. + assert!( + index + .insert(&Strategy, &Context, &62, &[0.0, 0.0]) + .await + .is_err() + ); + + // Rerun search - the point 62 should be present. + let mut neighbors = Vec::>::new(); + index + .search(knn, &Strategy, &Context, &[0.0, 0.0], &mut neighbors) + .await + .unwrap(); + + assert_eq!(neighbors[0].as_tuple(), (1, 0.0)); + assert_eq!(neighbors[1].as_tuple(), (11, 1.0)); // this can be swapped with 2 + assert_eq!(neighbors[2].as_tuple(), (51, 1.0)); + assert_eq!(neighbors[3].as_tuple(), (62, 2.0)); } } diff --git a/diskann-inmem/src/sharded.rs b/diskann-inmem/src/sharded.rs index 946d484db..50a59b9e4 100644 --- a/diskann-inmem/src/sharded.rs +++ b/diskann-inmem/src/sharded.rs @@ -153,7 +153,6 @@ where fn capacity(&self) -> usize { self.capacity } - } struct Shard { @@ -171,6 +170,8 @@ pub(crate) enum InsertError { InternalExists, } +crate::opaque!(InsertError); + /// A handle to a valid entry in a [`Sharded`]. /// /// This can be used to guarantee the presence of an entry prior to deletion to support @@ -209,7 +210,14 @@ mod tests { #[test] fn new_reports_capacity() { - for capacity in [0, 1, SHARD_SIZE - 1, SHARD_SIZE, SHARD_SIZE + 1, 3 * SHARD_SIZE] { + for capacity in [ + 0, + 1, + SHARD_SIZE - 1, + SHARD_SIZE, + SHARD_SIZE + 1, + 3 * SHARD_SIZE, + ] { let map = Sharded::::new(capacity); assert_eq!(map.capacity(), capacity); } @@ -248,10 +256,7 @@ mod tests { let map = Sharded::::new(16); map.insert(7, 5).unwrap(); - assert!(matches!( - map.insert(7, 6), - Err(InsertError::ExternalExists) - )); + assert!(matches!(map.insert(7, 6), Err(InsertError::ExternalExists))); // The failed insert must not have established any partial mapping. assert_eq!(map.to_internal(&7), Some(5)); @@ -264,10 +269,7 @@ mod tests { let map = Sharded::::new(16); map.insert(7, 5).unwrap(); - assert!(matches!( - map.insert(8, 5), - Err(InsertError::InternalExists) - )); + assert!(matches!(map.insert(8, 5), Err(InsertError::InternalExists))); // The failed insert must not have established any partial mapping. assert_eq!(map.to_external(5), Some(7)); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index f5d5ab38b..8815b5236 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -188,6 +188,11 @@ impl Store { self.unpadded } + /// Return the maximum degree that can be stored in the graph. + pub(crate) fn max_degree(&self) -> usize { + self.neighbors.max_length() + } + /// Attempt to reclaim retired slots. /// /// If successful, returns the number of slots reclaimed. @@ -511,6 +516,7 @@ impl<'a> Reader<'a> { /// /// This guarantee only holds while `self` is alive. Construction of a new [`Reader`] /// requires a separate check. + #[cfg(test)] pub(crate) fn can_read(&self, i: usize) -> Option { if !self.is_in_bounds(i) { return None; @@ -582,7 +588,7 @@ impl<'a> Reader<'a> { } } -/// A writable buffer into the data managed by a [`Store`], obtained from [`Store::Acquire`]. +/// A writable buffer into the data managed by a [`Store`], obtained from [`Store::acquire`]. #[derive(Debug)] pub(crate) struct Slot<'a> { tag: &'a AtomicTag, @@ -607,12 +613,23 @@ impl<'a> Slot<'a> { me.mirror.store(Tag::FROZEN, Ordering::Release); me.tag.store(Tag::FROZEN, Ordering::Release); } + + /// Consume the slot and publish the written data for all readers. + /// + /// Return the internal slot ID. + pub(crate) fn publish(self) -> u32 { + let id = self.slot(); + let me = std::mem::ManuallyDrop::new(self); + me.mirror.store(Tag::PUBLISHED, Ordering::Release); + me.tag.store(Tag::PUBLISHED, Ordering::Release); + id + } } impl Drop for Slot<'_> { fn drop(&mut self) { - self.mirror.store(Tag::PUBLISHED, Ordering::Release); - self.tag.store(Tag::PUBLISHED, Ordering::Release); + self.mirror.store(Tag::AVAILABLE, Ordering::Release); + self.tag.store(Tag::AVAILABLE, Ordering::Release); } } @@ -729,7 +746,7 @@ mod tests { // Before the slot is dropped - we should not be able to read it. assert!(reader.read(idx).is_none()); assert!(!s.can_read_approximate(idx).unwrap()); - + slot.publish(); idx }; @@ -737,6 +754,30 @@ mod tests { assert!(s.can_read_approximate(idx).unwrap()); } + #[test] + fn unpublished_slots_are_immediately_available() { + let s = store(4, 8, 1).unwrap(); + + let reader = s.reader().expect("reader guard available"); + + let idx = { + let mut slot = s.acquire().expect("a fresh store has free slots"); + let idx = slot.slot() as usize; + slot.as_mut_slice() + .copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + + // Before the slot is dropped - we should not be able to read it. + assert!(reader.read(idx).is_none()); + assert!(!s.can_read_approximate(idx).unwrap()); + + // NOTE: We do not explicitly publish the slot. + idx + }; + + assert!(reader.read(idx).is_none()); + assert!(!s.can_read_approximate(idx).unwrap()); + } + #[test] fn acquire_exhausts_then_reports_none() { let s = store(2, 8, 1).unwrap(); @@ -787,8 +828,7 @@ mod tests { let idx = { let slot = s.acquire().unwrap(); - slot.slot() as usize - // Publish + slot.publish() as usize }; assert!(s.retire(idx).is_ok()); @@ -818,6 +858,7 @@ mod tests { // Claim all slots. let mut count = 0; while let Some(slot) = s.acquire() { + slot.publish(); count += 1; } @@ -831,6 +872,7 @@ mod tests { // Verify that we can claim all slots again. let mut count = 0; while let Some(slot) = s.acquire() { + slot.publish(); count += 1; } diff --git a/diskann-inmem/src/tag.rs b/diskann-inmem/src/tag.rs index 81dfd14a7..e0867252b 100644 --- a/diskann-inmem/src/tag.rs +++ b/diskann-inmem/src/tag.rs @@ -307,4 +307,17 @@ mod tests { assert!(!Tag::OWNED.can_read()); assert!(!Tag::RETIRING.can_read()); } + + #[test] + fn test_display() { + assert_eq!(Tag::AVAILABLE.to_string(), "Tag(AVAILABLE)"); + assert_eq!(Tag::OWNED.to_string(), "Tag(OWNED)"); + assert_eq!(Tag::RETIRING.to_string(), "Tag(RETIRING)"); + assert_eq!(Tag::FROZEN.to_string(), "Tag(FROZEN)"); + assert_eq!(Tag::PUBLISHED.to_string(), "Tag(PUBLISHED)"); + + // Guard against future changes. + assert_eq!(Tag::new(Tag::RETIRING.value() + 1).to_string(), "Tag(3)"); + assert_eq!(Tag::new(Tag::PUBLISHED.value() - 1).to_string(), "Tag(253)"); + } } From f1d6dfd007b993f309dcc785a068e01db1daed0d Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 25 Jun 2026 18:35:17 -0700 Subject: [PATCH 27/45] Clippy!. --- diskann-inmem/integration/index/mod.rs | 4 +- .../integration/index/{index.rs => object.rs} | 3 -- diskann-inmem/integration/index/runner.rs | 2 +- diskann-inmem/integration/index/tests.rs | 2 +- diskann-inmem/integration/store.rs | 6 +-- diskann-inmem/integration/support/check.rs | 2 +- diskann-inmem/integration/support/datatype.rs | 21 ++++++-- diskann-inmem/integration/support/io.rs | 8 +-- diskann-inmem/src/buffer.rs | 22 ++++++++- diskann-inmem/src/epoch.rs | 2 +- diskann-inmem/src/freelist.rs | 7 +-- diskann-inmem/src/layers/full.rs | 33 +++++++------ diskann-inmem/src/layers/mod.rs | 2 +- diskann-inmem/src/neighbors.rs | 39 +++++++++++++++ diskann-inmem/src/num.rs | 7 +++ diskann-inmem/src/provider.rs | 12 ++--- diskann-inmem/src/store.rs | 49 ++++++++++++++++--- diskann-inmem/src/tag.rs | 7 ++- diskann-inmem/src/test/epoch.rs | 35 +++++++------ 19 files changed, 192 insertions(+), 71 deletions(-) rename diskann-inmem/integration/index/{index.rs => object.rs} (98%) diff --git a/diskann-inmem/integration/index/mod.rs b/diskann-inmem/integration/index/mod.rs index ef15e8fe1..735ca820b 100644 --- a/diskann-inmem/integration/index/mod.rs +++ b/diskann-inmem/integration/index/mod.rs @@ -3,11 +3,11 @@ * Licensed under the MIT license. */ -mod index; +mod object; mod runner; mod tests; -use index::{Counters, Index, KnnSearch}; +use object::{Counters, Index, KnnSearch}; use diskann_benchmark_runner::{Registry, RegistryError}; diff --git a/diskann-inmem/integration/index/index.rs b/diskann-inmem/integration/index/object.rs similarity index 98% rename from diskann-inmem/integration/index/index.rs rename to diskann-inmem/integration/index/object.rs index 43bd0a7e5..95505366d 100644 --- a/diskann-inmem/integration/index/index.rs +++ b/diskann-inmem/integration/index/object.rs @@ -227,7 +227,4 @@ where fn counters(&self) -> Counters { self.provider().counters().into() } - - // fn retire(&self, id: u64) -> anyhow::Result<()> { - // } } diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index bd2a68335..68824cec6 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -487,7 +487,7 @@ impl diskann_benchmark_runner::Benchmark for FullPrecision { for param in input.search.knn.iter() { let stats = super::tests::knn( &*index, - param.clone(), + *param, queries.as_view(), &groundtruth.as_view(), rt.handle(), diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs index dd9b686c1..61d5166bd 100644 --- a/diskann-inmem/integration/index/tests.rs +++ b/diskann-inmem/integration/index/tests.rs @@ -26,7 +26,7 @@ pub(super) fn insert( for i in 0..dataset.nrows() { rt.block_on(index.insert(dataset.row(i).unwrap(), i as u64))?; } - Ok(before.delta(&index.counters())?) + before.delta(&index.counters()) } pub(super) fn knn( diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index af3d99fee..7c4d183a2 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -343,7 +343,7 @@ fn retirer(shared: &Shared, seed: u64) { // Flow control: keep a steady readable population. if shared.live.load(Relaxed) > shared.low_watermark { - let i = rng.sample(&shared.writable); + let i = rng.sample(shared.writable); if shared.store.retire(i) { shared.live.fetch_sub(1, Relaxed); shared.retires_ok.fetch_add(1, Relaxed); @@ -352,7 +352,7 @@ fn retirer(shared: &Shared, seed: u64) { } } - if iteration % RECLAIM_EVERY == 0 + if iteration.is_multiple_of(RECLAIM_EVERY) && let Some(reclaimed) = shared.store.reclaim() { shared.reclaims.fetch_add(reclaimed as u64, Relaxed); @@ -377,7 +377,7 @@ fn reader(shared: &Shared, seed: u64) { }; observations.clear(); - let start = rng.sample(&shared.readable); + let start = rng.sample(shared.readable); for _ in 0..READER_PASSES { for k in 0..window { let i = (start + k) % slots; diff --git a/diskann-inmem/integration/support/check.rs b/diskann-inmem/integration/support/check.rs index f807827e2..435f0a012 100644 --- a/diskann-inmem/integration/support/check.rs +++ b/diskann-inmem/integration/support/check.rs @@ -308,7 +308,7 @@ impl MatchBuilder { } else { Match::Nested { children: self.children, - remark: remark, + remark, } } } diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index d82c4579a..6864f434c 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -113,6 +113,17 @@ pub(crate) enum SliceMut<'a> { I8(&'a mut [i8]), } +fn map(dst: &mut [T], src: &[U], f: F) +where + T: std::fmt::Display + AsDataType, + U: std::fmt::Display + AsDataType + Copy, + F: Fn(U) -> T, +{ + std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| { + *d = f(*s); + }) +} + fn try_map(dst: &mut [T], src: &[U], f: F) -> anyhow::Result<()> where T: std::fmt::Display + AsDataType, @@ -181,14 +192,14 @@ impl<'a> SliceMut<'a> { match (self, rhs) { (SliceMut::F32(dst), Slice::F32(src)) => dst.copy_from_slice(src), - (SliceMut::F32(dst), Slice::F16(src)) => try_map(dst, src, |x| x.try_into())?, - (SliceMut::F32(dst), Slice::U8(src)) => try_map(dst, src, |x| x.try_into())?, - (SliceMut::F32(dst), Slice::I8(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::F32(dst), Slice::F16(src)) => map(dst, src, |x| x.into()), + (SliceMut::F32(dst), Slice::U8(src)) => map(dst, src, |x| x.into()), + (SliceMut::F32(dst), Slice::I8(src)) => map(dst, src, |x| x.into()), (SliceMut::F16(dst), Slice::F32(src)) => try_map(dst, src, f32_to_f16)?, (SliceMut::F16(dst), Slice::F16(src)) => dst.copy_from_slice(src), - (SliceMut::F16(dst), Slice::U8(src)) => try_map(dst, src, |x| x.try_into())?, - (SliceMut::F16(dst), Slice::I8(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::F16(dst), Slice::U8(src)) => map(dst, src, |x| x.into()), + (SliceMut::F16(dst), Slice::I8(src)) => map(dst, src, |x| x.into()), (SliceMut::U8(dst), Slice::F32(src)) => try_map(dst, src, f32_to_u8)?, (SliceMut::U8(dst), Slice::F16(src)) => try_map(dst, src, f16_to_u8)?, diff --git a/diskann-inmem/integration/support/io.rs b/diskann-inmem/integration/support/io.rs index c00a9ca00..335ead7cf 100644 --- a/diskann-inmem/integration/support/io.rs +++ b/diskann-inmem/integration/support/io.rs @@ -30,22 +30,22 @@ where let dst = match target { DataType::F32 => { let mut dst = Matrix::new(0.0f32, data.nrows(), data.ncols()); - SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; Dataset::from(dst) } DataType::F16 => { let mut dst = Matrix::new(f16::from_f32(0.0f32), data.nrows(), data.ncols()); - SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; Dataset::from(dst) } DataType::U8 => { let mut dst = Matrix::new(0u8, data.nrows(), data.ncols()); - SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; Dataset::from(dst) } DataType::I8 => { let mut dst = Matrix::new(0i8, data.nrows(), data.ncols()); - SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice().into())?; + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; Dataset::from(dst) } }; diff --git a/diskann-inmem/src/buffer.rs b/diskann-inmem/src/buffer.rs index 8c2721e55..32e6e4930 100644 --- a/diskann-inmem/src/buffer.rs +++ b/diskann-inmem/src/buffer.rs @@ -89,7 +89,11 @@ impl Buffer { #[inline] pub(crate) unsafe fn get_unchecked(&self, i: usize) -> RawSlice<'_> { debug_assert!(i < self.entries); + + // SAFETY: The caller asserts that `i` is in-bounds; the computed pointer stays + // within a single allocated object. let ptr = unsafe { self.ptr.add(self.stride().value() * i) }; + RawSlice { ptr, len: self.stride, @@ -203,7 +207,7 @@ impl<'a> RawSlice<'a> { /// where `m = n.min(self.len())`. #[inline] pub(crate) fn split(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { - // SAFETY: The argument is <= `self.len()`. + // SAFETY: the argument is <= `self.len()`. unsafe { self.split_unchecked(self.len.min(n)) } } @@ -215,6 +219,8 @@ impl<'a> RawSlice<'a> { #[inline] pub(crate) unsafe fn split_unchecked(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { debug_assert!(n <= self.len); + + // SAFETY: the argument is <= `self.len()`. unsafe { ( Self::new(self.ptr, n), @@ -258,6 +264,7 @@ impl<'a> RawSlice<'a> { /// slice does not violate Rust's borrowing rules. #[inline] pub(crate) unsafe fn as_slice(&self) -> &'a [u8] { + // SAFETY: Inherited from caller. unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len.value()) } } @@ -272,6 +279,7 @@ impl<'a> RawSlice<'a> { /// slice does not violate Rust's borrowing rules. #[inline] pub(crate) unsafe fn as_mut_slice(&mut self) -> &'a mut [u8] { + // SAFETY: Inherited from caller. unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len.value()) } } } @@ -344,6 +352,7 @@ mod tests { let mut raw_slice = buffer.get(i).unwrap(); assert_eq!(raw_slice.len(), buffer.stride()); + // SAFETY: See safety note. let slice = unsafe { raw_slice.as_mut_slice() }; assert_eq!(slice.len(), buffer.stride().value()); slice.fill(0); @@ -373,6 +382,7 @@ mod tests { ctx ); + // SAFETY: See safety note. let slice = unsafe { raw_slice.as_slice() }; assert_eq!(slice.len(), buffer.stride().value()); assert!(slice.iter().all(|&i| i == 0), "{}", ctx); @@ -397,12 +407,14 @@ mod tests { // truncate // + // SAFETY: see safety note. iota(unsafe { raw.as_mut_slice() }, base); for i in 0..raw.len().value() + base_usize { let expected = i.min(raw.len().value()); let truncated = raw.truncate(Bytes::new(i)); assert_eq!(truncated.len().value(), expected, "{}", ctx); + // SAFETY: see safety note. assert!(is_iota(unsafe { truncated.as_slice() }, base), "{}", ctx); } @@ -417,8 +429,11 @@ mod tests { assert_eq!(prefix.len().value(), first, "{}", ctx); assert_eq!(suffix.len().value(), last, "{}", ctx); + // SAFETY: see safety note. assert!(is_iota(unsafe { prefix.as_slice() }, base), "{}", ctx); + assert!( + // SAFETY: see safety note. is_iota(unsafe { suffix.as_slice() }, base.wrapping_add(i as u8)), "{}", ctx @@ -429,13 +444,17 @@ mod tests { // SAFETY: `prefix` and `suffix` are non-overlapping sub-ranges of the same // entry, so materializing both as mutable is sound. { + // SAFETY: see above let prefix = unsafe { prefix.as_mut_slice() }; + // SAFETY: see above let suffix = unsafe { suffix.as_mut_slice() }; suffix.fill(0); prefix.fill(0); } + // SAFETY: see safety note. assert!(unsafe { raw.as_slice() }.iter().all(|i| *i == 0), "{}", ctx); + // SAFETY: see safety note. iota(unsafe { raw.as_mut_slice() }, base); } } @@ -467,6 +486,7 @@ mod tests { } for i in 0..spawns { + // SAFETY: at this point we have exclusive access to `buffer`. let slice = unsafe { buffer.get(i).unwrap().as_slice() }; assert!(is_iota(slice, i as u8), "i = {} -- {}", i, ctx); } diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs index ac0a17814..ec006d4b0 100644 --- a/diskann-inmem/src/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -164,7 +164,7 @@ impl Registry { let m = &self.guards[slot]; delay.pre_cas(); - if let Ok(_) = m.compare_exchange(0, epoch, Ordering::Relaxed, Ordering::Relaxed) { + if m.compare_exchange(0, epoch, Ordering::Relaxed, Ordering::Relaxed).is_ok() { delay.post_cas(); let mut reset = false; loop { diff --git a/diskann-inmem/src/freelist.rs b/diskann-inmem/src/freelist.rs index 966153c6d..a0400e672 100644 --- a/diskann-inmem/src/freelist.rs +++ b/diskann-inmem/src/freelist.rs @@ -390,11 +390,8 @@ mod tests { s.spawn(|| { let mut out = Vec::new(); barrier.wait(); - loop { - match fl.pop() { - Id::Found(id) => out.push(id), - Id::Scan => break, - } + while let Id::Found(id) = fl.pop() { + out.push(id); } out }) diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index cffc2fdcf..83ceb5b20 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -85,7 +85,7 @@ impl layers::Set<&[T]> for Full where T: bytemuck::Pod + Send + Sync, { - fn into_bytes(&self, v: &[T], bytes: &mut [u8]) -> ANNResult<()> { + fn set(&self, v: &[T], bytes: &mut [u8]) -> ANNResult<()> { if v.len() != self.dim() { Err(ANNError::from(SetError::Dim { got: v.len(), @@ -194,10 +194,13 @@ where if x.len() != bytes || y.len() != bytes { self.error(x, y) } else { - Ok(self.f.call_unaligned( - unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.dim) }, - unsafe { UnalignedSlice::new(y.as_ptr().cast::(), self.dim) }, - )) + // SAFETY: We've checked that both `x` and `y` are valid for + // `size_of::() * self.dim` bytes. + let ux = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.dim) }; + + // SAFETY: Same as above + let uy = unsafe { UnalignedSlice::new(y.as_ptr().cast::(), self.dim) }; + Ok(self.f.call_unaligned(ux, uy)) } } } @@ -501,7 +504,7 @@ mod tests { /// Exercise every `Full` API across dimensions `1..=max_dim`. /// - /// For each dimension we check that `bytes`/`into_bytes` agree, that `distance` and + /// For each dimension we check that `bytes`/`set` agree, that `distance` and /// `query_distance` are consistent with `DistanceProvider`, and that all of these /// reject byte slices that are too long or too short. fn test_impl(max_dim: usize, ctx: &dyn Display) @@ -521,7 +524,7 @@ mod tests { let a = gen_vec::(&mut rng, dim); let b = gen_vec::(&mut rng, dim); - // `bytes` and `into_bytes` agree: the encoded buffer equals the raw cast bytes. + // `bytes` and `set` agree: the encoded buffer equals the raw cast bytes. let layer = Full::::new(dim, Metric::L2); assert_eq!( layer.bytes().value(), @@ -530,15 +533,15 @@ mod tests { ); let mut a_bytes = vec![0u8; layer.bytes().value()]; - layer.into_bytes(&a, &mut a_bytes).unwrap(); + layer.set(&a, &mut a_bytes).unwrap(); assert_eq!( a_bytes.as_slice(), bytemuck::cast_slice::(&a), - "{ctx}: dim {dim}: into_bytes mismatch", + "{ctx}: dim {dim}: set mismatch", ); let mut b_bytes = vec![0u8; layer.bytes().value()]; - layer.into_bytes(&b, &mut b_bytes).unwrap(); + layer.set(&b, &mut b_bytes).unwrap(); for metric in metrics { let full = Full::::new(dim, metric); @@ -578,12 +581,12 @@ mod tests { assert!(query.evaluate(&long).is_err()); } - // `into_bytes` rejects mis-sized element and buffer slices. + // `set` rejects mis-sized element and buffer slices. let mut buf = vec![0u8; layer.bytes().value()]; let too_many = gen_vec::(&mut rng, dim + 1); assert!( - layer.into_bytes(&too_many, &mut buf).is_err(), - "{ctx}: dim {dim}: into_bytes accepted an over-long element slice", + layer.set(&too_many, &mut buf).is_err(), + "{ctx}: dim {dim}: set accepted an over-long element slice", ); assert!( @@ -593,8 +596,8 @@ mod tests { let mut short_buf = vec![0u8; layer.bytes().value().saturating_sub(1)]; assert!( - layer.into_bytes(&a, &mut short_buf).is_err(), - "{ctx}: dim {dim}: into_bytes accepted an under-sized buffer", + layer.set(&a, &mut short_buf).is_err(), + "{ctx}: dim {dim}: set accepted an under-sized buffer", ); let too_few = gen_vec::(&mut rng, dim - 1); diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index b3b1cc187..b67c1ada1 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -46,7 +46,7 @@ pub trait Layer: Send + Sync + 'static { /// Implementations may assume that `bytes.len()` is equal to [`Layer::bytes`]. pub trait Set: Layer { /// Write into the stored representation. - fn into_bytes(&self, element: T, bytes: &mut [u8]) -> ANNResult<()>; + fn set(&self, element: T, bytes: &mut [u8]) -> ANNResult<()>; } /// A distance computation on raw byte slices. diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs index 3425a2fc1..14600502e 100644 --- a/diskann-inmem/src/neighbors.rs +++ b/diskann-inmem/src/neighbors.rs @@ -128,6 +128,7 @@ impl Neighbors { ) -> Result<(), OutOfBounds> { self.check(i)?; + // SAFETY: We've checked that `i` is in-bounds. let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }; let _guard = lock.read(); @@ -147,6 +148,9 @@ impl Neighbors { .into_usize(); let mut resizer = neighbors.resize(len); + + // SAFETY: We've validated that the two slices are valid. They cannot overlap + // because `neighbors` is provided externally by exclusive reference. unsafe { std::ptr::copy_nonoverlapping( rest.as_mut_ptr(), @@ -163,10 +167,18 @@ impl Neighbors { /// Returns an error if `i` exceeds [`Self::entries`]. pub(crate) fn lock(&self, i: u32) -> Result, OutOfBounds> { self.check(i)?; + + // SAFETY: `i` is in-bounds. Ok(unsafe { self.lock_unchecked(i) }) } + /// Lock adjacency-list `i` without bounds-checking. + /// + /// # SAFETY + /// + /// `i` must be in-bounds. unsafe fn lock_unchecked(&self, i: u32) -> Lock<'_> { + // SAFETY: `i` is in-bounds. let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }.write(); // SAFETY: By construction `self.buffer` has the same number of entries as @@ -203,7 +215,10 @@ impl Neighbors { })); } + // SAFETY: We've checked `i` is in-bounds. let lock = unsafe { self.lock_unchecked(i) }; + + // SAFETY: `neighbors.len() <= self.max_length()`. unsafe { lock.write_unchecked(neighbors) }; Ok(()) } @@ -313,6 +328,10 @@ impl Lock<'_> { }); } + // SAFETY: We've verified that both regions are in-bounds. + // + // The slices have to be disjoint because `self` effectively owns its data while + // it is alive and this method receives by-value. unsafe { std::ptr::copy_nonoverlapping( neighbors.as_ptr(), @@ -321,20 +340,39 @@ impl Lock<'_> { ) } + // SAFETY: `self.ptr` is guaranteed to be valid for at least 4-bytes, and we own the + // underlying data until `drop`. unsafe { self.ptr.write(newlen as u32) }; + Ok(()) } + /// Write the contents of `neighbors` into `self` without validating lenghts. + /// + /// # Safety + /// + /// `neighbors.len() <= self.capacity()`. unsafe fn write_unchecked(self, neighbors: &[u32]) { let len = neighbors.len(); debug_assert!(len <= self.capacity()); + + // SAFETY: the caller asserts that the pointer arithmetic is sound. + // + // The slices are disjoint because `self` owns its data and this method receives + // by value. unsafe { std::ptr::copy_nonoverlapping(neighbors.as_ptr(), self.ptr.as_ptr().add(1), len) } + + // SAFETY: `self.ptr` is guaranteed to be valid for at least 4-bytes, and we own the + // underlying data until `drop`. unsafe { self.ptr.write(len as u32) }; } #[cfg(test)] fn as_slice(&self) -> &[u32] { let len = self.len(); + + // SAFETY: by construction - this access is in-bounds and `Lock` has exclusive + // access too its data, so we're free to hand out a raw slice. unsafe { std::slice::from_raw_parts(self.ptr.add(1).as_ptr().cast_const(), len) } } @@ -347,6 +385,7 @@ impl Lock<'_> { }); } + // SAFETY: We've checked that `neighbors.len() <= self.capacity()`. unsafe { self.write_unchecked(neighbors) }; Ok(()) } diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index 676af91c0..eeba142ac 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -94,8 +94,15 @@ impl Align { self.0.get() } + /// Construct a new [`Align`] with the raw `value`. + /// + /// # Safety + /// + /// `value` must be a power of two. pub const unsafe fn new_unchecked(value: usize) -> Self { debug_assert!(value.is_power_of_two()); + + // SAFETY: powers of two must be non-zero. Self(unsafe { NonZeroUsize::new_unchecked(value) }) } diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 24db529da..42819432a 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -57,7 +57,7 @@ where let mut data = Matrix::new(0u8, start_points.len(), bytes.value()); for (row, point) in std::iter::zip(data.row_iter_mut(), start_points.into_iter()) { - layers::Set::into_bytes(&layer, point, row).unwrap(); + layers::Set::set(&layer, point, row).unwrap(); } let store = Store::new( @@ -252,7 +252,7 @@ where // TODO: Proper cleanup via `Guard` or some other mechanism on the event of // insert failure. - >::into_bytes(&self.layer, element, slot.as_mut_slice())?; + >::set(&self.layer, element, slot.as_mut_slice())?; self.mapping.insert(id.clone(), slot.slot())?; // Now that insert has succeeded - publish the slot. This method cannot fail, so @@ -421,7 +421,7 @@ impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { T: QueryDistance + 'a, { // Make sure there's no lying. - assert_eq!(Bytes::new(BYTES + 1), self.bytes); + assert_eq!(Bytes::new(BYTES + store::TAG_SIZE.value()), self.bytes); Box::new(ExpandBeamImpl::<_, BYTES>(distance)) } @@ -470,16 +470,16 @@ where T: layers::QueryDistance, { debug_assert!( - BYTES + 1 <= reader.bytes().value(), + BYTES + store::TAG_SIZE.value() <= reader.bytes().value(), "we really rely on this: {}, bytes = {}", - BYTES + 1, + BYTES + store::TAG_SIZE.value(), reader.bytes() ); let bytes = if BYTES == 0 { reader.bytes().value() } else { - BYTES + 1 + BYTES + store::TAG_SIZE.value() }; let len = list.len(); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 8815b5236..93480d9dc 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -99,7 +99,9 @@ pub(crate) struct Store { neighbors: Neighbors, } -const SPLIT: Bytes = Bytes::size_of::(); +/// The number of bytes occupied by the in-line concurrency tag. +pub(crate) const TAG_SIZE: Bytes = Bytes::size_of::(); + const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); // TODO: This is a guess and probably needs tuning. @@ -123,7 +125,7 @@ impl Store { } let unpadded = bytes - .checked_add(SPLIT) + .checked_add(TAG_SIZE) .expect("unreachable because `init` cannot exceed `isize::MAX` bytes"); // Pad to half a cache line. When data occupies just part of a cache line, this @@ -153,7 +155,7 @@ impl Store { unpadded, unfrozen: entries.into_usize(), tags: repeat_n(Tag::AVAILABLE, total.into_usize()) - .map(|v| AtomicTag::new(v)) + .map(AtomicTag::new) .collect(), // NOTE: The `Freelist` is initialized to `entries` and not `total` because @@ -218,8 +220,17 @@ impl Store { let drain = self.registry.try_advance()?; let items = drain.len(); for i in drain { + assert!( + i.into_usize() < self.buffer.len(), + "received an invalid ID ({}) while reclaiming slots - max allowed is {}", + i, + self.buffer.len(), + ); + // We release the mirror before the main tag. The other direction would // prematurely advertise availability. + // + // SAFETY: We've verified that `i` is in-bounds. let (mirror, _) = unsafe { self.data_unchecked(i.into_usize()) }; release(mirror, "mirror"); release(&self.tags[i.into_usize()], "tag"); @@ -302,6 +313,8 @@ impl Store { match tag.compare_exchange(current, retiring, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => { // Set the metadata in the mirror as well. + // + // SAFETY: We've checked that `i` is in-bounds. let (mirror, _) = unsafe { self.data_unchecked(i) }; mirror.store(retiring, Ordering::Relaxed); guard.retire(i as u32); @@ -330,6 +343,7 @@ impl Store { // freelist for other threads. if tag.load(Ordering::Relaxed) == Tag::AVAILABLE { if acquired.is_none() { + // SAFETY: We're guaranteed that `tag` belongs to `slot`. acquired = unsafe { self.try_acquire(tag, slot) }; } else { self.freelist.push(slot); @@ -356,6 +370,7 @@ impl Store { fn slot(&self, i: u32) -> Option> { let tag = &self.tags.get(i.into_usize()).unwrap(); + // SAFETY: We've guaranteed that `tag` belongs to `slot`. unsafe { self.try_acquire(tag, i) } } @@ -373,6 +388,7 @@ impl Store { Ordering::Relaxed, ) { Ok(_) => { + // SAFETY: Inherited from caller - `slot` is in-bounds. let (mirror, data) = unsafe { self.data_unchecked(slot.into_usize()) }; Some(Slot { tag, @@ -391,10 +407,13 @@ impl Store { /// /// The index `i` must be less then `self.buffer.len()`. unsafe fn data_unchecked(&self, i: usize) -> (&AtomicTag, RawSlice<'_>) { + // SAFETY: inherited from caller. let (data, mirror) = unsafe { self.buffer.get_unchecked(i) } .truncate(self.unpadded) - .split(self.unpadded.unchecked_sub(SPLIT)); + .split(self.unpadded.unchecked_sub(TAG_SIZE)); ( + // SAFETY: We're careful in this module to ensure the inline tags are only + // ever accessed atomically. unsafe { AtomicTag::from_ptr(mirror.as_mut_ptr().cast()) }, data, ) @@ -499,6 +518,7 @@ impl<'a> Reader<'a> { #[inline] pub(crate) fn read(&self, i: usize) -> Option<&[u8]> { if self.is_in_bounds(i) { + // SAFETY: `i` is in-bounds. unsafe { self.read_in_bounds(i) } } else { None @@ -522,13 +542,18 @@ impl<'a> Reader<'a> { return None; } + // SAFETY: We've checked that `i` is in-bounds. + // + // Further, we guarantee that `self.unpadded >= TAG_SIZE`, so the pointer arithmetic + // is in-bounds. let tag_ptr = unsafe { self.buffer .get_unchecked(i) .as_mut_ptr() - .add(self.unpadded.unchecked_sub(SPLIT).value()) + .add(self.unpadded.unchecked_sub(TAG_SIZE).value()) }; + // SAFETY: We only access tag pointers atomically. let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.cast()) } .load(Ordering::Acquire) .can_read(); @@ -546,14 +571,22 @@ impl<'a> Reader<'a> { pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { debug_assert!(self.is_in_bounds(i)); + // SAFETY: + // + // * The caller asserts `i` is in-bounds. + // * We maintain an internal invariant that `self.buffer.stride() <= self.unpadded`. + // * Further, we maintain that `self.unpadded >= TAG_SIZE`. let (data, tag_ptr) = unsafe { self.buffer .get_unchecked(i) .truncate_unchecked(self.unpadded) - .split_unchecked(self.unpadded.unchecked_sub(SPLIT)) + .split_unchecked(self.unpadded.unchecked_sub(TAG_SIZE)) }; // NOTE: Must be `Acquire` to correctly synchronize with writes. + // + // SAFETY: We are careful in this module to ensure that inline tags are only accessed + // atomically. let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.as_mut_ptr().cast()) } .load(Ordering::Acquire) .can_read(); @@ -574,6 +607,7 @@ impl<'a> Reader<'a> { /// The index `i` must be satisfy [`Self::is_in_bounds`]. #[inline] pub(crate) unsafe fn read_raw_unchecked(&self, i: usize) -> RawSlice<'_> { + // SAFETY: Inherited from caller: `i` is inbounds. unsafe { self.buffer.get_unchecked(i) }.truncate(self.unpadded) } @@ -584,7 +618,7 @@ impl<'a> Reader<'a> { /// Return [`Neighbors`]. pub(crate) fn neighbors(&self) -> &Neighbors { - &self.neighbors + self.neighbors } } @@ -600,6 +634,7 @@ pub(crate) struct Slot<'a> { impl<'a> Slot<'a> { /// View the managed data as a mutable slice. pub(crate) fn as_mut_slice(&mut self) -> &mut [u8] { + // SAFETY: The slot guarantees exclusive access to its corresponding data. unsafe { self.data.as_mut_slice() } } diff --git a/diskann-inmem/src/tag.rs b/diskann-inmem/src/tag.rs index e0867252b..e96e234f7 100644 --- a/diskann-inmem/src/tag.rs +++ b/diskann-inmem/src/tag.rs @@ -187,6 +187,7 @@ impl AtomicTag { /// /// See: pub(crate) unsafe fn from_ptr<'a>(ptr: *mut AtomicTag) -> &'a Self { + // SAFETY: inherited from caller. unsafe { &*ptr } } @@ -265,6 +266,7 @@ mod tests { let ptr = buffer.get(0).unwrap().as_mut_ptr().cast::(); { + // SAFETY: We only access these atomically. let tag = unsafe { AtomicTag::from_ptr(ptr) }; tag.store(Tag::FROZEN, Ordering::Relaxed); } @@ -275,14 +277,17 @@ mod tests { s.spawn(|| { // Re-derive `p` to avoid issues with `Send`. let p = buffer.get(0).unwrap().as_mut_ptr().cast::(); + + // SAFETY: We only access this atomically. let tag = unsafe { AtomicTag::from_ptr(p) }; barrier.wait(); - spin_decrement(&tag, count); + spin_decrement(tag, count); }); } }); { + // SAFETY: We only access this atomically. let g = unsafe { AtomicTag::from_ptr(ptr) }.load(Ordering::Relaxed); assert_eq!(g, Tag::new(u8::MAX.wrapping_sub((count * threads) as u8))); } diff --git a/diskann-inmem/src/test/epoch.rs b/diskann-inmem/src/test/epoch.rs index b4a2a5243..c6fa9104e 100644 --- a/diskann-inmem/src/test/epoch.rs +++ b/diskann-inmem/src/test/epoch.rs @@ -37,22 +37,27 @@ impl Slot { where F: FnOnce(), { - if let Ok(_) = self.tag.compare_exchange( + if self.tag.compare_exchange( Tag::AVAILABLE, Tag::OWNED, Ordering::Acquire, Ordering::Relaxed, - ) { + ).is_ok() { + // SAFETY: By transitioning from AVAILABLE to OWNED, we've acquired ownership + // of this slot and are thus free to write to the `UnsafeCell`. unsafe { &mut *self.payload.get() }.write(Box::new(payload)); f(); self.tag.store(Tag::PUBLISHED, Ordering::Release); } } - unsafe fn try_read(&self) -> Option<&Data> { + fn try_read(&self) -> Option<&Data> { if self.tag.load(Ordering::Acquire).can_read() { + // SAFETY: We've checked that we can read this cell. let payload = unsafe { &*self.payload.get() }; - Some(&*unsafe { payload.assume_init_ref() }) + + // SAFETY: Items that can be read **must** be initialized. + Some(unsafe { payload.assume_init_ref() }) } else { None } @@ -65,28 +70,26 @@ impl Slot { return false; } - if let Ok(_) = self.tag.compare_exchange( + self.tag.compare_exchange( Tag::PUBLISHED, Tag::RETIRING, Ordering::Relaxed, Ordering::Relaxed, - ) { - true - } else { - false - } + ).is_ok() } unsafe fn make_available(&self) { assert_eq!(self.tag.load(Ordering::Relaxed), Tag::RETIRING); + + // SAFETY: Items tagged as `RETIRING` must be initialized. unsafe { (&mut *self.payload.get()).assume_init_drop() }; - if let Err(_) = self.tag.compare_exchange( + if self.tag.compare_exchange( Tag::RETIRING, Tag::AVAILABLE, Ordering::Release, Ordering::Relaxed, - ) { + ).is_err() { panic!("concurrency violation"); } } @@ -96,12 +99,15 @@ impl Drop for Slot { fn drop(&mut self) { if self.tag.load(Ordering::Relaxed) != Tag::AVAILABLE { let payload = self.payload.get_mut(); + + // SAFETY: We have exclusive access and by convention, if the tag is not + // available, then the corresponding payload is initialized. unsafe { payload.assume_init_drop() }; } } } -// We control concurrency, so can safely share this. +// SAFETY: We control concurrency, so can safely share this. unsafe impl Sync for Slot {} fn make_payload(epoch: u64, index: usize) -> Data { @@ -151,7 +157,7 @@ fn read_job( } for (i, slot) in slots.iter().enumerate() { - if let Some(read) = unsafe { slot.try_read() } { + if let Some(read) = slot.try_read() { reads.push(read); let sample: f64 = rng.sample(StandardUniform); @@ -187,6 +193,7 @@ fn retire_job(registry: &Registry, slots: &[Slot], stop_at: u64, active: &Atomic if let Some(drain) = registry.try_advance() { for i in drain { + // SAFETY: retrieving from the drain gives us exclusive access. unsafe { slots[i as usize].make_available() }; } } From 2d022db62d1d7f82e752a733ea83105156f9da5b Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 26 Jun 2026 11:13:11 -0700 Subject: [PATCH 28/45] Clippy! --- diskann-inmem/.clippy.toml | 3 + diskann-inmem/Cargo.toml | 8 +- diskann-inmem/integration/index/runner.rs | 45 ++- diskann-inmem/integration/index/tests.rs | 8 +- .../jsons/integration-baseline.json | 127 +++++++ .../integration/jsons/integration.json | 55 ++- diskann-inmem/integration/store.rs | 63 +++- diskann-inmem/integration/support/check.rs | 327 ++++++++++++++++- diskann-inmem/integration/support/datatype.rs | 331 +++++++++++++++++- diskann-inmem/integration/support/io.rs | 9 +- diskann-inmem/src/epoch.rs | 8 +- diskann-inmem/src/integration/store.rs | 5 + diskann-inmem/src/provider.rs | 162 +++++++-- diskann-inmem/src/store.rs | 50 ++- diskann-inmem/src/test/epoch.rs | 46 ++- 15 files changed, 1150 insertions(+), 97 deletions(-) create mode 100644 diskann-inmem/.clippy.toml diff --git a/diskann-inmem/.clippy.toml b/diskann-inmem/.clippy.toml new file mode 100644 index 000000000..7bada5473 --- /dev/null +++ b/diskann-inmem/.clippy.toml @@ -0,0 +1,3 @@ +allow-unwrap-in-tests = true +allow-expect-in-tests = true +allow-panic-in-tests = true diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml index d61b9f889..00e02dbf2 100644 --- a/diskann-inmem/Cargo.toml +++ b/diskann-inmem/Cargo.toml @@ -28,8 +28,12 @@ rand = { workspace = true, optional = true } diskann-benchmark-core = { workspace = true, optional = true } tokio = { workspace = true, optional = true } -[lints] -workspace = true +[lints.clippy] +undocumented_unsafe_blocks = "warn" +unwrap_used = "warn" +expect_used = "warn" +panic = "warn" +uninlined_format_args = "allow" [dev-dependencies] diskann = { workspace = true, features = ["testing"] } diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs index 68824cec6..ce85e106d 100644 --- a/diskann-inmem/integration/index/runner.rs +++ b/diskann-inmem/integration/index/runner.rs @@ -24,7 +24,7 @@ use crate::{ index::{Counters, Index}, support::{ check::{CheckMatch, Match, check_all_fields}, - datatype::{DataType, Dataset, DatasetView}, + datatype::{self, DataType, Dataset, DatasetView}, io::load_and_convert, tolerance, }, @@ -70,6 +70,31 @@ mod dto { } } + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "kebab-case")] + pub(super) enum Preprocess { + Halve, + Floor, + } + + impl From for datatype::Preprocess { + fn from(op: Preprocess) -> Self { + match op { + Preprocess::Halve => datatype::Preprocess::Halve, + Preprocess::Floor => datatype::Preprocess::Floor, + } + } + } + + impl From<&datatype::Preprocess> for Preprocess { + fn from(op: &datatype::Preprocess) -> Self { + match op { + datatype::Preprocess::Halve => Preprocess::Halve, + datatype::Preprocess::Floor => Preprocess::Floor, + } + } + } + #[derive(Debug, Serialize, Deserialize)] pub(super) struct Data { pub(super) data: InputFile, @@ -77,6 +102,7 @@ mod dto { pub(super) groundtruth: InputFile, pub(super) metric: SerdeMetric, pub(super) data_type: DataType, + pub(super) preprocess: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -121,6 +147,7 @@ struct Data { groundtruth: InputFile, metric: Metric, data_type: DataType, + preprocess: Vec, } impl Data { @@ -131,6 +158,7 @@ impl Data { mut groundtruth, metric, data_type, + preprocess, } = raw; if let Some(checker) = checker { @@ -145,6 +173,7 @@ impl Data { groundtruth, metric: metric.into(), data_type, + preprocess: preprocess.into_iter().map(From::from).collect(), }) } @@ -155,6 +184,7 @@ impl Data { groundtruth: self.groundtruth.clone(), metric: self.metric.try_into()?, data_type: self.data_type, + preprocess: self.preprocess.iter().map(From::from).collect(), }) } @@ -163,14 +193,14 @@ impl Data { let mut io = std::fs::File::open(&*self.data) .with_context(|| format!("could not open {}", self.data.display()))?; - load_and_convert(&mut io, self.data_type, data_type)? + load_and_convert(&mut io, self.data_type, data_type, &self.preprocess)? }; let queries = { let mut io = std::fs::File::open(&*self.queries) .with_context(|| format!("could not open {}", self.queries.display()))?; - load_and_convert(&mut io, self.data_type, data_type)? + load_and_convert(&mut io, self.data_type, data_type, &self.preprocess)? }; let groundtruth = { @@ -348,19 +378,19 @@ impl Test { let index = match start_points { DatasetView::F32(v) => finish( - Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, index_config, ), DatasetView::F16(v) => finish( - Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, index_config, ), DatasetView::U8(v) => finish( - Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, index_config, ), DatasetView::I8(v) => finish( - Provider::new(layers::Full::::new(dim, metric), config, v.row_iter()), + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, index_config, ), }; @@ -407,6 +437,7 @@ impl diskann_benchmark_runner::Input for Test { groundtruth: InputFile::new("path/to/groundtruth"), metric: dto::SerdeMetric::L2, data_type: DataType::F32, + preprocess: vec![], }, layer: dto::Layer::FullPrecision { data_type: DataType::F32, diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs index 61d5166bd..7b0277a0d 100644 --- a/diskann-inmem/integration/index/tests.rs +++ b/diskann-inmem/integration/index/tests.rs @@ -23,8 +23,8 @@ pub(super) fn insert( rt: &tokio::runtime::Handle, ) -> anyhow::Result { let before = index.counters(); - for i in 0..dataset.nrows() { - rt.block_on(index.insert(dataset.row(i).unwrap(), i as u64))?; + for (i, r) in dataset.iter().enumerate() { + rt.block_on(index.insert(r, i as u64))?; } before.delta(&index.counters()) } @@ -48,10 +48,10 @@ pub(super) fn knn( let before = index.counters(); let mut misc = KnnSearch::new(); let mut neighbors = Vec::new(); - for (i, out) in ids.row_iter_mut().enumerate() { + for (out, query) in std::iter::zip(ids.row_iter_mut(), queries.iter()) { neighbors.clear(); - let stats = rt.block_on(index.search(queries.row(i).unwrap(), knn, &mut neighbors))?; + let stats = rt.block_on(index.search(query, knn, &mut neighbors))?; misc += stats; std::iter::zip(out.iter_mut(), neighbors.iter()).for_each(|(d, s)| *d = s.id); diff --git a/diskann-inmem/integration/jsons/integration-baseline.json b/diskann-inmem/integration/jsons/integration-baseline.json index 6f8f5b014..a21a83b8e 100644 --- a/diskann-inmem/integration/jsons/integration-baseline.json +++ b/diskann-inmem/integration/jsons/integration-baseline.json @@ -13,6 +13,7 @@ "data_type": "f32", "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", "metric": "l2", + "preprocess": [], "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" }, "layer": { @@ -133,6 +134,7 @@ "data_type": "f32", "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", "metric": "l2", + "preprocess": [], "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" }, "layer": { @@ -253,6 +255,7 @@ "data_type": "f32", "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", "metric": "l2", + "preprocess": [], "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" }, "layer": { @@ -358,5 +361,129 @@ } ] } + }, + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "preprocess": [ + "halve", + "floor" + ], + "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "i8" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 2449, + "distance": 59803, + "get_neighbors": 14471, + "get_vector": 42075, + "query_distance": 22829, + "set_neighbors": 428, + "set_vector": 256 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 517, + "get_vector": 1434, + "query_distance": 1434, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1434, + "hops": 517 + }, + "recall": { + "average": 0.9, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 523, + "get_vector": 1462, + "query_distance": 1462, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1462, + "hops": 523 + }, + "recall": { + "average": 0.9, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 1016, + "get_vector": 1625, + "query_distance": 1625, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 1625, + "hops": 1016 + }, + "recall": { + "average": 0.9, + "num_queries": 10, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } } ] \ No newline at end of file diff --git a/diskann-inmem/integration/jsons/integration.json b/diskann-inmem/integration/jsons/integration.json index 84369ab38..dd9e6e4bf 100644 --- a/diskann-inmem/integration/jsons/integration.json +++ b/diskann-inmem/integration/jsons/integration.json @@ -18,7 +18,8 @@ "data_type": "f32", "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin" + "queries": "disk_index_sample_query_10pts.fbin", + "preprocess": [] }, "layer": { "FullPrecision": { @@ -60,7 +61,8 @@ "data_type": "f32", "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin" + "queries": "disk_index_sample_query_10pts.fbin", + "preprocess": [] }, "layer": { "FullPrecision": { @@ -102,7 +104,8 @@ "data_type": "f32", "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin" + "queries": "disk_index_sample_query_10pts.fbin", + "preprocess": [] }, "layer": { "FullPrecision": { @@ -129,6 +132,52 @@ ] } } + }, + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 50, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "f32", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "metric": "l2", + "queries": "disk_index_sample_query_10pts.fbin", + "preprocess": [ + "halve", + "floor" + ] + }, + "layer": { + "FullPrecision": { + "data_type": "i8" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } } ] } diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index 7c4d183a2..b8f6d7e8b 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -12,6 +12,11 @@ //! 2. A readable value is stable for the lifetime of a single reader guard. //! 3. A slot never resurrects (`readable -> unreadable -> readable`) within one guard. +#![expect( + clippy::unwrap_used, + reason = "this code works mainly as an integration test" +)] + use std::{ collections::HashMap, io::Write, @@ -74,43 +79,49 @@ pub struct StoreStressInput { seed: u64, } -impl Input for StoreStressInput { - type Raw = Self; - - fn tag() -> &'static str { - "store-stress" - } - - fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { - if raw.readers == 0 || raw.writers == 0 { +impl StoreStressInput { + fn check(self) -> anyhow::Result { + if self.readers == 0 || self.writers == 0 { anyhow::bail!("`readers` and `writers` must be non-zero"); } - if raw.readers >= GUARD_CAPACITY { + if self.readers >= GUARD_CAPACITY { anyhow::bail!( "`readers` ({}) must be below the epoch guard capacity ({GUARD_CAPACITY})", - raw.readers, + self.readers, ); } - if raw.capacity == 0 { + if self.capacity == 0 { anyhow::bail!("`capacity` must be non-zero"); } - if raw.entry_bytes == 0 || raw.entry_bytes % 8 != 0 { + if self.entry_bytes == 0 || !self.entry_bytes.is_multiple_of(8) { anyhow::bail!( "`entry_bytes` ({}) must be a non-zero multiple of 8", - raw.entry_bytes, + self.entry_bytes, ); } - if raw.low_watermark > raw.capacity { + if self.low_watermark > self.capacity { anyhow::bail!( "`low_watermark` ({}) must not exceed `capacity` ({})", - raw.low_watermark, - raw.capacity, + self.low_watermark, + self.capacity, ); } - if raw.duration_secs == 0 && raw.max_ops == 0 { + if self.duration_secs == 0 && self.max_ops == 0 { anyhow::bail!("at least one of `duration_secs` or `max_ops` must be non-zero"); } - Ok(raw) + Ok(self) + } +} + +impl Input for StoreStressInput { + type Raw = Self; + + fn tag() -> &'static str { + "store-stress" + } + + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Self::check(raw) } fn serialize(&self) -> anyhow::Result { @@ -498,3 +509,17 @@ impl Benchmark for StoreStress { Ok(stats) } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn make_sure_example_parses() { + let _ = StoreStressInput::check(StoreStressInput::example()).unwrap(); + } +} diff --git a/diskann-inmem/integration/support/check.rs b/diskann-inmem/integration/support/check.rs index 435f0a012..60fde4b31 100644 --- a/diskann-inmem/integration/support/check.rs +++ b/diskann-inmem/integration/support/check.rs @@ -20,12 +20,12 @@ use std::{ use diskann_benchmark_runner::{benchmark::PassFail, utils::fmt::Table}; use serde::{Serialize, Serializer}; -/// Perform a basline check on `self` and a `previous`ly saved result. +/// Perform a baseline check on `self` and a `previous`ly saved result. pub(crate) trait CheckMatch { fn check_match(&self, previous: &Self) -> Match; } -/// The result of a basline. +/// The result of a basline check. #[must_use = "this is a result type"] #[derive(Debug, Serialize)] #[serde(rename_all = "kebab-case")] @@ -190,6 +190,7 @@ impl<'a> Stack<'a> { Self { s, len: 0 } } + #[expect(clippy::unwrap_used, reason = "formatting shouldn't be failing here")] fn push(&mut self, key: &Key) -> Stack<'_> { let len = self.s.len(); if len == 0 { @@ -228,7 +229,7 @@ struct Record<'a> { /// /// Keys can either be strings or positional indices. The latter are used when traversing /// arrays. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub(crate) enum Key { Str(&'static str), Position(usize), @@ -280,28 +281,37 @@ impl From for Key { // Builder // ///////////// +/// A utility for building a nested [`Match`]. #[derive(Debug)] pub(crate) struct MatchBuilder { children: Vec<(Key, Match)>, } impl MatchBuilder { + /// Construct a new empty collection of matches. pub(crate) fn new() -> Self { Self { children: Vec::new(), } } + /// Push the [`Match`] into the collection only if [`Match::is_ok`] fails. pub(crate) fn push(&mut self, key: Key, child: Match) { if !child.is_ok() { self.children.push((key, child)); } } + /// Package the collection of matches into a single [`Match`]. + /// + /// If no failing matches have been aggregated, returns [`Match::Ok`]. pub(crate) fn finish(self) -> Match { self.finish_with_remark(None) } + /// Package the collection of matches into a single [`Match`] with a remark. + /// + /// If no failing matches have been aggregated, returns [`Match::Ok`]. pub(crate) fn finish_with_remark(self, remark: Option>) -> Match { if self.children.is_empty() { Match::Ok @@ -369,9 +379,9 @@ where } } -//////////// +//--------// // Macros // -//////////// +//--------// macro_rules! check_all_fields { ($self:expr, $prev:expr, { $($field:ident),+ $(,)? } $(,)?) => {{ @@ -391,3 +401,310 @@ macro_rules! check_all_fields { } pub(crate) use check_all_fields; + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + //-------// + // Match // + //-------// + + #[test] + fn match_is_ok() { + assert!(Match::Ok.is_ok()); + assert!(!Match::mismatch(&1, &2).is_ok()); + } + + #[test] + fn mismatch_records_got_and_expected() { + match Match::mismatch(&1, &2) { + Match::Mismatch { + got, + expected, + remark, + } => { + assert_eq!(got, "1"); + assert_eq!(expected, "2"); + assert!(remark.is_none()); + } + other => panic!("expected Mismatch, got {other:?}"), + } + } + + #[test] + fn mismatch_with_remark_records_remark() { + match Match::mismatch_with_remark(&"a", &"b", Some("note".into())) { + Match::Mismatch { + got, + expected, + remark, + } => { + assert_eq!(got, "a"); + assert_eq!(expected, "b"); + assert_eq!(remark.as_deref(), Some("note")); + } + other => panic!("expected Mismatch, got {other:?}"), + } + } + + #[test] + fn pass_fail_follows_is_ok() { + assert!(matches!(Match::Ok.pass_fail(), PassFail::Pass(Match::Ok))); + + assert!(matches!( + Match::mismatch(&1, &2).pass_fail(), + PassFail::Fail(Match::Mismatch { .. }) + )); + + let mut builder = MatchBuilder::new(); + builder.push(Key::from("test"), Match::mismatch(&1, &2)); + builder.push(Key::from("test2"), Match::mismatch(&2, &3)); + let mismatch = builder.finish(); + + assert!(matches!(mismatch, Match::Nested { .. })); + assert!(matches!( + mismatch.pass_fail(), + PassFail::Fail(Match::Nested { .. }) + )); + } + + //------------// + // CheckMatch // + //------------// + + #[test] + fn primitive_check_match() { + assert!(1u32.check_match(&1u32).is_ok()); + assert!(!2u32.check_match(&3u32).is_ok()); + assert!("x".check_match(&"x").is_ok()); + assert!(!"x".check_match(&"y").is_ok()); + } + + #[test] + fn slice_check_match_equal() { + let a = vec![1u32, 2, 3]; + let b = vec![1u32, 2, 3]; + assert!(a.check_match(&b).is_ok()); + } + + #[test] + fn slice_check_match_length_mismatch() { + let a = vec![1u32, 2, 3]; + let b = vec![1u32, 2]; + match a.check_match(&b) { + Match::Mismatch { + got, + expected, + remark, + } => { + assert_eq!(got, "3"); + assert_eq!(expected, "2"); + assert!(remark.is_some()); + } + other => panic!("expected length Mismatch, got {other:?}"), + } + } + + #[test] + fn slice_check_match_element_mismatch() { + let a = vec![1u32, 9, 3]; + let b = vec![1u32, 2, 3]; + match a.check_match(&b) { + Match::Nested { children, .. } => { + assert_eq!(children.len(), 1); + assert!(matches!(children[0].0, Key::Position(1))); + } + other => panic!("expected Nested, got {other:?}"), + } + } + + //--------------// + // MatchBuilder // + //--------------// + + #[test] + fn builder_empty_is_ok() { + assert!(MatchBuilder::new().finish().is_ok()); + } + + #[test] + fn builder_skips_ok_matches() { + let mut builder = MatchBuilder::new(); + builder.push("a".into(), Match::Ok); + builder.push("b".into(), Match::Ok); + assert!(builder.finish().is_ok()); + } + + #[test] + fn builder_collects_failures() { + let mut builder = MatchBuilder::new(); + builder.push("a".into(), Match::Ok); + builder.push("b".into(), Match::mismatch(&1, &2)); + match builder.finish() { + Match::Nested { children, remark } => { + assert_eq!(children.len(), 1); + assert!(remark.is_none()); + } + other => panic!("expected Nested, got {other:?}"), + } + } + + #[test] + fn builder_finish_with_remark() { + let mut builder = MatchBuilder::new(); + builder.push("b".into(), Match::mismatch(&1, &2)); + match builder.finish_with_remark(Some("ctx".into())) { + Match::Nested { remark, .. } => assert_eq!(remark.as_deref(), Some("ctx")), + other => panic!("expected Nested, got {other:?}"), + } + } + + //-----// + // Key // + //-----// + + #[test] + fn key_display() { + assert_eq!(Key::from("field").to_string(), "field"); + assert_eq!(Key::from(7usize).to_string(), "7"); + assert_eq!(Key::from(String::from("owned")).to_string(), "owned"); + } + + #[test] + fn key_serde() { + let k = serde_json::to_value(Key::Str("field")).unwrap(); + assert_eq!(k, serde_json::Value::String("field".into())); + + let k = serde_json::to_value(Key::Position(10)).unwrap(); + assert_eq!(k, serde_json::Value::Number(10.into())); + + let k = serde_json::to_value(Key::String("world".into())).unwrap(); + assert_eq!(k, serde_json::Value::String("world".into())); + } + + //---------// + // Display // + //---------// + + #[test] + fn display_ok() { + assert_eq!(Match::Ok.to_string(), "ok"); + } + + #[test] + fn display_nonnested() { + let mismatch = Match::mismatch_with_remark(&"hello", &1, Some("word".into())); + let rendered = mismatch.to_string(); + + let expected = r#" + got, expected, remark +=========================== +hello, 1, word +"#; + let expected = expected.strip_prefix('\n').unwrap(); + + println!("rendered = {:?}", rendered); + + let mut count = 0; + for (line, (got, expected)) in + std::iter::zip(rendered.lines(), expected.lines()).enumerate() + { + count += 1; + assert_eq!(got.trim(), expected.trim(), "failed on line {line}",); + } + assert_eq!(count, 3); + } + + #[test] + fn display_nested() { + // Build a nested match and ensure the hierarchical path is rendered. + let mut inner = MatchBuilder::new(); + inner.push(1usize.into(), Match::mismatch(&9, &2)); + inner.push( + "test".into(), + Match::mismatch_with_remark(&9, &2, Some("hello".into())), + ); + let nested = inner.finish_with_remark(Some("some remark".into())); + + let mut outer = MatchBuilder::new(); + outer.push("results".into(), nested); + let rendered = outer + .finish_with_remark(Some("final remarks".into())) + .to_string(); + + let expected = r#" + path, got, expected, remark + ================================================ + , , , final remarks + results, , , some remark + results.1, 9, 2, + results.test, 9, 2, hello + "#; + + let expected = expected.strip_prefix('\n').unwrap(); + + println!("rendered = {:?}", rendered); + + let mut count = 0; + for (line, (got, expected)) in + std::iter::zip(rendered.lines(), expected.lines()).enumerate() + { + count += 1; + assert_eq!(got.trim(), expected.trim(), "failed on line {line}",); + } + assert_eq!(count, 6); + } + + //-------------------// + // check_all_fields! // + //-------------------// + + #[derive(Debug)] + struct Sample { + a: u32, + b: String, + } + + impl CheckMatch for Sample { + fn check_match(&self, previous: &Self) -> Match { + check_all_fields!(self, previous, { a, b }).finish() + } + } + + #[test] + fn check_all_fields_equal() { + let x = Sample { + a: 1, + b: "hi".into(), + }; + let y = Sample { + a: 1, + b: "hi".into(), + }; + assert!(x.check_match(&y).is_ok()); + } + + #[test] + fn check_all_fields_reports_changed_field() { + let x = Sample { + a: 1, + b: "hi".into(), + }; + let y = Sample { + a: 1, + b: "bye".into(), + }; + match x.check_match(&y) { + Match::Nested { children, .. } => { + assert_eq!(children.len(), 1); + assert_eq!(children[0].0.to_string(), "b"); + } + other => panic!("expected Nested, got {other:?}"), + } + } +} diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index 6864f434c..dff4528cc 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -5,7 +5,7 @@ use diskann_utils::{ sampling::medoid::ComputeMedoid, - views::{Matrix, MatrixView}, + views::{Matrix, MatrixView, MutMatrixView}, }; use half::f16; use serde::{Deserialize, Serialize}; @@ -258,6 +258,70 @@ impl Dataset { pub(crate) fn medoid(&self) -> Dataset { self.as_view().medoid() } + + pub(crate) fn preprocess(&mut self, op: &Preprocess) { + match self { + Self::F32(m) => op.apply(m.as_mut_view()), + Self::F16(m) => op.apply(m.as_mut_view()), + Self::U8(m) => op.apply(m.as_mut_view()), + Self::I8(m) => op.apply(m.as_mut_view()), + } + } +} + +/// Preprocess steps for [`Dataset`]s. +/// +/// These exist so we can coax `u8` data into a form compatible for testing `i8` data. +#[derive(Debug)] +pub(crate) enum Preprocess { + // Divide each component by 2. + Halve, + // Perform a `floor` operation on the each component. + Floor, +} + +trait Apply { + fn apply(&self, m: MutMatrixView<'_, T>); +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, f32>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| *v *= 0.5), + Self::Floor => m.as_mut_slice().iter_mut().for_each(|v| *v = v.floor()), + } + } +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, f16>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| { + *v = f16::from_f32(f32::from(*v) * 0.5); + }), + Self::Floor => m.as_mut_slice().iter_mut().for_each(|v| { + *v = f16::from_f32(f32::from(*v).floor()); + }), + } + } +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, u8>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| *v /= 2), + Self::Floor => {} + } + } +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, i8>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| *v /= 2), + Self::Floor => {} + } + } } ///////////////// @@ -317,6 +381,34 @@ impl<'a> DatasetView<'a> { Self::I8(v) => Matrix::row_vector(Box::from(i8::compute_medoid(*v))).into(), } } + + pub(crate) fn iter(&self) -> Iter<'_> { + Iter::new(self) + } +} + +pub(crate) struct Iter<'a> { + view: &'a DatasetView<'a>, + row: usize, +} + +impl<'a> Iter<'a> { + fn new(view: &'a DatasetView<'a>) -> Self { + Self { + view, + row: 0, + } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Slice<'a>; + + fn next(&mut self) -> Option> { + let r = self.view.row(self.row)?; + self.row += 1; + Some(r) + } } //------// @@ -363,3 +455,240 @@ define!(f32, F32); define!(f16, F16); define!(u8, U8); define!(i8, I8); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + fn matrix(data: &[T], nrows: usize, ncols: usize) -> Matrix + where + T: Copy, + { + Matrix::try_from(Box::from(data), nrows, ncols).unwrap() + } + + //----------// + // DataType // + //----------// + + #[test] + fn datatype_display() { + assert_eq!(DataType::F32.to_string(), "f32"); + assert_eq!(DataType::F16.to_string(), "f16"); + assert_eq!(DataType::U8.to_string(), "u8"); + assert_eq!(DataType::I8.to_string(), "i8"); + } + + //-------// + // Slice // + //-------// + + #[test] + fn slice_data_type_and_len() { + let f: &[f32] = &[1.0, 2.0, 3.0]; + let s = Slice::from(f); + assert_eq!(s.data_type(), DataType::F32); + assert_eq!(s.len(), 3); + + let u: &[u8] = &[1, 2]; + assert_eq!(Slice::from(u).data_type(), DataType::U8); + assert_eq!(Slice::from(u).len(), 2); + } + + #[test] + fn slice_try_cast_success() { + let f: &[f32] = &[1.0, 2.0]; + let s = Slice::from(f); + let out: &[f32] = s.try_cast().unwrap(); + assert_eq!(out, &[1.0, 2.0]); + } + + #[test] + fn slice_try_cast_wrong_type() { + let f: &[f32] = &[1.0, 2.0]; + let s = Slice::from(f); + let err = s.try_cast::().unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("u8"), "msg: {msg}"); + assert!(msg.contains("f32"), "msg: {msg}"); + } + + //----------// + // SliceMut // + //----------// + + #[test] + fn convert_lossless_same_type() { + let mut dst = [0.0f32; 3]; + let src: &[f32] = &[1.0, 2.0, 3.0]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [1.0, 2.0, 3.0]); + } + + #[test] + fn convert_lossless_widening() { + // u8 -> f32 is always lossless. + let mut dst = [0.0f32; 3]; + let src: &[u8] = &[1, 2, 250]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [1.0, 2.0, 250.0]); + + // i8 -> f16 is always lossless. + let mut dst = [f16::ZERO; 2]; + let src: &[i8] = &[-5, 7]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [f16::from_f32(-5.0), f16::from_f32(7.0)]); + } + + #[test] + fn convert_lossless_narrowing_exact() { + // Whole-valued, in-range f32 -> u8 is lossless. + let mut dst = [0u8; 3]; + let src: &[f32] = &[0.0, 12.0, 255.0]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [0, 12, 255]); + } + + #[test] + fn convert_lossless_narrowing_fraction_errors() { + let mut dst = [0u8; 2]; + let src: &[f32] = &[1.0, 0.5]; + let err = SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap_err(); + assert!(err.to_string().contains("losslessly"), "{err}"); + } + + #[test] + fn convert_lossless_signedness_errors() { + // Negative i8 cannot fit into u8. + let mut dst = [0u8; 2]; + let src: &[i8] = &[5, -1]; + assert!( + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .is_err() + ); + + // u8 > 127 cannot fit into i8. + let mut dst = [0i8; 2]; + let src: &[u8] = &[10, 200]; + assert!( + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .is_err() + ); + } + + #[test] + fn convert_lossless_length_mismatch_errors() { + let mut dst = [0.0f32; 2]; + let src: &[f32] = &[1.0, 2.0, 3.0]; + let err = SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap_err(); + assert!(err.to_string().contains("len"), "{err}"); + } + + //---------// + // Dataset // + //---------// + + #[test] + fn dataset_shape_and_views() { + let ds: Dataset = matrix(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).into(); + assert_eq!(ds.nrows(), 2); + assert_eq!(ds.ncols(), 3); + assert_eq!(ds.as_view().data_type(), DataType::F32); + assert_eq!(ds.as_slice().data_type(), DataType::F32); + assert_eq!(ds.as_slice().len(), 6); + } + + #[test] + fn dataset_medoid_shape() { + let ds: Dataset = matrix(&[1.0f32, 2.0, 3.0, 4.0], 2, 2).into(); + let medoid = ds.medoid(); + assert_eq!(medoid.nrows(), 1); + assert_eq!(medoid.ncols(), 2); + } + + #[test] + fn dataset_preprocess_halve() { + let mut ds: Dataset = matrix(&[2.0f32, 4.0, 6.0, 8.0], 2, 2).into(); + ds.preprocess(&Preprocess::Halve); + let slice: &[f32] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0]); + } + + //------------// + // Preprocess // + //------------// + + #[test] + fn preprocess_floor_f32() { + let mut ds: Dataset = matrix(&[1.7f32, 2.2, 3.9, 4.0], 1, 4).into(); + ds.preprocess(&Preprocess::Floor); + let slice: &[f32] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn preprocess_floor_integer_is_noop() { + let mut ds: Dataset = matrix(&[3u8, 7, 9, 11], 1, 4).into(); + ds.preprocess(&Preprocess::Floor); + let slice: &[u8] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[3, 7, 9, 11]); + } + + #[test] + fn preprocess_halve_integer() { + let mut ds: Dataset = matrix(&[4i8, 7, 8, 10], 1, 4).into(); + ds.preprocess(&Preprocess::Halve); + let slice: &[i8] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[2, 3, 4, 5]); + } + + //-------------// + // DatasetView // + //-------------// + + #[test] + fn dataset_view_accessors() { + let ds: Dataset = matrix(&[1u8, 2, 3, 4, 5, 6], 2, 3).into(); + let view = ds.as_view(); + assert_eq!(view.data_type(), DataType::U8); + assert_eq!(view.nrows(), 2); + assert_eq!(view.ncols(), 3); + } + + #[test] + fn dataset_view_row() { + let ds: Dataset = matrix(&[1u8, 2, 3, 4, 5, 6], 2, 3).into(); + let view = ds.as_view(); + + let row1: &[u8] = view.row(1).unwrap().try_cast().unwrap(); + assert_eq!(row1, &[4, 5, 6]); + + assert!(view.row(2).is_none()); + } + + #[test] + fn dataset_view_medoid() { + let ds: Dataset = matrix(&[1i8, 2, 3, 4], 2, 2).into(); + let medoid = ds.as_view().medoid(); + assert_eq!(medoid.nrows(), 1); + assert_eq!(medoid.ncols(), 2); + } +} diff --git a/diskann-inmem/integration/support/io.rs b/diskann-inmem/integration/support/io.rs index 335ead7cf..7b45018b5 100644 --- a/diskann-inmem/integration/support/io.rs +++ b/diskann-inmem/integration/support/io.rs @@ -6,23 +6,28 @@ use diskann_utils::{io::read_bin, views::Matrix}; use half::f16; -use super::datatype::{DataType, Dataset, SliceMut}; +use super::datatype::{DataType, Dataset, Preprocess, SliceMut}; pub(crate) fn load_and_convert( io: &mut IO, src: DataType, target: DataType, + ops: &[Preprocess], ) -> anyhow::Result where IO: std::io::Read + std::io::Seek, { - let data = match src { + let mut data = match src { DataType::F32 => Dataset::from(read_bin::(io)?), DataType::F16 => Dataset::from(read_bin::(io)?), DataType::U8 => Dataset::from(read_bin::(io)?), DataType::I8 => Dataset::from(read_bin::(io)?), }; + for op in ops { + data.preprocess(op); + } + if src == target { return Ok(data); } diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs index ec006d4b0..3eb77b558 100644 --- a/diskann-inmem/src/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -164,7 +164,9 @@ impl Registry { let m = &self.guards[slot]; delay.pre_cas(); - if m.compare_exchange(0, epoch, Ordering::Relaxed, Ordering::Relaxed).is_ok() { + if m.compare_exchange(0, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { delay.post_cas(); let mut reset = false; loop { @@ -271,6 +273,10 @@ impl Registry { self.try_advance_inner(NoDelay) } + #[expect( + clippy::panic, + reason = "the panic is exceedingly unlikely to happen and if it does, we can't continue" + )] fn try_advance_inner(&self, mut delay: T) -> Option> where T: TryAdvanceDelay, diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs index eba8a9632..d48238b6f 100644 --- a/diskann-inmem/src/integration/store.rs +++ b/diskann-inmem/src/integration/store.rs @@ -3,6 +3,11 @@ * Licensed under the MIT license. */ +#![expect( + clippy::expect_used, + reason = "integration test tools are not production code", +)] + use diskann_utils::views::Matrix; use crate::{num::Bytes, store}; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 42819432a..7fbc85468 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -3,6 +3,33 @@ * Licensed under the MIT license. */ +//! An in-memory provider for the DiskANN graph index. +//! +//! This type supports the following: +//! +//! * Arbitrary external IDs for store data (provided they satisfy [`Id`]. +//! * Support for concurrent insertions, deletions, and searches. +//! * Specialized implementations of [`glue::SearchAccessor::expand_beam`] enabling full +//! inlining of distance kernels. +//! +//! Known areas for future work: +//! +//! * Insert and delete protection: The [`DiskANNIndex`](diskann::graph::DiskANNIndex) doesn't +//! support ergonomic insert or delete guards to protect slots during insert or delete +//! operations. This leaves open a situation where an item can be inserted and during +//! the insertion algorithm, it is deleted, and then re-inserted. +//! +//! This can cause some issue within the main indexing algorithms which assume the inserted +//! ID is present but requires upstream changes to properly fix. +//! +//! * Failed insert rollback: again, this needs some upstream changes to full support. +//! +//! * Quantization + reranking: Ths current version of this index targets just a single +//! data-store and is planned to be addressed in the near future. +//! +//! * Lack of save/load support: The index is currently ephemeral, but there are plans to +//! address this gap. + use std::hash::Hash; use diskann::{ @@ -17,6 +44,7 @@ use diskann::{ utils::IntoUsize, }; use diskann_utils::views::Matrix; +use thiserror::Error; use crate::{ counters::{Counters, LocalCounters}, @@ -26,16 +54,26 @@ use crate::{ store::{self, Store}, }; +/// Aggregate trait for the external ID type of [`Provider`]. pub trait Id: Send + Sync + Hash + Eq + Clone + 'static {} + impl Id for T where T: Send + Sync + Hash + Eq + Clone + 'static {} +/// An in-memory data-provider for DiskANN's graph indexing algorithms. +/// +/// The first type parameter `L` is a [`layers::Layer`] for describing the kind of data +/// stored within the provider. The second parameter `M` is the associated data for items +/// inserted into the provider. #[derive(Debug)] pub struct Provider where M: Id, { + // The raw binary store store: Store, + // Data representation. layer: L, + // ID translation. mapping: Sharded, // `Counters` is only non-trivial under the `integration-test` feature flag. Otherwise, @@ -47,7 +85,10 @@ impl Provider where M: Id, { - pub fn new(layer: L, config: Config, start_points: I) -> Self + /// Construct a new [`Provider`]. + /// + /// The list of `start_points` must be must be compatible with `layer`. + pub fn new(layer: L, config: Config, start_points: I) -> Result where I: IntoIterator, L: layers::Set, @@ -57,7 +98,7 @@ where let mut data = Matrix::new(0u8, start_points.len(), bytes.value()); for (row, point) in std::iter::zip(data.row_iter_mut(), start_points.into_iter()) { - layers::Set::set(&layer, point, row).unwrap(); + layers::Set::set(&layer, point, row)?; } let store = Store::new( @@ -66,16 +107,16 @@ where config.max_degree(), data.as_view(), ) - .unwrap(); + .map_err(|err| ProviderError::CreatingStore(Box::new(err)))?; let mapping = Sharded::new(config.capacity()); - Self { + Ok(Self { store, layer, mapping, counters: Counters::new(), - } + }) } fn local_counters(&self) -> LocalCounters<'_> { @@ -94,6 +135,15 @@ where } } +#[derive(Debug, Error)] +pub enum ProviderError { + #[error("error when trying to set start points")] + SettingStartPoints(#[from] ANNError), + #[error("could not create data store")] + CreatingStore(#[source] Box), +} + +/// Configuration for [`Provider`]. #[derive(Debug)] pub struct Config { capacity: usize, @@ -101,6 +151,10 @@ pub struct Config { } impl Config { + /// Construct a new [`Config`]. + /// + /// * `capacity`: The number of dynamic entries in the resulting provider. + /// * `max_degree`: The maximum degree of any adjacency list in the graph. pub fn new(capacity: usize, max_degree: usize) -> Self { Self { capacity, @@ -108,10 +162,12 @@ impl Config { } } + /// Return the number of dynamic entries in the resulting provider. pub fn capacity(&self) -> usize { self.capacity } + /// Return the maximum degree of any adjacency list. pub fn max_degree(&self) -> usize { self.max_degree } @@ -121,6 +177,7 @@ impl Config { // Data Provider // /////////////////// +/// A zero-sied [`diskann::provider::ExecutionContext`] for [`Provider`]. #[derive(Debug, Clone, Default)] pub struct Context; @@ -161,8 +218,8 @@ where } } -// TODO: The element-status checks here are profoundly expensive as they require epoch -// registration for each check! +// TODO: The element-status checks here are profoundly approximate because we try to avoid +// any kind of EBR registration. // // `diskann` has plans to move deletion checks behind an accessor trait, which will help // with this situation. @@ -172,7 +229,10 @@ where M: Id, { async fn delete(&self, _context: &Context, gid: &M) -> ANNResult<()> { - // TODO: These need to actually happen in lock-step. + // This guarantees that we have a valid mapping, but defers the actual deletion until + // we know it's also safe to retire the internal slot. + // + // This ensures both either succeed or are aborted. let entry = match self.mapping.occupied_entry(gid.clone()) { None => { return Err(ANNError::message( @@ -204,14 +264,16 @@ where ) -> ANNResult { // Not that this check is approximate. A full check requires materialization of // a `reader`. - if self.store.can_read_approximate(id.into_usize()).unwrap() { - Ok(diskann::provider::ElementStatus::Valid) - } else { - Ok(diskann::provider::ElementStatus::Deleted) + match self.store.can_read_approximate(id.into_usize()) { + Some(true) => Ok(diskann::provider::ElementStatus::Valid), + Some(false) => Ok(diskann::provider::ElementStatus::Deleted), + None => Err(ANNError::message( + ANNErrorKind::Opaque, + "accessed invalid internal ID", + )), } } - /// Check the status via external ID. async fn status_by_external_id( &self, _context: &Context, @@ -251,7 +313,7 @@ where })?; // TODO: Proper cleanup via `Guard` or some other mechanism on the event of - // insert failure. + // insert failure after `set_element` returns. >::set(&self.layer, element, slot.as_mut_slice())?; self.mapping.insert(id.clone(), slot.slot())?; @@ -276,11 +338,16 @@ where // Search // //////////// +/// A [`glue::SearchAccessor`] for [`Provider`]. +/// +/// This type intentionally avoids generic parameters and instead compiles optimized +/// `expand_beam` kernels that get reused. The idea is to generate an efficient graph search +/// kernel once and reuse it to balance compile times and performance. #[derive(Debug)] pub struct SearchAccessor<'a> { reader: store::Reader<'a>, ids: AdjacencyList, - expand_beam: Box, + expand_beam: Box, // The parent provider for the accessor. provider: &'a (dyn std::any::Any + Send + Sync), @@ -343,7 +410,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { { let work = move || -> ANNResult<()> { for i in ids { - self.reader.neighbors().get(i, &mut self.ids).unwrap(); + self.reader.neighbors().get(i, &mut self.ids)?; self.counters.get_neighbors(1); // Filter out unvisited IDs and ensure that all the IDs we are about @@ -359,6 +426,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { on_neighbors(id, distance); }; + // SAFETY: We've verified that each entry in `self.ids` is in-bounds. unsafe { self.expand_beam .expand_beam(&self.ids, 8, &self.reader, &mut on_neighbors) @@ -372,10 +440,15 @@ impl glue::SearchAccessor for SearchAccessor<'_> { } } -trait ExpandBeam2: Send + Sync + std::fmt::Debug { +trait ExpandBeam: Send + Sync + std::fmt::Debug { /// Evaluate a raw distance function. fn evaluate(&self, x: &[u8]) -> ANNResult; + /// Compute the distance between the query and each neighbor in `list`. + /// + /// # Safety + /// + /// All items in `list` must in-bounds with respect to `reader`. unsafe fn expand_beam( &self, list: &[u32], @@ -389,7 +462,7 @@ trait ExpandBeam2: Send + Sync + std::fmt::Debug { #[repr(transparent)] struct ExpandBeamImpl(T); -impl ExpandBeam2 for ExpandBeamImpl +impl ExpandBeam for ExpandBeamImpl where T: layers::QueryDistance, { @@ -404,6 +477,7 @@ where reader: &store::Reader<'_>, f: &mut dyn FnMut(u32, f32), ) -> ANNResult<()> { + // SAFETY: Inherited from caller. unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, f) } } } @@ -414,13 +488,13 @@ struct ExpandBeamVisitor { } impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { - type Output = Box; + type Output = Box; fn visit_sized(self, distance: T) -> Self::Output where T: QueryDistance + 'a, { - // Make sure there's no lying. + // This is critical to ensure we emit the correct number of prefetches. assert_eq!(Bytes::new(BYTES + store::TAG_SIZE.value()), self.bytes); Box::new(ExpandBeamImpl::<_, BYTES>(distance)) } @@ -433,6 +507,13 @@ impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { } } +/// Prefetch `len` bytes beginning at `ptr`. +/// +/// The last cache line prefetched first, followed by the rest in ascending order. +/// +/// # Safety +/// +/// The memory range `[ptr, ptr.add(len))` must be valid. #[inline(always)] unsafe fn prefetch(ptr: *const u8, len: usize) { use std::arch::x86_64::*; @@ -445,17 +526,18 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { return; } + // SAFETY: Inherited from caller. unsafe { _mm_prefetch(ptr.add(stride * (lines - 1)), _MM_HINT_T0) }; for i in 0..(lines - 1) { + // SAFETY: Inherited from caller. unsafe { _mm_prefetch(ptr.add(stride * i), _MM_HINT_T0); } } } -/// Safety (no # yet because we need to revisit this - clippy will lint) +/// # Safety /// -/// * The concrete type of `distance` must be `T`. /// * All items in `list` must in-bounds with respect to `reader`. /// * The number of bytes associated with `N` cache lines must "make sense". #[inline] @@ -486,6 +568,8 @@ where let lookahead = lookahead.min(len); for j in 0..lookahead { + // SAFETY: The in-bounds constraint is assured by the caller, both for `j` as well + // as the validity of the prefetch bounds. unsafe { prefetch( reader @@ -500,6 +584,8 @@ where let mut j = lookahead; for &i in list.iter() { if j != len { + // SAFETY: The in-bounds constraint is assured by the caller, both for `j` as + // well as the validity of the prefetch bounds. unsafe { prefetch( reader @@ -512,6 +598,7 @@ where j += 1; } + // SAFETY: Caller asserts that `i` is in-bounds. if let Some(data) = unsafe { reader.read_in_bounds(i.into_usize()) } { f(i, distance.evaluate(data)?) } @@ -547,6 +634,10 @@ impl<'a> Distance<'a> { } } +#[expect( + clippy::unwrap_used, + reason = "prune does not allow fallible distance functions yet" +)] impl diskann_vector::DistanceFunction<&[u8], &[u8], f32> for Distance<'_> { #[inline] fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { @@ -713,8 +804,8 @@ pub fn test_function<'a>( strategy: &'a Strategy, context: &'a Context, query: &'a [u8], -) -> SearchAccessor<'a> { - glue::SearchStrategy::search_accessor(strategy, x, context, query).unwrap() +) -> ANNResult> { + glue::SearchStrategy::search_accessor(strategy, x, context, query) } #[derive(Debug, Clone, Copy)] @@ -746,7 +837,11 @@ where { let work = move || { // By construction - the downcast should succeed. Otherwise, this is a program bug. - let provider = accessor.provider.downcast_ref::>().unwrap(); + let provider = match accessor.provider.downcast_ref::>() { + Some(provider) => provider, + None => return Err(ANNError::message(ANNErrorKind::Opaque, "bad any cast")), + }; + let mut count = 0; for c in candidates { if let Some(ext) = provider.mapping.to_external(c.id) { @@ -839,9 +934,19 @@ where ) -> impl Future> + Send { let work = move || { - let reader = provider.store.reader().unwrap(); + let reader = provider.store.reader()?; + let data = match reader.read(id.into_usize()) { + Some(data) => data, + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "item could not be read", + )); + } + }; + let mut buf: Box<[_]> = std::iter::repeat_n(0.0, provider.layer.dim()).collect(); - let data = reader.read(id.into_usize()).unwrap(); + bytemuck::must_cast_slice_mut::(&mut buf).copy_from_slice(data); Ok(buf) }; @@ -896,7 +1001,8 @@ mod tests { let config = Config::new(grid.num_points(size), degree); - let provider = Provider::<_, u64>::new(full, config, std::iter::once(start.as_slice())); + let provider = + Provider::<_, u64>::new(full, config, std::iter::once(start.as_slice())).unwrap(); assert_eq!(provider.max_degree(), degree); let config = diskann::graph::config::Builder::new( diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 93480d9dc..2b92d3e72 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -62,10 +62,10 @@ use diskann_utils::views::MatrixView; use thiserror::Error; use crate::{ - buffer::{Buffer, RawSlice}, + buffer::{Buffer, BufferError, RawSlice}, epoch::{self, Registry}, freelist::{self, Freelist}, - neighbors::Neighbors, + neighbors::{Neighbors, NeighborsError}, num::{Align, Bytes}, tag::{AtomicTag, Tag}, }; @@ -124,6 +124,10 @@ impl Store { return Err(StoreError::need_frozen_point()); } + #[expect( + clippy::expect_used, + reason = "we expect `init` to have at least one row, so this should never happen" + )] let unpadded = bytes .checked_add(TAG_SIZE) .expect("unreachable because `init` cannot exceed `isize::MAX` bytes"); @@ -131,6 +135,10 @@ impl Store { // Pad to half a cache line. When data occupies just part of a cache line, this // results in the same total number of cache lines being fetched while potentially // enabling more compact memory. + #[expect( + clippy::expect_used, + reason = "we expect `init` to have at least one row, so this should never happen" + )] let padded_bytes = unpadded .checked_next_multiple_of(Bytes::CACHELINE.div(TWO)) .expect("unreachabel because `init` cannot exceed `isize::MAX` bytes"); @@ -150,8 +158,10 @@ impl Store { .try_into() .map_err(|_| StoreError::too_many_neighbors(max_neighbors))?; + const FREELIST_SIZE: NonZeroU32 = NonZeroU32::new(1024).unwrap(); + let me = Self { - buffer: Buffer::new(total.into_usize(), padded_bytes, Align::_128).unwrap(), + buffer: Buffer::new(total.into_usize(), padded_bytes, Align::_128)?, unpadded, unfrozen: entries.into_usize(), tags: repeat_n(Tag::AVAILABLE, total.into_usize()) @@ -160,15 +170,16 @@ impl Store { // NOTE: The `Freelist` is initialized to `entries` and not `total` because // we do not want it to release frozen IDs. - freelist: Freelist::new(entries, NonZeroU32::new(1024).unwrap()), + freelist: Freelist::new(entries, FREELIST_SIZE), registry: Registry::new(), - neighbors: Neighbors::new(total, max_neighbors).unwrap(), + neighbors: Neighbors::new(total, max_neighbors)?, }; // Populate frozen points. for (i, data) in init.row_iter().enumerate() { // We have checked that the total number of entries fits in `u32`, so this // arithmetic cannot overflow. + #[expect(clippy::expect_used, reason = "this should always succeed")] let mut slot = me .slot(entries + (i as u32)) .expect("store was just created - claiming the slot must succeed"); @@ -199,6 +210,7 @@ impl Store { /// /// If successful, returns the number of slots reclaimed. pub(crate) fn try_drain(&self) -> Option { + #[expect(clippy::panic, reason = "we cannot proceed if we observe this")] fn release(tag: &AtomicTag, kind: &'static str) { // Relaxed ordering is sufficient as all readers/writers are synchronized on // the central generation. @@ -336,7 +348,14 @@ impl Store { remaining = remaining.saturating_sub(chunk.len()); for slot in chunk { - let tag = self.tags.get(slot.into_usize()).unwrap(); + #[expect( + clippy::expect_used, + reason = "this is a serious bug with the freelist" + )] + let tag = self + .tags + .get(slot.into_usize()) + .expect("freelist scan should not give out invalid IDs"); // If this slot is available and we haven't claimed a slot yet, try to // claim it. Otherwise, continue with the scan to partially repopulate the @@ -369,7 +388,8 @@ impl Store { } fn slot(&self, i: u32) -> Option> { - let tag = &self.tags.get(i.into_usize()).unwrap(); + let tag = &self.tags.get(i.into_usize())?; + // SAFETY: We've guaranteed that `tag` belongs to `slot`. unsafe { self.try_acquire(tag, i) } } @@ -460,6 +480,18 @@ impl StoreError { } } +impl From for StoreError { + fn from(err: BufferError) -> Self { + Self(err.into()) + } +} + +impl From for StoreError { + fn from(err: NeighborsError) -> Self { + Self(err.into()) + } +} + #[derive(Debug, Error)] enum StoreErrorInner { #[error( @@ -478,6 +510,10 @@ enum StoreErrorInner { TooManyEntries { entries: usize, frozen: usize }, #[error("number of neighbors ({}) may not exceed `u32::MAX`", neighbors)] TooManyNeighbors { neighbors: usize }, + #[error(transparent)] + BufferError(#[from] BufferError), + #[error(transparent)] + NeighborsError(#[from] NeighborsError), } /// Error conditions for [`Store::retire`]. diff --git a/diskann-inmem/src/test/epoch.rs b/diskann-inmem/src/test/epoch.rs index c6fa9104e..f97f4dfe3 100644 --- a/diskann-inmem/src/test/epoch.rs +++ b/diskann-inmem/src/test/epoch.rs @@ -37,12 +37,16 @@ impl Slot { where F: FnOnce(), { - if self.tag.compare_exchange( - Tag::AVAILABLE, - Tag::OWNED, - Ordering::Acquire, - Ordering::Relaxed, - ).is_ok() { + if self + .tag + .compare_exchange( + Tag::AVAILABLE, + Tag::OWNED, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { // SAFETY: By transitioning from AVAILABLE to OWNED, we've acquired ownership // of this slot and are thus free to write to the `UnsafeCell`. unsafe { &mut *self.payload.get() }.write(Box::new(payload)); @@ -70,12 +74,14 @@ impl Slot { return false; } - self.tag.compare_exchange( - Tag::PUBLISHED, - Tag::RETIRING, - Ordering::Relaxed, - Ordering::Relaxed, - ).is_ok() + self.tag + .compare_exchange( + Tag::PUBLISHED, + Tag::RETIRING, + Ordering::Relaxed, + Ordering::Relaxed, + ) + .is_ok() } unsafe fn make_available(&self) { @@ -84,12 +90,16 @@ impl Slot { // SAFETY: Items tagged as `RETIRING` must be initialized. unsafe { (&mut *self.payload.get()).assume_init_drop() }; - if self.tag.compare_exchange( - Tag::RETIRING, - Tag::AVAILABLE, - Ordering::Release, - Ordering::Relaxed, - ).is_err() { + if self + .tag + .compare_exchange( + Tag::RETIRING, + Tag::AVAILABLE, + Ordering::Release, + Ordering::Relaxed, + ) + .is_err() + { panic!("concurrency violation"); } } From 366aee2e1a429ebdccdb3a1795306de184b0fc53 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 26 Jun 2026 11:22:26 -0700 Subject: [PATCH 29/45] We're getting there! --- Cargo.lock | 2 +- diskann-benchmark/src/index/inmem2.rs | 4 +- .../jsons/integration-baseline.json | 256 +++++++++--------- .../integration/jsons/integration.json | 34 +-- diskann-inmem/integration/support/datatype.rs | 5 +- diskann-inmem/src/integration/store.rs | 2 +- diskann-inmem/src/layers/full.rs | 4 +- diskann-inmem/src/lib.rs | 2 + 8 files changed, 154 insertions(+), 155 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c76fd4bf7..6110f576c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -822,7 +822,7 @@ dependencies = [ "diskann-wide", "half", "parking_lot", - "rand 0.9.4", + "rand", "serde", "serde_json", "tempfile", diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 9e703d4d2..b92b63f14 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -161,7 +161,7 @@ impl Benchmark for Inmem2 { let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; let layer = Full::::new(dim, metric); let config = diskann_inmem::provider::Config::new(num_points, exact_max_degree); - let provider = Provider::new(layer, config, start.row_iter()); + let provider = Provider::new(layer, config, start.row_iter())?; let config = graph::config::Builder::new_with( input.max_degree, @@ -425,7 +425,7 @@ impl Benchmark for Inmem2Stream { let layer = Full::::new(dim, metric); let config = diskann_inmem::provider::Config::new(max_points, exact_max_degree); - let provider = Provider::new(layer, config, start.row_iter()); + let provider = Provider::new(layer, config, start.row_iter())?; let config = graph::config::Builder::new_with( input.max_degree, diff --git a/diskann-inmem/integration/jsons/integration-baseline.json b/diskann-inmem/integration/jsons/integration-baseline.json index a21a83b8e..453375e2e 100644 --- a/diskann-inmem/integration/jsons/integration-baseline.json +++ b/diskann-inmem/integration/jsons/integration-baseline.json @@ -4,17 +4,17 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data": "/yfcc/yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "/yfcc/groundtruth.bin", "metric": "l2", "preprocess": [], - "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + "queries": "/yfcc/yfcc_query_100.fbin" }, "layer": { "FullPrecision": { @@ -45,32 +45,32 @@ }, "results": { "build": { - "append_neighbors": 2447, - "distance": 59957, - "get_neighbors": 14477, - "get_vector": 42151, - "query_distance": 22813, - "set_neighbors": 430, - "set_vector": 256 + "append_neighbors": 96949, + "distance": 2867876, + "get_neighbors": 352139, + "get_vector": 3067092, + "query_distance": 2240744, + "set_neighbors": 23599, + "set_vector": 10000 }, "knn": [ { "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 514, - "get_vector": 1449, - "query_distance": 1449, + "get_neighbors": 5441, + "get_vector": 44988, + "query_distance": 44988, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1449, - "hops": 514 + "cmps": 44988, + "hops": 5441 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.975, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -79,19 +79,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 519, - "get_vector": 1450, - "query_distance": 1450, + "get_neighbors": 5806, + "get_vector": 49075, + "query_distance": 49075, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1450, - "hops": 519 + "cmps": 49075, + "hops": 5806 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.974, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -100,19 +100,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 1013, - "get_vector": 1615, - "query_distance": 1615, + "get_neighbors": 10634, + "get_vector": 74001, + "query_distance": 74001, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1615, - "hops": 1013 + "cmps": 74001, + "hops": 10634 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.992, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -125,17 +125,17 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data": "/yfcc/yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "/yfcc/groundtruth.bin", "metric": "l2", "preprocess": [], - "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + "queries": "/yfcc/yfcc_query_100.fbin" }, "layer": { "FullPrecision": { @@ -166,32 +166,32 @@ }, "results": { "build": { - "append_neighbors": 2447, - "distance": 59957, - "get_neighbors": 14477, - "get_vector": 42151, - "query_distance": 22813, - "set_neighbors": 430, - "set_vector": 256 + "append_neighbors": 96949, + "distance": 2867876, + "get_neighbors": 352139, + "get_vector": 3067092, + "query_distance": 2240744, + "set_neighbors": 23599, + "set_vector": 10000 }, "knn": [ { "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 514, - "get_vector": 1449, - "query_distance": 1449, + "get_neighbors": 5441, + "get_vector": 44988, + "query_distance": 44988, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1449, - "hops": 514 + "cmps": 44988, + "hops": 5441 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.975, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -200,19 +200,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 519, - "get_vector": 1450, - "query_distance": 1450, + "get_neighbors": 5806, + "get_vector": 49075, + "query_distance": 49075, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1450, - "hops": 519 + "cmps": 49075, + "hops": 5806 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.974, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -221,19 +221,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 1013, - "get_vector": 1615, - "query_distance": 1615, + "get_neighbors": 10634, + "get_vector": 74001, + "query_distance": 74001, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1615, - "hops": 1013 + "cmps": 74001, + "hops": 10634 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.992, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -246,17 +246,17 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data": "/yfcc/yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "/yfcc/groundtruth.bin", "metric": "l2", "preprocess": [], - "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + "queries": "/yfcc/yfcc_query_100.fbin" }, "layer": { "FullPrecision": { @@ -287,32 +287,32 @@ }, "results": { "build": { - "append_neighbors": 2447, - "distance": 59957, - "get_neighbors": 14477, - "get_vector": 42151, - "query_distance": 22813, - "set_neighbors": 430, - "set_vector": 256 + "append_neighbors": 96949, + "distance": 2867876, + "get_neighbors": 352139, + "get_vector": 3067092, + "query_distance": 2240744, + "set_neighbors": 23599, + "set_vector": 10000 }, "knn": [ { "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 514, - "get_vector": 1449, - "query_distance": 1449, + "get_neighbors": 5441, + "get_vector": 44988, + "query_distance": 44988, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1449, - "hops": 514 + "cmps": 44988, + "hops": 5441 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.975, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -321,19 +321,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 519, - "get_vector": 1450, - "query_distance": 1450, + "get_neighbors": 5806, + "get_vector": 49075, + "query_distance": 49075, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1450, - "hops": 519 + "cmps": 49075, + "hops": 5806 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.974, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -342,19 +342,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 1013, - "get_vector": 1615, - "query_distance": 1615, + "get_neighbors": 10634, + "get_vector": 74001, + "query_distance": 74001, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1615, - "hops": 1013 + "cmps": 74001, + "hops": 10634 }, "recall": { - "average": 0.91, - "num_queries": 10, + "average": 0.992, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -367,20 +367,20 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin", + "data": "/yfcc/yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "/yfcc/groundtruth.bin", "metric": "l2", "preprocess": [ "halve", "floor" ], - "queries": "/disk_index_search/disk_index_sample_query_10pts.fbin" + "queries": "/yfcc/yfcc_query_100.fbin" }, "layer": { "FullPrecision": { @@ -411,32 +411,32 @@ }, "results": { "build": { - "append_neighbors": 2449, - "distance": 59803, - "get_neighbors": 14471, - "get_vector": 42075, - "query_distance": 22829, - "set_neighbors": 428, - "set_vector": 256 + "append_neighbors": 97055, + "distance": 2867292, + "get_neighbors": 352087, + "get_vector": 3064106, + "query_distance": 2238420, + "set_neighbors": 23587, + "set_vector": 10000 }, "knn": [ { "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 517, - "get_vector": 1434, - "query_distance": 1434, + "get_neighbors": 5446, + "get_vector": 44805, + "query_distance": 44805, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1434, - "hops": 517 + "cmps": 44805, + "hops": 5446 }, "recall": { - "average": 0.9, - "num_queries": 10, + "average": 0.961, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -445,19 +445,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 523, - "get_vector": 1462, - "query_distance": 1462, + "get_neighbors": 5771, + "get_vector": 48508, + "query_distance": 48508, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1462, - "hops": 523 + "cmps": 48508, + "hops": 5771 }, "recall": { - "average": 0.9, - "num_queries": 10, + "average": 0.9590000000000002, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } @@ -466,19 +466,19 @@ "counters": { "append_neighbors": 0, "distance": 0, - "get_neighbors": 1016, - "get_vector": 1625, - "query_distance": 1625, + "get_neighbors": 10638, + "get_vector": 74034, + "query_distance": 74034, "set_neighbors": 0, "set_vector": 0 }, "misc": { - "cmps": 1625, - "hops": 1016 + "cmps": 74034, + "hops": 10638 }, "recall": { - "average": 0.9, - "num_queries": 10, + "average": 0.9680000000000004, + "num_queries": 100, "recall_k": 10, "recall_n": 10 } diff --git a/diskann-inmem/integration/jsons/integration.json b/diskann-inmem/integration/jsons/integration.json index dd9e6e4bf..57f2a338c 100644 --- a/diskann-inmem/integration/jsons/integration.json +++ b/diskann-inmem/integration/jsons/integration.json @@ -1,6 +1,6 @@ { "search_directories": [ - "disk_index_search" + "yfcc" ], "output_directory": null, "jobs": [ @@ -9,16 +9,16 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data": "yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "groundtruth.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin", + "queries": "yfcc_query_100.fbin", "preprocess": [] }, "layer": { @@ -52,16 +52,16 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data": "yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "groundtruth.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin", + "queries": "yfcc_query_100.fbin", "preprocess": [] }, "layer": { @@ -95,16 +95,16 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data": "yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "groundtruth.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin", + "queries": "yfcc_query_100.fbin", "preprocess": [] }, "layer": { @@ -138,16 +138,16 @@ "content": { "build": { "alpha": 1.2000000476837158, - "l_build": 50, + "l_build": 20, "max_degree": 20, "pruned_degree": 16 }, "data": { - "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data": "yfcc_10k.fbin", "data_type": "f32", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "groundtruth": "groundtruth.bin", "metric": "l2", - "queries": "disk_index_sample_query_10pts.fbin", + "queries": "yfcc_query_100.fbin", "preprocess": [ "halve", "floor" diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs index dff4528cc..fe61de539 100644 --- a/diskann-inmem/integration/support/datatype.rs +++ b/diskann-inmem/integration/support/datatype.rs @@ -394,10 +394,7 @@ pub(crate) struct Iter<'a> { impl<'a> Iter<'a> { fn new(view: &'a DatasetView<'a>) -> Self { - Self { - view, - row: 0, - } + Self { view, row: 0 } } } diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs index d48238b6f..ecfe22878 100644 --- a/diskann-inmem/src/integration/store.rs +++ b/diskann-inmem/src/integration/store.rs @@ -5,7 +5,7 @@ #![expect( clippy::expect_used, - reason = "integration test tools are not production code", + reason = "integration test tools are not production code" )] use diskann_utils::views::Matrix; diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 83ceb5b20..ccc0b75b6 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -50,12 +50,12 @@ where Self { distance, metric } } - /// Return the logical dimension of the data handled by this [`Layer`]. + /// Return the logical dimension of the data handled by this [`layers::Layer`]. pub fn dim(&self) -> usize { self.distance.dim } - /// Return the number of bytes of the data handles by this [`Layer`]. + /// Return the number of bytes of the data handles by this [`layers::Layer`]. pub fn bytes(&self) -> Bytes { Bytes::new(self.dim() * std::mem::size_of::()) } diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index 5a2934761..55d5b4653 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -3,6 +3,8 @@ * Licensed under the MIT license. */ +//! The inmem index for DiskANN. + #![deny(rustdoc::broken_intra_doc_links)] pub mod num; From 03ac022ee9349958ad1e3a992f589d3b8337e8a4 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 26 Jun 2026 16:57:24 -0700 Subject: [PATCH 30/45] Checkpoint. --- diskann-benchmark/src/index/benchmarks.rs | 8 +- diskann-benchmark/src/index/inmem2.rs | 554 ++++++++++++------ diskann-inmem/integration/index/object.rs | 3 +- diskann-inmem/src/layers/full.rs | 122 +++- diskann-inmem/src/layers/mod.rs | 2 +- diskann-inmem/src/provider.rs | 53 +- .../src/distance/implementations.rs | 2 +- diskann/src/graph/start_point.rs | 2 + 8 files changed, 504 insertions(+), 242 deletions(-) diff --git a/diskann-benchmark/src/index/benchmarks.rs b/diskann-benchmark/src/index/benchmarks.rs index cc033a0d5..67109bd87 100644 --- a/diskann-benchmark/src/index/benchmarks.rs +++ b/diskann-benchmark/src/index/benchmarks.rs @@ -86,10 +86,10 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> // "graph-index-full-precision-f16", // FullPrecision::::new().search(plugins::Topk), // )?; - // registry.register( - // "graph-index-full-precision-u8", - // FullPrecision::::new().search(plugins::Topk), - // )?; + registry.register( + "graph-index-full-precision-u8", + FullPrecision::::new().search(plugins::Topk), + )?; // registry.register( // "graph-index-full-precision-i8", // FullPrecision::::new().search(plugins::Topk), diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index b92b63f14..6cb14fccf 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -25,6 +25,7 @@ use diskann_benchmark_runner::{ benchmark::{FailureScore, MatchScore}, files::InputFile, output::Output, + utils::datatype::DataType, Benchmark, Checker, Checkpoint, Input, Registry, }; use diskann_inmem::{layers::Full, Provider, Strategy}; @@ -33,11 +34,13 @@ use diskann_vector::distance::Metric; use serde::{Deserialize, Serialize}; use crate::{ - index::build::ProgressMeter, inputs::graph_index::DynamicRunbookParams, utils::datafiles, + index::build::ProgressMeter, + inputs::graph_index::DynamicRunbookParams, + utils::{datafiles, SimilarityMeasure}, }; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { - registry.register("inmem2-f32", Inmem2)?; + // registry.register("inmem2-f32", Inmem2)?; registry.register("inmem2-f32-stream", Inmem2Stream)?; Ok(()) } @@ -46,233 +49,422 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> // Input // /////////// -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct Inmem2Build { - data: InputFile, - queries: InputFile, - groundtruth: InputFile, - - max_degree: usize, - l_build: usize, - alpha: f32, - - search_n: usize, - search_l: Vec, - recall_k: usize, - - num_threads: usize, - reps: NonZeroUsize, -} - -impl Input for Inmem2Build { - type Raw = Inmem2Build; +mod dto { + use super::*; - fn tag() -> &'static str { - "inmem2" + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct KnnSweep { + pub(super) search_n: usize, + pub(super) search_l: Vec, + pub(super) recall_k: usize, } - fn from_raw(mut raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { - raw.data.resolve(checker)?; - raw.queries.resolve(checker)?; - raw.groundtruth.resolve(checker)?; - Ok(raw) + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct KnnSearch { + pub(super) queries: InputFile, + pub(super) groundtruth: InputFile, + pub(super) reps: NonZeroUsize, + pub(super) num_threads: Vec, + pub(super) runs: Vec, } - fn serialize(&self) -> anyhow::Result { - Ok(serde_json::to_value(self)?) + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Data { + pub(super) data_type: DataType, + pub(super) data: InputFile, + pub(super) distance: SimilarityMeasure, } - fn example() -> Self { - Self { - data: InputFile::new("path/to/base.bin"), - queries: InputFile::new("path/to/query.bin"), - groundtruth: InputFile::new("path/to/groundtruth.bin"), - max_degree: 64, - l_build: 100, - alpha: 1.2, - search_n: 10, - search_l: vec![10, 20, 50, 100], - recall_k: 10, - num_threads: 4, - reps: NonZeroUsize::new(3).unwrap(), - } + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct BuildParams { + pub(super) pruned_degree: usize, + pub(super) max_degree: usize, + pub(super) l_build: usize, + pub(super) alpha: f32, + pub(super) num_threads: NonZeroUsize, } -} -impl std::fmt::Display for Inmem2Build { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "inmem2 f32 benchmark")?; - writeln!(f, " max_degree: {}", self.max_degree)?; - writeln!(f, " l_build: {}", self.l_build)?; - writeln!(f, " alpha: {}", self.alpha)?; - writeln!(f, " search_n: {}", self.search_n)?; - writeln!(f, " search_l: {:?}", self.search_l)?; - writeln!(f, " recall_k: {}", self.recall_k)?; - writeln!(f, " num_threads: {}", self.num_threads)?; - writeln!(f, " reps: {}", self.reps) + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct StaticBuild { + pub(super) data: Data, + pub(super) build: BuildParams, + pub(super) search: KnnSearch, } } -/////////////// -// Benchmark // -/////////////// +#[derive(Debug)] +struct KnnInstance { + knn: graph::search::Knn, + recall_k: usize, +} #[derive(Debug)] -struct Inmem2; +struct KnnSearch { + queries: InputFile, + groundtruth: InputFile, + reps: NonZeroUsize, + num_threads: Vec, + runs: Vec, +} -impl Benchmark for Inmem2 { - type Input = Inmem2Build; - type Output = (); +impl KnnSearch { + fn from_raw(raw: dto::KnnSearch, checker: Option<&mut Checker>) -> anyhow::Result { + let dto::KnnSearch { + mut queries, + mut groundtruth, + reps, + num_threads, + runs, + } = raw; + + if let Some(checker) = checker { + queries.resolve(checker)?; + groundtruth.resolve(checker)?; + } - fn try_match(&self, _input: &Inmem2Build) -> Result { - Ok(MatchScore(0)) + let runs = runs + .into_iter() + .flat_map(|sweep| { + sweep + .search_l + .into_iter() + .map(move |search_l| -> anyhow::Result<_> { + let knn = graph::search::Knn::new_default(sweep.search_n, search_l)?; + Ok(KnnInstance { + knn, + recall_k: sweep.recall_k, + }) + }) + }) + .collect::>>()?; + + Ok(Self { + queries, + groundtruth, + reps, + num_threads, + runs, + }) } - fn description( - &self, - f: &mut std::fmt::Formatter<'_>, - input: Option<&Inmem2Build>, - ) -> std::fmt::Result { - match input { - Some(i) => write!(f, "{i}"), - None => write!(f, "inmem2 f32 benchmark"), - } + fn maximum_recall_k(&self) -> usize { + self.runs.iter().map(|r| r.recall_k).max().unwrap_or(0) } +} - fn run( - &self, - input: &Inmem2Build, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> anyhow::Result<()> { - writeln!(output, "{input}")?; +#[derive(Debug)] +struct Data { + data_type: DataType, + data: InputFile, + distance: Metric, +} - // Load data. - let data: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.data))?); +impl Data { + fn from_raw(raw: dto::Data, checker: Option<&mut Checker>) -> anyhow::Result { + let dto::Data { + data_type, + mut data, + distance, + } = raw; - let dim = data.ncols(); - let num_points = data.nrows(); - writeln!(output, "Loaded {num_points} points, dim={dim}")?; + if let Some(checker) = checker { + data.resolve(checker)?; + } - // Compute the medoid of the dataset as the single start point. - let start = StartPointStrategy::Medoid.compute(data.as_view())?; - let metric = Metric::L2; - let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; - let layer = Full::::new(dim, metric); - let config = diskann_inmem::provider::Config::new(num_points, exact_max_degree); - let provider = Provider::new(layer, config, start.row_iter())?; + Ok(Self { + data_type, + data, + distance: distance.into(), + }) + } +} + +#[derive(Debug)] +struct BuildParams { + config: graph::Config, + num_threads: NonZeroUsize, +} + +impl BuildParams { + fn from_raw(raw: dto::BuildParams, metric: Metric) -> anyhow::Result { + let dto::BuildParams { + pruned_degree, + max_degree, + l_build, + alpha, + num_threads, + } = raw; let config = graph::config::Builder::new_with( - input.max_degree, - graph::config::MaxDegree::new(exact_max_degree), - input.l_build, + pruned_degree, + graph::config::MaxDegree::new(max_degree), + l_build, metric.into(), |b| { - b.alpha(input.alpha); + b.alpha(alpha); }, ) .build()?; - let index = Arc::new(DiskANNIndex::new(config, provider, None)); + Ok(Self { + config, + num_threads, + }) + } +} - // Build via SingleInsert. - let rt = benchmark_core::tokio::runtime(input.num_threads)?; - let builder = build_core::graph::SingleInsert::new( - index.clone(), - data, - Strategy, - build_core::ids::Identity::::new(), - ); +#[derive(Debug)] +struct StaticBuild { + data: Data, + build: BuildParams, + search: KnnSearch, +} - writeln!( - output, - "Building index with {} threads...", - input.num_threads - )?; - let build_results = build_core::build_tracked( - builder, - build_core::Parallelism::dynamic( - diskann::utils::ONE, - NonZeroUsize::new(input.num_threads).unwrap(), - ), - &rt, - Some(&ProgressMeter::new(output)), - )?; +impl StaticBuild { + fn from_raw(raw: dto::StaticBuild, mut checker: Option<&mut Checker>) -> anyhow::Result { + let dto::StaticBuild { + data, + build, + search, + } = raw; - let total_build_time = build_results.end_to_end_latency(); - writeln!( - output, - "Build complete in {:.2}s", - total_build_time.as_seconds() - )?; - checkpoint.checkpoint(&total_build_time)?; + let data = Data::from_raw(data, checker.as_deref_mut())?; + let build = BuildParams::from_raw(build, data.distance)?; + let search = KnnSearch::from_raw(search, checker)?; - // Search. - let queries: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.queries))?); - let max_k = input.recall_k; - let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.groundtruth), Some(max_k))?; + Ok(Self { + data, + build, + search, + }) + } +} - writeln!(output, "Loaded {} queries", queries.nrows())?; +impl Input for StaticBuild { + type Raw = dto::StaticBuild; - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(Strategy), - )?; + fn tag() -> &'static str { + "inmem2" + } - let num_threads = NonZeroUsize::new(input.num_threads).unwrap(); + fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { + Self::from_raw(raw, Some(checker)) + } - for &search_l in &input.search_l { - let params = graph::search::Knn::new(input.search_n, search_l, None)?; + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(())?) + } - let setup = core_search::Setup { - threads: num_threads, - tasks: num_threads, - reps: input.reps, - }; + fn example() -> Self::Raw { + const FOUR: NonZeroUsize = NonZeroUsize::new(4).unwrap(); + const THREE: NonZeroUsize = NonZeroUsize::new(3).unwrap(); - let run = core_search::Run::new(params, setup); + dto::StaticBuild { + data: dto::Data { + data_type: DataType::Float32, + data: InputFile::new("path/to/data"), + distance: SimilarityMeasure::SquaredL2, + }, + build: dto::BuildParams { + pruned_degree: 28, + max_degree: 32, + l_build: 100, + alpha: 1.2, + num_threads: FOUR, + }, + search: dto::KnnSearch { + queries: InputFile::new("path/to/queries"), + groundtruth: InputFile::new("path/to/groundtruth"), + reps: THREE, + num_threads: vec![FOUR], + runs: vec![dto::KnnSweep { + search_n: 10, + search_l: vec![10, 20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} - let summaries = core_search::search_all( - knn.clone(), - std::iter::once(run), - benchmark_core::search::graph::knn::Aggregator::new( - &groundtruth, - input.recall_k, - input.search_n, - GroundTruthMode::Fixed, - ), - )?; +// impl std::fmt::Display for Inmem2Build { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// writeln!(f, "inmem2 f32 benchmark")?; +// writeln!(f, " max_degree: {}", self.max_degree)?; +// writeln!(f, " l_build: {}", self.l_build)?; +// writeln!(f, " alpha: {}", self.alpha)?; +// writeln!(f, " search_n: {}", self.search_n)?; +// writeln!(f, " search_l: {:?}", self.search_l)?; +// writeln!(f, " recall_k: {}", self.recall_k)?; +// writeln!(f, " num_threads: {}", self.num_threads)?; +// writeln!(f, " reps: {}", self.reps) +// } +// } - for summary in &summaries { - let qps: Vec = summary - .end_to_end_latencies - .iter() - .map(|lat| summary.recall.num_queries as f64 / lat.as_seconds()) - .collect(); - let max_qps = qps.iter().cloned().fold(0.0f64, f64::max); - let mean_qps = qps.iter().sum::() / qps.len() as f64; +/////////////// +// Benchmark // +/////////////// - writeln!( - output, - " L={:<4} recall={:.4} QPS={:.0} (max {:.0}) cmps={:.1} hops={:.1}", - search_l, - summary.recall.average, - mean_qps, - max_qps, - summary.mean_cmps, - summary.mean_hops, - )?; - } - } +#[derive(Debug)] +struct Build(std::marker::PhantomData); - Ok(()) +impl Build { + fn new() -> Self { + Self(std::marker::PhantomData) } } +// impl Benchmark for Build +// where +// T: diskann_inmem::layers::FullPrecision +// + diskann::graph::SampleableForStart, +// { +// type Input = StaticBuild; +// type Output = (); +// +// fn try_match(&self, input: &StaticBuild) -> Result { +// Ok(MatchScore(0)) +// } +// +// fn description( +// &self, +// f: &mut std::fmt::Formatter<'_>, +// input: Option<&StaticBuild>, +// ) -> std::fmt::Result { +// Ok(()) +// // match input { +// // Some(i) => write!(f, "{i}"), +// // None => write!(f, "inmem2 f32 benchmark"), +// // } +// } +// +// fn run( +// &self, +// input: &StaticBuild, +// checkpoint: Checkpoint<'_>, +// mut output: &mut dyn Output, +// ) -> anyhow::Result<()> { +// // writeln!(output, "{input}")?; +// +// // Load data. +// let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( +// &input.data.data, +// ))?); +// +// let dim = data.ncols(); +// let num_points = data.nrows(); +// writeln!(output, "Loaded {num_points} points, dim={dim}")?; +// +// // Compute the medoid of the dataset as the single start point. +// let start = StartPointStrategy::Medoid.compute(data.as_view())?; +// let layer = Full::::new(dim, input.data.distance); +// let config = +// diskann_inmem::provider::Config::new(num_points, input.build.config.max_degree().get()); +// let provider = Provider::<_, u32>::new(layer, config, start.row_iter())?; +// +// let index = Arc::new(DiskANNIndex::new( +// input.build.config.clone(), +// provider, +// None, +// )); +// +// // Build via SingleInsert. +// let rt = benchmark_core::tokio::runtime(input.build.num_threads.get())?; +// let builder = build_core::graph::SingleInsert::new( +// index.clone(), +// data, +// Strategy, +// build_core::ids::Identity::::new(), +// ); +// +// // writeln!( +// // output, +// // "Building index with {} threads...", +// // input.num_threads +// // )?; +// let build_results = build_core::build_tracked( +// builder, +// build_core::Parallelism::dynamic( +// diskann::utils::ONE, +// input.build.num_threads, +// ), +// &rt, +// Some(&ProgressMeter::new(output)), +// )?; +// +// let total_build_time = build_results.end_to_end_latency(); +// writeln!( +// output, +// "Build complete in {:.2}s", +// total_build_time.as_seconds() +// )?; +// checkpoint.checkpoint(&total_build_time)?; +// +// // Search. +// let queries: Arc> = +// Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.search.queries))?); +// let max_k = input.search.maximum_recall_k(); +// let groundtruth = +// datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), Some(max_k))?; +// +// writeln!(output, "Loaded {} queries", queries.nrows())?; +// +// let knn = benchmark_core::search::graph::KNN::new( +// index, +// queries, +// benchmark_core::search::graph::Strategy::broadcast(Strategy), +// )?; +// +// // let num_threads = NonZeroUsize::new(input.num_threads).unwrap(); +// let mut results = Vec::new(); +// for num_threads in input.search.num_threads.iter() { +// let params = graph::search::Knn::new(input.search_n, search_l, None)?; +// +// let setup = core_search::Setup { +// threads: num_threads, +// tasks: num_threads, +// reps: input.reps, +// }; +// +// let run = core_search::Run::new(params, setup); +// +// let summaries = core_search::search_all( +// knn.clone(), +// std::iter::once(run), +// benchmark_core::search::graph::knn::Aggregator::new( +// &groundtruth, +// input.recall_k, +// input.search_n, +// GroundTruthMode::Fixed, +// ), +// )?; +// +// for summary in &summaries { +// let qps: Vec = summary +// .end_to_end_latencies +// .iter() +// .map(|lat| summary.recall.num_queries as f64 / lat.as_seconds()) +// .collect(); +// let max_qps = qps.iter().cloned().fold(0.0f64, f64::max); +// let mean_qps = qps.iter().sum::() / qps.len() as f64; +// +// writeln!( +// output, +// " L={:<4} recall={:.4} QPS={:.0} (max {:.0}) cmps={:.1} hops={:.1}", +// search_l, +// summary.recall.average, +// mean_qps, +// max_qps, +// summary.mean_cmps, +// summary.mean_hops, +// )?; +// } +// } +// +// Ok(()) +// } +// } + /////////////// // Streaming // /////////////// diff --git a/diskann-inmem/integration/index/object.rs b/diskann-inmem/integration/index/object.rs index 95505366d..a065a9b95 100644 --- a/diskann-inmem/integration/index/object.rs +++ b/diskann-inmem/integration/index/object.rs @@ -188,8 +188,7 @@ impl CheckMatch for Counters { impl Index for DiskANNIndex, u64>> where - layers::Full: for<'a> layers::Insert = &'a [T]>, - T: FromSlice + AsDataType + Send + Sync + 'static, + T: layers::FullPrecision + FromSlice + AsDataType, { fn search<'a>( &'a self, diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index ccc0b75b6..41935409b 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -23,6 +23,24 @@ use thiserror::Error; use crate::{layers, num::Bytes}; +/// A useful trait bound for types compatible with [`Full`]. +/// +/// This encompases *everything* required for `Full: layers::Insert` and can be used as +/// a single bound. +pub trait FullPrecision: bytemuck::Pod + std::fmt::Debug + Send + Sync { + #[doc(hidden)] + fn __new(dim: usize, metric: Metric) -> Full; + + #[doc(hidden)] + fn __query_distance<'a, V>( + full: &'a Full, + query: &'a [Self], + visitor: V, + ) -> ANNResult + where + V: layers::QueryVisitor<'a>; +} + /// Full-precision data layer. #[derive(Debug)] pub struct Full @@ -39,6 +57,13 @@ where { /// Create a new full-precision layer for data with the given `dim` and `metric`. pub fn new(dim: usize, metric: Metric) -> Self + where + T: FullPrecision, + { + T::__new(dim, metric) + } + + fn from_distance_provider(dim: usize, metric: Metric) -> Self where T: DistanceProvider, { @@ -74,7 +99,7 @@ where impl layers::Layer for Full where - T: bytemuck::Pod + Send + Sync, + T: FullPrecision, { fn bytes(&self) -> Bytes { >::bytes(self) @@ -83,7 +108,7 @@ where impl layers::Set<&[T]> for Full where - T: bytemuck::Pod + Send + Sync, + T: FullPrecision, { fn set(&self, v: &[T], bytes: &mut [u8]) -> ANNResult<()> { if v.len() != self.dim() { @@ -123,26 +148,36 @@ crate::opaque!(SetError); impl layers::AsDistance for Full where - T: Debug + Send + Sync + 'static, + T: FullPrecision, { fn as_distance(&self) -> &dyn layers::Distance { &self.distance } } -impl layers::Insert for Full +impl layers::Search for Full where - T: bytemuck::Pod + Debug + Send + Sync + 'static, - Self: for<'a> layers::Search = &'a [T]>, + T: FullPrecision, { + type Query<'a> = &'a [T]; + + fn query_distance<'a, V>(&'a self, query: &'a [T], visitor: V) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + T::__query_distance(self, query, visitor) + } } +impl layers::Insert for Full where T: FullPrecision {} + ////////////// // Distance // ////////////// #[derive(Debug)] -struct Distance +#[doc(hidden)] +pub struct Distance where T: 'static, U: 'static, @@ -329,20 +364,26 @@ macro_rules! mint { }}; } -impl layers::Search for Full { - type Query<'a> = &'a [f32]; +impl FullPrecision for f32 { + fn __new(dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } - fn query_distance<'a, V>(&'a self, query: &'a [f32], visitor: V) -> ANNResult + fn __query_distance<'a, V>( + full: &'a Full, + query: &'a [f32], + visitor: V, + ) -> ANNResult where V: layers::QueryVisitor<'a>, { - self.check_dim(query.len())?; + full.check_dim(query.len())?; let query = Calf::Borrowed(query); - let output = match self.metric { + let output = match full.metric { Metric::L2 => { - if self.dim() == 100 { + if full.dim() == 100 { mint!(query, visitor, f32 => { 100, SquaredL2 }) } else { mint!(query, visitor, f32 => SquaredL2) @@ -357,20 +398,26 @@ impl layers::Search for Full { } } -impl layers::Search for Full { - type Query<'a> = &'a [f16]; +impl FullPrecision for f16 { + fn __new(dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } - fn query_distance<'a, V>(&'a self, query: &'a [f16], visitor: V) -> ANNResult + fn __query_distance<'a, V>( + full: &'a Full, + query: &'a [f16], + visitor: V, + ) -> ANNResult where V: layers::QueryVisitor<'a>, { - self.check_dim(query.len())?; + full.check_dim(query.len())?; - let mut as_f32: Box<[f32]> = std::iter::repeat_n(0.0, self.dim()).collect(); + let mut as_f32: Box<[f32]> = std::iter::repeat_n(0.0, full.dim()).collect(); diskann_wide::arch::dispatch2(SliceCast::new(), &mut *as_f32, query); let query = Calf::Owned(as_f32); - let output = match self.metric { + let output = match full.metric { Metric::L2 => mint!(query, visitor, { f32, f16 } => SquaredL2), Metric::InnerProduct => mint!(query, visitor, { f32, f16 } => InnerProduct), Metric::Cosine => mint!(query, visitor, { f32, f16 } => Cosine), @@ -381,20 +428,26 @@ impl layers::Search for Full { } } -impl layers::Search for Full { - type Query<'a> = &'a [u8]; +impl FullPrecision for u8 { + fn __new(dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } - fn query_distance<'a, V>(&'a self, query: &'a [u8], visitor: V) -> ANNResult + fn __query_distance<'a, V>( + full: &'a Full, + query: &'a [u8], + visitor: V, + ) -> ANNResult where V: layers::QueryVisitor<'a>, { - self.check_dim(query.len())?; + full.check_dim(query.len())?; let query = Calf::Borrowed(query); - let output = match self.metric { + let output = match full.metric { Metric::L2 => { - if self.dim() == 128 { + if full.dim() == 128 { mint!(query, visitor, u8 => { 128, SquaredL2 }) } else { mint!(query, visitor, u8 => SquaredL2) @@ -409,18 +462,24 @@ impl layers::Search for Full { } } -impl layers::Search for Full { - type Query<'a> = &'a [i8]; +impl FullPrecision for i8 { + fn __new(dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } - fn query_distance<'a, V>(&'a self, query: &'a [i8], visitor: V) -> ANNResult + fn __query_distance<'a, V>( + full: &'a Full, + query: &'a [i8], + visitor: V, + ) -> ANNResult where V: layers::QueryVisitor<'a>, { - self.check_dim(query.len())?; + full.check_dim(query.len())?; let query = Calf::Borrowed(query); - let output = match self.metric { + let output = match full.metric { Metric::L2 => mint!(query, visitor, i8 => SquaredL2), Metric::InnerProduct => mint!(query, visitor, i8 => InnerProduct), Metric::Cosine => mint!(query, visitor, i8 => Cosine), @@ -509,8 +568,7 @@ mod tests { /// reject byte slices that are too long or too short. fn test_impl(max_dim: usize, ctx: &dyn Display) where - T: Sample + Debug + Send + Sync + DistanceProvider + 'static, - Full: for<'a> layers::Search = &'a [T]>, + T: FullPrecision + Sample + DistanceProvider, { let mut rng = StdRng::seed_from_u64(0x0D15_0ACE ^ max_dim as u64); let metrics = [ diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index b67c1ada1..648e5a6b4 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -31,7 +31,7 @@ use diskann::ANNResult; use crate::num::Bytes; mod full; -pub use full::Full; +pub use full::{Full, FullPrecision}; /// Base layer for data representations. pub trait Layer: Send + Sync + 'static { diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 7fbc85468..398b6a18c 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -348,6 +348,7 @@ pub struct SearchAccessor<'a> { reader: store::Reader<'a>, ids: AdjacencyList, expand_beam: Box, + buffer: Vec<(u32, f32)>, // The parent provider for the accessor. provider: &'a (dyn std::any::Any + Send + Sync), @@ -417,20 +418,23 @@ impl glue::SearchAccessor for SearchAccessor<'_> { self.ids .retain(|i| pred.eval_mut(i) && self.reader.is_in_bounds(i.into_usize())); - // TODO: Move to an external buffer to avoid any dynamic dispatcn in - // `expand_beam_inner` - then we can do a bulk-update on the counters. - let mut on_neighbors = |id, distance| { - self.counters.get_vector(1); - self.counters.query_distance(1); + // This should always hold, but let's double check. + assert!(self.buffer.len() >= self.ids.len()); - on_neighbors(id, distance); - }; - - // SAFETY: We've verified that each entry in `self.ids` is in-bounds. - unsafe { + // SAFETY: We've verified that each entry in `self.ids` is in-bounds and the + // `self.buffer` is long enough to hold all the IDs. + let processed = unsafe { self.expand_beam - .expand_beam(&self.ids, 8, &self.reader, &mut on_neighbors) + .expand_beam(&self.ids, 8, &self.reader, &mut self.buffer) }?; + + self.counters.get_vector(processed as u64); + self.counters.query_distance(processed as u64); + + self.buffer + .iter() + .take(processed) + .for_each(|(id, dist)| on_neighbors(*id, *dist)); } Ok(()) @@ -448,14 +452,15 @@ trait ExpandBeam: Send + Sync + std::fmt::Debug { /// /// # Safety /// - /// All items in `list` must in-bounds with respect to `reader`. + /// * All items in `list` must in-bounds with respect to `reader`. + /// * `buffer.len() >= list.len()`. unsafe fn expand_beam( &self, list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - f: &mut dyn FnMut(u32, f32), - ) -> ANNResult<()>; + buffer: &mut [(u32, f32)], + ) -> ANNResult; } #[derive(Debug)] @@ -475,10 +480,10 @@ where list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - f: &mut dyn FnMut(u32, f32), - ) -> ANNResult<()> { + buffer: &mut [(u32, f32)], + ) -> ANNResult { // SAFETY: Inherited from caller. - unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, f) } + unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, buffer) } } } @@ -540,14 +545,15 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { /// /// * All items in `list` must in-bounds with respect to `reader`. /// * The number of bytes associated with `N` cache lines must "make sense". +/// * `buffer.len() >= list.len()`. #[inline] unsafe fn expand_beam_inner( distance: &T, list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - f: &mut dyn FnMut(u32, f32), -) -> ANNResult<()> + buffer: &mut [(u32, f32)], +) -> ANNResult where T: layers::QueryDistance, { @@ -558,6 +564,8 @@ where reader.bytes() ); + debug_assert!(buffer.len() >= list.len()); + let bytes = if BYTES == 0 { reader.bytes().value() } else { @@ -582,6 +590,7 @@ where } let mut j = lookahead; + let mut processed = 0; for &i in list.iter() { if j != len { // SAFETY: The in-bounds constraint is assured by the caller, both for `j` as @@ -600,11 +609,12 @@ where // SAFETY: Caller asserts that `i` is in-bounds. if let Some(data) = unsafe { reader.read_in_bounds(i.into_usize()) } { - f(i, distance.evaluate(data)?) + *unsafe { buffer.get_unchecked_mut(processed) } = (i, distance.evaluate(data)?); + processed += 1; } } - Ok(()) + Ok(processed) } //////////// @@ -791,6 +801,7 @@ where reader, ids: AdjacencyList::new(), expand_beam, + buffer: vec![(0, 0.0); provider.max_degree()], provider, start_points: provider.store.frozen(), counters: provider.local_counters(), diff --git a/diskann-vector/src/distance/implementations.rs b/diskann-vector/src/distance/implementations.rs index 8e119e43c..a2f250bb9 100644 --- a/diskann-vector/src/distance/implementations.rs +++ b/diskann-vector/src/distance/implementations.rs @@ -41,7 +41,7 @@ macro_rules! architecture_hook { { #[inline(always)] fn run(arch: A, left: L, right: R) -> T { - arch.run2(Self::default(), left, right) + arch.run2_inline(Self::default(), left, right) } } }; diff --git a/diskann/src/graph/start_point.rs b/diskann/src/graph/start_point.rs index 2f0fca217..c3bee67bd 100644 --- a/diskann/src/graph/start_point.rs +++ b/diskann/src/graph/start_point.rs @@ -44,6 +44,7 @@ pub trait SampleableForStart: diskann_utils::sampling::medoid::ComputeMedoid + diskann_utils::sampling::latin_hypercube::SampleLatinHyperCube + diskann_utils::sampling::random::RoundFromf32 + + diskann_utils::sampling::WithApproximateNorm { } @@ -51,6 +52,7 @@ impl SampleableForStart for T where T: diskann_utils::sampling::medoid::ComputeMedoid + diskann_utils::sampling::latin_hypercube::SampleLatinHyperCube + diskann_utils::sampling::random::RoundFromf32 + + diskann_utils::sampling::WithApproximateNorm { } From 4207c1d269b89b5a32cd9cc4ff8ea523b6ebafdf Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 29 Jun 2026 12:47:19 -0700 Subject: [PATCH 31/45] Move machines. --- diskann-benchmark/src/index/inmem2.rs | 287 ++++++++++++-------------- diskann-inmem/src/layers/full.rs | 8 +- diskann-inmem/src/provider.rs | 8 +- 3 files changed, 146 insertions(+), 157 deletions(-) diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 6cb14fccf..a36d46b7f 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -34,13 +34,13 @@ use diskann_vector::distance::Metric; use serde::{Deserialize, Serialize}; use crate::{ - index::build::ProgressMeter, + index::{result::{SearchResults, AggregatedSearchResults}, build::ProgressMeter}, inputs::graph_index::DynamicRunbookParams, utils::{datafiles, SimilarityMeasure}, }; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { - // registry.register("inmem2-f32", Inmem2)?; + registry.register("inmem2-f32", Build::::new())?; registry.register("inmem2-f32-stream", Inmem2Stream)?; Ok(()) } @@ -314,156 +314,139 @@ impl Build { } } -// impl Benchmark for Build -// where -// T: diskann_inmem::layers::FullPrecision -// + diskann::graph::SampleableForStart, -// { -// type Input = StaticBuild; -// type Output = (); -// -// fn try_match(&self, input: &StaticBuild) -> Result { -// Ok(MatchScore(0)) -// } -// -// fn description( -// &self, -// f: &mut std::fmt::Formatter<'_>, -// input: Option<&StaticBuild>, -// ) -> std::fmt::Result { -// Ok(()) -// // match input { -// // Some(i) => write!(f, "{i}"), -// // None => write!(f, "inmem2 f32 benchmark"), -// // } -// } -// -// fn run( -// &self, -// input: &StaticBuild, -// checkpoint: Checkpoint<'_>, -// mut output: &mut dyn Output, -// ) -> anyhow::Result<()> { -// // writeln!(output, "{input}")?; -// -// // Load data. -// let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( -// &input.data.data, -// ))?); -// -// let dim = data.ncols(); -// let num_points = data.nrows(); -// writeln!(output, "Loaded {num_points} points, dim={dim}")?; -// -// // Compute the medoid of the dataset as the single start point. -// let start = StartPointStrategy::Medoid.compute(data.as_view())?; -// let layer = Full::::new(dim, input.data.distance); -// let config = -// diskann_inmem::provider::Config::new(num_points, input.build.config.max_degree().get()); -// let provider = Provider::<_, u32>::new(layer, config, start.row_iter())?; -// -// let index = Arc::new(DiskANNIndex::new( -// input.build.config.clone(), -// provider, -// None, -// )); -// -// // Build via SingleInsert. -// let rt = benchmark_core::tokio::runtime(input.build.num_threads.get())?; -// let builder = build_core::graph::SingleInsert::new( -// index.clone(), -// data, -// Strategy, -// build_core::ids::Identity::::new(), -// ); -// -// // writeln!( -// // output, -// // "Building index with {} threads...", -// // input.num_threads -// // )?; -// let build_results = build_core::build_tracked( -// builder, -// build_core::Parallelism::dynamic( -// diskann::utils::ONE, -// input.build.num_threads, -// ), -// &rt, -// Some(&ProgressMeter::new(output)), -// )?; -// -// let total_build_time = build_results.end_to_end_latency(); -// writeln!( -// output, -// "Build complete in {:.2}s", -// total_build_time.as_seconds() -// )?; -// checkpoint.checkpoint(&total_build_time)?; -// -// // Search. -// let queries: Arc> = -// Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.search.queries))?); -// let max_k = input.search.maximum_recall_k(); -// let groundtruth = -// datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), Some(max_k))?; -// -// writeln!(output, "Loaded {} queries", queries.nrows())?; -// -// let knn = benchmark_core::search::graph::KNN::new( -// index, -// queries, -// benchmark_core::search::graph::Strategy::broadcast(Strategy), -// )?; -// -// // let num_threads = NonZeroUsize::new(input.num_threads).unwrap(); -// let mut results = Vec::new(); -// for num_threads in input.search.num_threads.iter() { -// let params = graph::search::Knn::new(input.search_n, search_l, None)?; -// -// let setup = core_search::Setup { -// threads: num_threads, -// tasks: num_threads, -// reps: input.reps, -// }; -// -// let run = core_search::Run::new(params, setup); -// -// let summaries = core_search::search_all( -// knn.clone(), -// std::iter::once(run), -// benchmark_core::search::graph::knn::Aggregator::new( -// &groundtruth, -// input.recall_k, -// input.search_n, -// GroundTruthMode::Fixed, -// ), -// )?; -// -// for summary in &summaries { -// let qps: Vec = summary -// .end_to_end_latencies -// .iter() -// .map(|lat| summary.recall.num_queries as f64 / lat.as_seconds()) -// .collect(); -// let max_qps = qps.iter().cloned().fold(0.0f64, f64::max); -// let mean_qps = qps.iter().sum::() / qps.len() as f64; -// -// writeln!( -// output, -// " L={:<4} recall={:.4} QPS={:.0} (max {:.0}) cmps={:.1} hops={:.1}", -// search_l, -// summary.recall.average, -// mean_qps, -// max_qps, -// summary.mean_cmps, -// summary.mean_hops, -// )?; -// } -// } -// -// Ok(()) -// } -// } +impl Benchmark for Build +where + T: diskann_inmem::layers::FullPrecision + + diskann::graph::SampleableForStart, +{ + type Input = StaticBuild; + type Output = (); + + fn try_match(&self, input: &StaticBuild) -> Result { + Ok(MatchScore(0)) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&StaticBuild>, + ) -> std::fmt::Result { + Ok(()) + // match input { + // Some(i) => write!(f, "{i}"), + // None => write!(f, "inmem2 f32 benchmark"), + // } + } + + fn run( + &self, + input: &StaticBuild, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result<()> { + // writeln!(output, "{input}")?; + + // Load data. + let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &input.data.data, + ))?); + + let dim = data.ncols(); + let num_points = data.nrows(); + writeln!(output, "Loaded {num_points} points, dim={dim}")?; + + // Compute the medoid of the dataset as the single start point. + let start = StartPointStrategy::Medoid.compute(data.as_view())?; + let layer = Full::::new(dim, input.data.distance); + let config = + diskann_inmem::provider::Config::new(num_points, input.build.config.max_degree().get()); + let provider = Provider::<_, u32>::new(layer, config, start.row_iter())?; + + let index = Arc::new(DiskANNIndex::new( + input.build.config.clone(), + provider, + None, + )); + + // Build via SingleInsert. + let rt = benchmark_core::tokio::runtime(input.build.num_threads.get())?; + let builder = build_core::graph::SingleInsert::new( + index.clone(), + data, + Strategy, + build_core::ids::Identity::::new(), + ); + + // writeln!( + // output, + // "Building index with {} threads...", + // input.num_threads + // )?; + let build_results = build_core::build_tracked( + builder, + build_core::Parallelism::dynamic( + diskann::utils::ONE, + input.build.num_threads, + ), + &rt, + Some(&ProgressMeter::new(output)), + )?; + + let total_build_time = build_results.end_to_end_latency(); + writeln!( + output, + "Build complete in {:.2}s", + total_build_time.as_seconds() + )?; + checkpoint.checkpoint(&total_build_time)?; + + // Search. + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.search.queries))?); + let max_k = input.search.maximum_recall_k(); + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), Some(max_k))?; + + writeln!(output, "Loaded {} queries", queries.nrows())?; + + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(Strategy), + )?; + + let mut results = Vec::new(); + for num_threads in input.search.num_threads.iter() { + for instance in input.search.runs.iter() { + let setup = core_search::Setup { + threads: *num_threads, + tasks: *num_threads, + reps: input.search.reps, + }; + + let run = core_search::Run::new(instance.knn, setup); + + let r = core_search::search_all( + knn.clone(), + std::iter::once(run), + benchmark_core::search::graph::knn::Aggregator::new( + &groundtruth, + instance.recall_k, + instance.knn.k_value().get(), + GroundTruthMode::Fixed, + ), + )?; + results.extend(r.into_iter().map(SearchResults::new)); + } + } + + let results = AggregatedSearchResults::Topk(results); + + writeln!(output, "{}", results)?; + + Ok(()) + } +} /////////////// // Streaming // diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 41935409b..a9fef5deb 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -389,7 +389,13 @@ impl FullPrecision for f32 { mint!(query, visitor, f32 => SquaredL2) } } - Metric::InnerProduct => mint!(query, visitor, f32 => InnerProduct), + Metric::InnerProduct => { + // if full.dim() == 768 { + // mint!(query, visitor, f32 => { 768, InnerProduct }) + // } else { + mint!(query, visitor, f32 => InnerProduct) + // } + }, Metric::Cosine => mint!(query, visitor, f32 => Cosine), Metric::CosineNormalized => mint!(query, visitor, f32 => CosineNormalized), }; diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 398b6a18c..1e6dc8646 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -500,7 +500,7 @@ impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { T: QueryDistance + 'a, { // This is critical to ensure we emit the correct number of prefetches. - assert_eq!(Bytes::new(BYTES + store::TAG_SIZE.value()), self.bytes); + assert!(Bytes::new(BYTES + store::TAG_SIZE.value()) <= self.bytes); Box::new(ExpandBeamImpl::<_, BYTES>(distance)) } @@ -533,7 +533,7 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { // SAFETY: Inherited from caller. unsafe { _mm_prefetch(ptr.add(stride * (lines - 1)), _MM_HINT_T0) }; - for i in 0..(lines - 1) { + for i in 0..(lines - 1).min(16) { // SAFETY: Inherited from caller. unsafe { _mm_prefetch(ptr.add(stride * i), _MM_HINT_T0); @@ -811,10 +811,10 @@ where } pub fn test_function<'a>( - x: &'a Provider>, + x: &'a Provider>, strategy: &'a Strategy, context: &'a Context, - query: &'a [u8], + query: &'a [f32], ) -> ANNResult> { glue::SearchStrategy::search_accessor(strategy, x, context, query) } From 23e386a933cc2d616ffd516cf997a9ca49a3daa1 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 30 Jun 2026 17:22:51 -0700 Subject: [PATCH 32/45] Checkpoint. --- .../src/streaming/executors/bigann/runbook.rs | 2 +- diskann-benchmark-runner/src/checker.rs | 30 + diskann-benchmark-runner/src/files.rs | 6 + diskann-benchmark-runner/src/utils/fmt.rs | 63 +- diskann-benchmark/src/index/benchmarks.rs | 12 +- diskann-benchmark/src/index/inmem2.rs | 934 ++++++++++-------- diskann-benchmark/src/index/streaming/mod.rs | 1 + diskann-inmem/src/epoch.rs | 4 +- diskann-inmem/src/layers/full.rs | 41 +- diskann-inmem/src/lib.rs | 19 + diskann-inmem/src/provider.rs | 24 +- 11 files changed, 693 insertions(+), 443 deletions(-) diff --git a/diskann-benchmark-core/src/streaming/executors/bigann/runbook.rs b/diskann-benchmark-core/src/streaming/executors/bigann/runbook.rs index 19ee0b7c9..971a88d54 100644 --- a/diskann-benchmark-core/src/streaming/executors/bigann/runbook.rs +++ b/diskann-benchmark-core/src/streaming/executors/bigann/runbook.rs @@ -19,7 +19,7 @@ use super::{parsing, validate}; /// /// If using this struct as a [`streaming::Executor`], consider using the /// [`super::WithData`] adaptor to provide dataset and query matrices. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RunBook { // The individual runbook stages. stages: Vec, diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index 03ca6dec0..00329e955 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -145,6 +145,36 @@ impl Checker { self.search_directories(), ))) } + + pub fn __check_dir(&self, dir: &Path) -> Result { + // Check if the file exists (allowing for relative paths with respect to the current + // directory. + // + // If the path is an absolute path and the file does not exist, then bail. + if dir.is_absolute() { + if dir.is_dir() { + return Ok(dir.into()); + } else { + return Err(anyhow::Error::msg(format!( + "input file with absolute path \"{}\" either does not exist or is not a file", + dir.display() + ))); + } + }; + + // At this point, start searching in the provided directories. + for d in self.search_directories() { + let absolute = d.join(dir); + if absolute.is_dir() { + return Ok(absolute); + } + } + Err(anyhow::Error::msg(format!( + "could not find input file \"{}\" in the search directories \"{:?}\"", + dir.display(), + self.search_directories(), + ))) + } } /////////// diff --git a/diskann-benchmark-runner/src/files.rs b/diskann-benchmark-runner/src/files.rs index 1672f6f62..f51760dd5 100644 --- a/diskann-benchmark-runner/src/files.rs +++ b/diskann-benchmark-runner/src/files.rs @@ -57,6 +57,12 @@ impl std::ops::Deref for InputFile { } } +impl std::fmt::Display for InputFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display()) + } +} + /////////// // Tests // /////////// diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 92fa6d7d7..95a49d066 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -379,7 +379,28 @@ where // KeyValue // ////////////// -/// Display a dynamic list of key-value pairs such that the keys are right-aligned. +enum MaybeLazy<'a> { + Lazy(&'a dyn std::fmt::Display), + Eager(String), +} + +impl std::fmt::Display for MaybeLazy<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Lazy(lazy) => write!(f, "{}", lazy), + Self::Eager(s) => f.write_str(&s), + } + } +} + +/// Display a dynamic list of key-value pairs in a YAML-like style. +/// +/// Keys are left-aligned and single-line values are aligned into a common column +/// just past the longest key. A value that renders to multiple lines (for example +/// a nested [`KeyValue`] or any other multi-line block) is placed on the lines +/// following its key, indented by two spaces. This keeps nested structures visibly +/// subordinate to their key regardless of whether the value is itself a key-value +/// list or an opaque block. /// /// # Examples /// @@ -390,12 +411,31 @@ where /// kv.push("a", &1); /// kv.push("hello", &"world"); /// -/// let expected = " a: 1\nhello: world"; +/// let expected = "a: 1\nhello: world"; +/// +/// assert_eq!(kv.to_string(), expected); +/// ``` +/// +/// Multi-line values are indented beneath their key: +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::KeyValue; +/// +/// let mut inner = KeyValue::new(); +/// inner.push("x", &1); +/// inner.push("yy", &2); +/// let inner = inner.to_string(); +/// +/// let mut kv = KeyValue::new(); +/// kv.push("name", &"example"); +/// kv.push("nested", &inner); +/// +/// let expected = "name: example\nnested:\n x: 1\n yy: 2"; /// /// assert_eq!(kv.to_string(), expected); /// ``` pub struct KeyValue<'a> { - kv: Vec<(&'a str, &'a dyn std::fmt::Display)>, + kv: Vec<(&'a str, MaybeLazy<'a>)>, max_key_length: usize, } @@ -411,7 +451,15 @@ impl<'a> KeyValue<'a> { /// Push the key-value pair to `self` for formatting. pub fn push(&mut self, key: &'a str, value: &'a dyn std::fmt::Display) { self.max_key_length = self.max_key_length.max(key.len()); - self.kv.push((key, value)) + self.kv.push((key, MaybeLazy::Lazy(value))) + } + + pub fn push_eager(&mut self, key: &'a str, value: D) + where + D: std::fmt::Display, + { + self.max_key_length = self.max_key_length.max(key.len()); + self.kv.push((key, MaybeLazy::Eager(value.to_string()))) } pub fn render(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -426,9 +474,12 @@ impl std::fmt::Display for KeyValue<'_> { for (k, v) in self.kv.iter() { let rendered = v.to_string(); if rendered.contains('\n') { - write!(f, "{}{:>width$}:\n{}", prefix, k, Indent::new(&rendered, 2))? + write!(f, "{}{}:\n{}", prefix, k, Indent::new(&rendered, 2))? } else { - write!(f, "{}{:>width$}: {rendered}", prefix, k)?; + // Left-align the key and pad so that all single-line values line up in a + // column one space past the longest key's colon. + let pad = (width + 1).saturating_sub(k.len()); + write!(f, "{}{}:{:pad$}{rendered}", prefix, k, "")?; } prefix = "\n"; } diff --git a/diskann-benchmark/src/index/benchmarks.rs b/diskann-benchmark/src/index/benchmarks.rs index 67109bd87..45a142344 100644 --- a/diskann-benchmark/src/index/benchmarks.rs +++ b/diskann-benchmark/src/index/benchmarks.rs @@ -82,15 +82,15 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> .search(plugins::DeterminantDiversity), )?; - // registry.register( - // "graph-index-full-precision-f16", - // FullPrecision::::new().search(plugins::Topk), - // )?; registry.register( - "graph-index-full-precision-u8", - FullPrecision::::new().search(plugins::Topk), + "graph-index-full-precision-f16", + FullPrecision::::new().search(plugins::Topk), )?; // registry.register( + // "graph-index-full-precision-u8", + // FullPrecision::::new().search(plugins::Topk), + // )?; + // registry.register( // "graph-index-full-precision-i8", // FullPrecision::::new().search(plugins::Topk), // )?; diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index a36d46b7f..69debbea9 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -3,45 +3,57 @@ * Licensed under the MIT license. */ -//! Benchmark backend for the `diskann-inmem` (inmem2) provider. -//! -//! This wires up the inmem2 `Provider>` to the standard build and search -//! infrastructure in `diskann-benchmark-core`, giving us parallel insertion via -//! [`SingleInsert`] and KNN search with recall/latency reporting via [`KNN`]. - -use std::{io::Write, num::NonZeroUsize, ops::Range, sync::Arc}; - -use diskann::{ - graph::{self, DiskANNIndex, StartPointStrategy}, - provider::{self as ann_provider}, +use std::{ + fmt::{self, Display, Formatter}, + io::Write, + num::NonZeroUsize, + ops::Range, + sync::Arc, + path::PathBuf, }; + +use diskann::graph::{self, DiskANNIndex, InplaceDeleteMethod, StartPointStrategy}; use diskann_benchmark_core::{ - self as benchmark_core, build as build_core, - recall::{self, GroundTruthMode}, - search as core_search, + self as benchmark_core, build as build_core, recall, search as core_search, streaming::{self, executors::bigann, Executor}, }; use diskann_benchmark_runner::{ benchmark::{FailureScore, MatchScore}, files::InputFile, output::Output, - utils::datatype::DataType, + utils::{ + datatype::{AsDataType, DataType}, + fmt::{Delimit, KeyValue, Quote}, + }, Benchmark, Checker, Checkpoint, Input, Registry, }; -use diskann_inmem::{layers::Full, Provider, Strategy}; +use diskann_inmem::{ + layers::{Full, FullPrecision}, + Provider, Strategy, +}; use diskann_utils::views::{Matrix, MatrixView}; use diskann_vector::distance::Metric; +use half::f16; use serde::{Deserialize, Serialize}; use crate::{ - index::{result::{SearchResults, AggregatedSearchResults}, build::ProgressMeter}, + index::{ + build::{BuildKind, BuildStats, ProgressMeter}, + result::{AggregatedSearchResults, SearchResults}, + streaming::{ + stats::{GenericStats, Summary}, + StreamStats, + }, + }, inputs::graph_index::DynamicRunbookParams, utils::{datafiles, SimilarityMeasure}, }; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { registry.register("inmem2-f32", Build::::new())?; - registry.register("inmem2-f32-stream", Inmem2Stream)?; + registry.register("inmem2-f16", Build::::new())?; + + registry.register("inmem2-f32-stream", StreamingBenchmark)?; Ok(()) } @@ -84,20 +96,84 @@ mod dto { pub(super) num_threads: NonZeroUsize, } + //-----------// + // Streaming // + //-----------// + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct StreamingKnnSearch { + pub(super) queries: InputFile, + pub(super) reps: NonZeroUsize, + pub(super) num_threads: NonZeroUsize, + pub(super) runs: Vec, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct RunBook { + pub(super) path: InputFile, + pub(super) dataset: String, + pub(super) groundtruth_directory: String, + pub(super) delete_method: crate::inputs::graph_index::InplaceDeleteMethod, + pub(super) delete_num_to_replace: usize, + } + + //------------------// + // Top Level Inputs // + //------------------// + #[derive(Debug, Serialize, Deserialize)] pub(super) struct StaticBuild { pub(super) data: Data, pub(super) build: BuildParams, pub(super) search: KnnSearch, } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct BigANNStreaming { + pub(super) data: Data, + pub(super) build: BuildParams, + pub(super) search: StreamingKnnSearch, + pub(super) runbook: RunBook, + } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct KnnInstance { knn: graph::search::Knn, recall_k: usize, } +impl KnnInstance { + fn flatten(runs: &[dto::KnnSweep]) -> anyhow::Result> { + runs.iter() + .flat_map(|sweep| { + let search_n = sweep.search_n; + let recall_k = sweep.recall_k; + + sweep + .search_l + .iter() + .map(move |search_l| -> anyhow::Result<_> { + let knn = graph::search::Knn::new_default(search_n, *search_l)?; + Ok(KnnInstance { knn, recall_k }) + }) + }) + .collect() + } +} + +impl Display for KnnInstance { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "knn = {}, search_l = {}, beam_width = {}", + self.recall_k, + self.knn.l_value(), + self.knn.beam_width(), + ) + } +} + #[derive(Debug)] struct KnnSearch { queries: InputFile, @@ -122,28 +198,12 @@ impl KnnSearch { groundtruth.resolve(checker)?; } - let runs = runs - .into_iter() - .flat_map(|sweep| { - sweep - .search_l - .into_iter() - .map(move |search_l| -> anyhow::Result<_> { - let knn = graph::search::Knn::new_default(sweep.search_n, search_l)?; - Ok(KnnInstance { - knn, - recall_k: sweep.recall_k, - }) - }) - }) - .collect::>>()?; - Ok(Self { queries, groundtruth, reps, num_threads, - runs, + runs: KnnInstance::flatten(&runs)?, }) } @@ -152,6 +212,22 @@ impl KnnSearch { } } +impl Display for KnnSearch { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("queries", &self.queries); + kv.push("groundtruth", &self.groundtruth); + kv.push("reps", &self.reps); + + let num_threads = Delimit::new(self.num_threads.iter(), ", "); + kv.push("num_threads", &num_threads); + + let runs = Delimit::new(self.runs.iter(), "\n").to_string(); + kv.push("runs", &runs); + write!(f, "{}", kv) + } +} + #[derive(Debug)] struct Data { data_type: DataType, @@ -179,6 +255,16 @@ impl Data { } } +impl Display for Data { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("data_type", &self.data_type); + kv.push("data", &self.data); + kv.push("distance", &self.distance); + write!(f, "{}", kv) + } +} + #[derive(Debug)] struct BuildParams { config: graph::Config, @@ -213,6 +299,24 @@ impl BuildParams { } } +impl Display for BuildParams { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + + let pruned_degree = self.config.pruned_degree(); + let max_degree = self.config.max_degree(); + let alpha = self.config.alpha(); + let l_build = self.config.l_build(); + + kv.push("pruned_degree", &pruned_degree); + kv.push("max_degree", &max_degree); + kv.push("alpha", &alpha); + kv.push("l_build", &l_build); + kv.push("num_threads", &self.num_threads); + write!(f, "{}", kv) + } +} + #[derive(Debug)] struct StaticBuild { data: Data, @@ -240,6 +344,17 @@ impl StaticBuild { } } +impl Display for StaticBuild { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("data", &self.data); + kv.push("build", &self.build); + kv.push("search", &self.search); + + write!(f, "{}", kv) + } +} + impl Input for StaticBuild { type Raw = dto::StaticBuild; @@ -287,20 +402,6 @@ impl Input for StaticBuild { } } -// impl std::fmt::Display for Inmem2Build { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// writeln!(f, "inmem2 f32 benchmark")?; -// writeln!(f, " max_degree: {}", self.max_degree)?; -// writeln!(f, " l_build: {}", self.l_build)?; -// writeln!(f, " alpha: {}", self.alpha)?; -// writeln!(f, " search_n: {}", self.search_n)?; -// writeln!(f, " search_l: {:?}", self.search_l)?; -// writeln!(f, " recall_k: {}", self.recall_k)?; -// writeln!(f, " num_threads: {}", self.num_threads)?; -// writeln!(f, " reps: {}", self.reps) -// } -// } - /////////////// // Benchmark // /////////////// @@ -316,14 +417,17 @@ impl Build { impl Benchmark for Build where - T: diskann_inmem::layers::FullPrecision - + diskann::graph::SampleableForStart, + T: diskann_inmem::layers::FullPrecision + diskann::graph::SampleableForStart + AsDataType, { type Input = StaticBuild; type Output = (); fn try_match(&self, input: &StaticBuild) -> Result { - Ok(MatchScore(0)) + if T::is_match(input.data.data_type) { + Ok(MatchScore(0)) + } else { + Err(FailureScore(1000)) + } } fn description( @@ -331,11 +435,28 @@ where f: &mut std::fmt::Formatter<'_>, input: Option<&StaticBuild>, ) -> std::fmt::Result { + match input { + Some(input) => { + let data_type = input.data.data_type; + if !T::is_match(data_type) { + write!( + f, + "expected data-type {}, instead got {}", + Quote(T::DATA_TYPE), + Quote(data_type) + )?; + } + } + None => { + write!( + f, + "full-precision streaming with data type {}", + Quote(T::DATA_TYPE) + )?; + } + } + Ok(()) - // match input { - // Some(i) => write!(f, "{i}"), - // None => write!(f, "inmem2 f32 benchmark"), - // } } fn run( @@ -344,7 +465,7 @@ where checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, ) -> anyhow::Result<()> { - // writeln!(output, "{input}")?; + writeln!(output, "{input}\n")?; // Load data. let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( @@ -377,17 +498,9 @@ where build_core::ids::Identity::::new(), ); - // writeln!( - // output, - // "Building index with {} threads...", - // input.num_threads - // )?; let build_results = build_core::build_tracked( builder, - build_core::Parallelism::dynamic( - diskann::utils::ONE, - input.build.num_threads, - ), + build_core::Parallelism::dynamic(diskann::utils::ONE, input.build.num_threads), &rt, Some(&ProgressMeter::new(output)), )?; @@ -395,19 +508,22 @@ where let total_build_time = build_results.end_to_end_latency(); writeln!( output, - "Build complete in {:.2}s", + "\nBuild complete in {:.2}s", total_build_time.as_seconds() )?; checkpoint.checkpoint(&total_build_time)?; // Search. - let queries: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.search.queries))?); + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &input.search.queries, + ))?); let max_k = input.search.maximum_recall_k(); - let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), Some(max_k))?; + let groundtruth = datafiles::load_groundtruth( + datafiles::BinFile(&input.search.groundtruth), + Some(max_k), + )?; - writeln!(output, "Loaded {} queries", queries.nrows())?; + writeln!(output, "Loaded {} queries\n", queries.nrows())?; let knn = benchmark_core::search::graph::KNN::new( index, @@ -415,30 +531,13 @@ where benchmark_core::search::graph::Strategy::broadcast(Strategy), )?; - let mut results = Vec::new(); - for num_threads in input.search.num_threads.iter() { - for instance in input.search.runs.iter() { - let setup = core_search::Setup { - threads: *num_threads, - tasks: *num_threads, - reps: input.search.reps, - }; - - let run = core_search::Run::new(instance.knn, setup); - - let r = core_search::search_all( - knn.clone(), - std::iter::once(run), - benchmark_core::search::graph::knn::Aggregator::new( - &groundtruth, - instance.recall_k, - instance.knn.k_value().get(), - GroundTruthMode::Fixed, - ), - )?; - results.extend(r.into_iter().map(SearchResults::new)); - } - } + let mut results = _knn( + &knn, + &groundtruth, + input.search.reps, + &input.search.num_threads, + &input.search.runs, + )?; let results = AggregatedSearchResults::Topk(results); @@ -448,318 +547,400 @@ where } } +fn _knn( + runner: &dyn crate::index::search::knn::Knn, + groundtruth: &dyn benchmark_core::recall::Rows, + reps: NonZeroUsize, + num_threads: &[NonZeroUsize], + instances: &[KnnInstance], +) -> anyhow::Result> { + let mut results = Vec::new(); + + for num_threads in num_threads.iter() { + for instance in instances.iter() { + let setup = core_search::Setup { + threads: *num_threads, + tasks: *num_threads, + reps: reps, + }; + + let run = core_search::Run::new(instance.knn, setup); + + let r = runner.search_all( + vec![run], + groundtruth, + instance.recall_k, + instance.knn.k_value().get(), + )?; + + results.extend(r); + } + } + + Ok(results) +} + /////////////// // Streaming // /////////////// -/// Input for the streaming inmem2 benchmark. -/// -/// Drives the inmem2 provider through a BigANN-style runbook. Because the inmem2 -/// provider already does external↔internal id translation, no `Managed`/ -/// `TagSlotManager` adapter is needed — the runbook talks to the provider directly. -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct Inmem2StreamInput { - /// Full dataset that the runbook indexes into. - data: InputFile, - /// Query set used for every search stage. +#[derive(Debug, Clone)] +struct StreamingKnnSearch { queries: InputFile, + reps: NonZeroUsize, + num_threads: NonZeroUsize, + runs: Vec, +} - /// Runbook parameters (path, dataset name, gt directory, ...). - runbook_params: DynamicRunbookParams, +impl StreamingKnnSearch { + fn from_raw( + raw: dto::StreamingKnnSearch, + checker: Option<&mut Checker>, + ) -> anyhow::Result { + let dto::StreamingKnnSearch { + mut queries, + reps, + num_threads, + runs, + } = raw; - max_degree: usize, - l_build: usize, - alpha: f32, + if let Some(checker) = checker { + queries.resolve(checker)?; + } - search_n: usize, - search_l: Vec, - recall_k: usize, + Ok(Self { + queries, + reps, + num_threads, + runs: KnnInstance::flatten(&runs)?, + }) + } +} - num_threads: usize, - reps: NonZeroUsize, +impl Display for StreamingKnnSearch { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("queries", &self.queries); + kv.push("reps", &self.reps); + kv.push("num_threads", &self.num_threads); + + let runs = Delimit::new(self.runs.iter(), "\n"); + kv.push("runs", &runs); + write!(f, "{}", runs) + } +} + +#[derive(Debug)] +struct RunBook { + runbook: bigann::RunBook, + delete_method: InplaceDeleteMethod, + delete_num_to_replace: usize, + // This is kept for display purposes. + runbook_path: InputFile, + dataset: String, +} + +impl RunBook { + fn from_raw(raw: dto::RunBook, checker: &mut Checker) -> anyhow::Result { + let dto::RunBook { + mut path, + dataset, + mut groundtruth_directory, + delete_method, + delete_num_to_replace, + } = raw; + + path.resolve(checker)?; + + let groundtruth_directory = checker.__check_dir(groundtruth_directory.as_ref())?; + + let runbook = bigann::RunBook::load( + &path, + &dataset, + &mut bigann::ScanDirectory::new(&groundtruth_directory)?, + )?; + + Ok(Self { + runbook, + delete_method: delete_method.into(), + delete_num_to_replace, + runbook_path: path, + dataset, + }) + } +} + +impl Display for RunBook { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + let path = self.runbook_path.display(); + kv.push("runbook", &path); + kv.push("dataset", &self.dataset); + + let max_points = self.runbook.max_points(); + let max_tag = self.runbook.max_tag(); + let num_stages = self.runbook.len(); + + kv.push("num_stages", &num_stages); + kv.push("max_active_points", &max_points); + if let Some(ref max_tag) = max_tag { + kv.push("max_tag", max_tag); + } + + kv.push_eager("delete_method", format_args!("{:?}", self.delete_method)); + kv.push("delete_num_to_replace", &self.delete_num_to_replace); + write!(f, "{}", kv) + } +} + +#[derive(Debug)] +struct BigANNStreaming { + data: Data, + build: BuildParams, + search: StreamingKnnSearch, + runbook: RunBook, +} + +impl BigANNStreaming { + fn from_raw(raw: dto::BigANNStreaming, checker: &mut Checker) -> anyhow::Result { + let data = Data::from_raw(raw.data, Some(checker))?; + let build = BuildParams::from_raw(raw.build, data.distance)?; + Ok(Self { + data, + build, + search: StreamingKnnSearch::from_raw(raw.search, Some(checker))?, + runbook: RunBook::from_raw(raw.runbook, checker)?, + }) + } +} + +impl Display for BigANNStreaming { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("data", &self.data); + kv.push("build", &self.build); + kv.push("search", &self.search); + kv.push("runbook", &self.runbook); + write!(f, "{}", kv) + } } -impl Input for Inmem2StreamInput { - type Raw = Inmem2StreamInput; +impl Input for BigANNStreaming { + type Raw = dto::BigANNStreaming; fn tag() -> &'static str { - "inmem2-stream" + "inmem2-streaming" } - fn from_raw(mut raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { - raw.data.resolve(checker)?; - raw.queries.resolve(checker)?; - raw.runbook_params.validate(checker)?; - Ok(raw) + fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { + Self::from_raw(raw, checker) } fn serialize(&self) -> anyhow::Result { - Ok(serde_json::to_value(self)?) - } - - fn example() -> Self { - Self { - data: InputFile::new("path/to/base.bin"), - queries: InputFile::new("path/to/query.bin"), - runbook_params: ::example(), - max_degree: 64, - l_build: 100, - alpha: 1.2, - search_n: 10, - search_l: vec![10, 20, 50, 100], - recall_k: 10, - num_threads: 4, - reps: NonZeroUsize::new(3).unwrap(), - } + Ok(serde_json::to_value(())?) } -} -impl std::fmt::Display for Inmem2StreamInput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "inmem2 f32 streaming benchmark")?; - writeln!( - f, - " runbook: {}", - self.runbook_params.runbook_path.display() - )?; - writeln!(f, " dataset: {}", self.runbook_params.dataset_name)?; - writeln!(f, " max_degree: {}", self.max_degree)?; - writeln!(f, " l_build: {}", self.l_build)?; - writeln!(f, " alpha: {}", self.alpha)?; - writeln!(f, " search_n: {}", self.search_n)?; - writeln!(f, " search_l: {:?}", self.search_l)?; - writeln!(f, " recall_k: {}", self.recall_k)?; - writeln!(f, " num_threads: {}", self.num_threads)?; - writeln!(f, " reps: {}", self.reps) + fn example() -> Self::Raw { + const FOUR: NonZeroUsize = NonZeroUsize::new(4).unwrap(); + const THREE: NonZeroUsize = NonZeroUsize::new(3).unwrap(); + + dto::BigANNStreaming { + data: dto::Data { + data_type: DataType::Float32, + data: InputFile::new("path/to/data"), + distance: SimilarityMeasure::SquaredL2, + }, + build: dto::BuildParams { + pruned_degree: 28, + max_degree: 32, + l_build: 100, + alpha: 1.2, + num_threads: FOUR, + }, + search: dto::StreamingKnnSearch { + queries: InputFile::new("path/to/queries"), + reps: THREE, + num_threads: FOUR, + runs: vec![dto::KnnSweep { + search_n: 10, + search_l: vec![10, 20, 30, 40, 50], + recall_k: 10, + }], + }, + runbook: dto::RunBook { + path: InputFile::new("path/to/runbook.yaml"), + dataset: "dataset-1M".into(), + groundtruth_directory: "groundtruth/dir".into(), + delete_method: crate::inputs::graph_index::InplaceDeleteMethod::TwoHopAndOneHop, + delete_num_to_replace: 3, + }, + } } } #[derive(Debug)] -struct Inmem2Stream; +struct StreamingBenchmark; -impl Benchmark for Inmem2Stream { - type Input = Inmem2StreamInput; - type Output = Vec; +impl Benchmark for StreamingBenchmark { + type Input = BigANNStreaming; + type Output = (); - fn try_match(&self, _input: &Inmem2StreamInput) -> Result { + fn try_match(&self, _input: &BigANNStreaming) -> Result { Ok(MatchScore(0)) } fn description( &self, f: &mut std::fmt::Formatter<'_>, - input: Option<&Inmem2StreamInput>, + input: Option<&BigANNStreaming>, ) -> std::fmt::Result { - match input { - Some(i) => write!(f, "{i}"), - None => write!(f, "inmem2 f32 streaming benchmark"), - } + Ok(()) + // match input { + // Some(i) => write!(f, "{i}"), + // None => write!(f, "inmem2 f32 streaming benchmark"), + // } } fn run( &self, - input: &Inmem2StreamInput, + input: &BigANNStreaming, _checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, ) -> anyhow::Result { - writeln!(output, "{input}")?; + writeln!(output, "{input}\n")?; // Load the runbook so we know the eventual capacity. - let gt_dir = input - .runbook_params - .resolved_gt_directory - .as_ref() - .ok_or_else(|| anyhow::anyhow!("groundtruth directory not resolved"))?; - - let runbook = bigann::RunBook::load( - &input.runbook_params.runbook_path, - &input.runbook_params.dataset_name, - &mut bigann::ScanDirectory::new(gt_dir)?, - )?; + let runbook = input.runbook.runbook.clone(); let max_points = runbook.max_points(); // Load the dataset (consumed by `WithData`) and queries. - let dataset: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data))?; - let queries: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&input.queries))?); + let dataset: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data.data))?; + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &input.search.queries, + ))?); let dim = dataset.ncols(); - writeln!( - output, - "Loaded dataset: {} points, dim={}", - dataset.nrows(), - dim - )?; - writeln!(output, "Loaded queries: {}", queries.nrows())?; - writeln!(output, "Runbook max_points: {max_points}")?; - // Compute the medoid of the dataset as the single start point. let start = StartPointStrategy::Medoid.compute(dataset.as_view())?; - let metric = Metric::L2; - let exact_max_degree = (input.max_degree as f32 * 1.3) as usize; - let layer = Full::::new(dim, metric); + let index_config = input.build.config.clone(); + let layer = Full::::new(dim, input.data.distance); - let config = diskann_inmem::provider::Config::new(max_points, exact_max_degree); - let provider = Provider::new(layer, config, start.row_iter())?; - - let config = graph::config::Builder::new_with( - input.max_degree, - graph::config::MaxDegree::new(exact_max_degree), - input.l_build, - metric.into(), - |b| { - b.alpha(input.alpha); - }, - ) - .build()?; + let config = + diskann_inmem::provider::Config::new(max_points, index_config.max_degree().get()); + let provider = Provider::<_, u32>::new(layer, config, start.row_iter())?; - let index = Arc::new(DiskANNIndex::new(config, provider, None)); + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); - let num_threads = NonZeroUsize::new(input.num_threads.max(1)).unwrap(); + let num_threads = input.build.num_threads; let runtime = benchmark_core::tokio::runtime(num_threads.get())?; - // Build the inner stream and wrap it with `WithData`. let stream = Stream { - index: index.clone(), + index, runtime, - ntasks: num_threads, - search_n: input.search_n, - search_l: input.search_l.clone(), - recall_k: input.recall_k, - reps: input.reps, + search: input.search.clone(), + ntasks: input.build.num_threads, + delete_method: input.runbook.delete_method, + delete_num_to_replace: input.runbook.delete_num_to_replace, }; - let max_k = input.recall_k; - let queries_for_search = queries.clone(); - let mut layered = bigann::WithData::new(stream, dataset, queries_for_search, move |path| { + let mut layered = bigann::WithData::new(stream, dataset, queries, move |path| { Ok(Box::new(datafiles::load_groundtruth( datafiles::BinFile(path), - Some(max_k), + None, )?)) }); - // Drive the runbook. - let mut runbook = runbook; + // Here we go! let mut results = Vec::new(); let stages = runbook.len(); - let mut stage_idx = 1usize; - - runbook.run_with(&mut layered, |o: StreamOutput| -> anyhow::Result<()> { - let banner = format!("Stage {} of {}: {}", stage_idx, stages, o.kind()); - write!(output, "{}", crate::utils::SmallBanner(&banner))?; - writeln!(output, "{o}")?; - stage_idx += 1; - results.push(o); - Ok(()) - })?; + let mut i = 1; + input.runbook.runbook.clone().run_with( + &mut layered, + |o: StreamStats| -> anyhow::Result<()> { + if o.is_maintain() { + let message = format!("Ran maintenance before stage {}", i); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + } else { + let message = format!("Finished stage {} of {}: {}", i, stages, o.kind()); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + i += 1; + } + writeln!(output, "{}", o)?; + results.push(o); + Ok(()) + }, + )?; write!( output, "{}", crate::utils::SmallBanner("End of Run Summary") )?; - let total_inserts: usize = results.iter().filter_map(|r| r.insert_count()).sum(); - let total_deletes: usize = results.iter().filter_map(|r| r.delete_count()).sum(); - let n_searches = results - .iter() - .filter(|r| matches!(r, StreamOutput::Search { .. })) - .count(); - writeln!( - output, - "stages={stages} inserts={total_inserts} deletes={total_deletes} searches={n_searches}", - )?; - Ok(results) + writeln!(output, "{}", Summary::new(results.iter()))?; + + Ok(()) } } -///////////////// -// Stream impl // -///////////////// +//////////// +// Stream // +//////////// -/// Inner streaming index over `inmem2`. -/// -/// Implements `streaming::Stream>` so it can be wrapped -/// by `bigann::WithData` and driven by `bigann::RunBook`. Replace and maintain are -/// not supported in v1; deletes are eager so no consolidation pass is needed. -struct Stream { - index: Arc>>>, +struct Stream +where + T: FullPrecision, +{ + index: Arc>>>, runtime: tokio::runtime::Runtime, + search: StreamingKnnSearch, ntasks: NonZeroUsize, - search_n: usize, - search_l: Vec, - recall_k: usize, - reps: NonZeroUsize, + delete_method: InplaceDeleteMethod, + delete_num_to_replace: usize, } -#[derive(Debug, Serialize)] -pub(crate) enum StreamOutput { - Insert { count: usize, latency_s: f64 }, - Delete { count: usize, latency_s: f64 }, - Search(Vec), -} - -#[derive(Debug, Serialize)] -pub(crate) struct SearchPoint { - pub search_l: usize, - pub recall: f64, - pub mean_qps: f64, - pub max_qps: f64, -} - -impl StreamOutput { - fn kind(&self) -> &'static str { - match self { - Self::Insert { .. } => "insert", - Self::Delete { .. } => "delete", - Self::Search(_) => "search", - } - } +impl Stream +where + T: FullPrecision, +{ + fn insert_( + &mut self, + data: MatrixView<'_, T>, + ids: Range, + ) -> anyhow::Result { + anyhow::ensure!( + data.nrows() == ids.len(), + "insert: data rows ({}) != ids range ({})", + data.nrows(), + ids.len(), + ); - fn insert_count(&self) -> Option { - match self { - Self::Insert { count, .. } => Some(*count), - _ => None, - } - } + let runner = build_core::graph::SingleInsert::new( + self.index.clone(), + Arc::new(data.to_owned()), + Strategy, + build_core::ids::Range::::new(ids.start as u32..ids.end as u32), + ); - fn delete_count(&self) -> Option { - match self { - Self::Delete { count, .. } => Some(*count), - _ => None, - } - } -} + let results = build_core::build( + runner, + build_core::Parallelism::dynamic(diskann::utils::ONE, self.ntasks), + &self.runtime, + )?; -impl std::fmt::Display for StreamOutput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Insert { count, latency_s } => { - writeln!(f, " inserted {count} points in {latency_s:.3}s") - } - Self::Delete { count, latency_s } => { - writeln!(f, " deleted {count} points in {latency_s:.3}s") - } - Self::Search(points) => { - for p in points { - writeln!( - f, - " L={:<4} recall={:.4} QPS={:.0} (max {:.0})", - p.search_l, p.recall, p.mean_qps, p.max_qps, - )?; - } - Ok(()) - } - } + BuildStats::new(BuildKind::SingleInsert, results) } } -impl streaming::Stream> for Stream { - type Output = StreamOutput; +impl streaming::Stream> for Stream +where + T: FullPrecision, +{ + type Output = StreamStats; fn search( &mut self, - (queries, groundtruth): (Arc>, &dyn recall::Rows), + (queries, groundtruth): (Arc>, &dyn recall::Rows), ) -> anyhow::Result { let knn = benchmark_core::search::graph::KNN::new( self.index.clone(), @@ -767,95 +948,34 @@ impl streaming::Stream> for Stream { benchmark_core::search::graph::Strategy::broadcast(Strategy), )?; - let mut points = Vec::with_capacity(self.search_l.len()); - for &search_l in &self.search_l { - let params = graph::search::Knn::new(self.search_n, search_l, None)?; - let setup = core_search::Setup { - threads: self.ntasks, - tasks: self.ntasks, - reps: self.reps, - }; - let run = core_search::Run::new(params, setup); - - let summaries = core_search::search_all( - knn.clone(), - std::iter::once(run), - benchmark_core::search::graph::knn::Aggregator::new( - groundtruth, - self.recall_k, - self.search_n, - GroundTruthMode::Fixed, - ), - )?; + let r = _knn( + &knn, + groundtruth, + self.search.reps, + std::slice::from_ref(&self.search.num_threads), + &self.search.runs, + )?; - for summary in &summaries { - let qps: Vec = summary - .end_to_end_latencies - .iter() - .map(|lat| summary.recall.num_queries as f64 / lat.as_seconds()) - .collect(); - let max_qps = qps.iter().cloned().fold(0.0f64, f64::max); - let mean_qps = qps.iter().sum::() / qps.len().max(1) as f64; - points.push(SearchPoint { - search_l, - recall: summary.recall.average, - mean_qps, - max_qps, - }); - } - } - Ok(StreamOutput::Search(points)) + Ok(StreamStats::Search(r)) } fn insert( &mut self, - (data, ids): (MatrixView<'_, f32>, Range), + (data, ids): (MatrixView<'_, T>, Range), ) -> anyhow::Result { - anyhow::ensure!( - data.nrows() == ids.len(), - "insert: data rows ({}) != ids range ({})", - data.nrows(), - ids.len(), - ); - - let count = data.nrows(); - let slots: Box<[u32]> = ids - .map(|id| u32::try_from(id)) - .collect::, _>>()?; - - let runner = build_core::graph::SingleInsert::new( - self.index.clone(), - Arc::new(data.to_owned()), - Strategy, - build_core::ids::Slice::new(slots), - ); - - let results = build_core::build( - runner, - build_core::Parallelism::dynamic(diskann::utils::ONE, self.ntasks), - &self.runtime, - )?; - - let latency_s = results.end_to_end_latency().as_seconds(); - Ok(StreamOutput::Insert { count, latency_s }) + self.insert_(data, ids).map(StreamStats::Insert) } fn delete(&mut self, ids: Range) -> anyhow::Result { - let count = ids.len(); - let provider = self.index.provider(); - let ctx = diskann_inmem::Context; - - let start = std::time::Instant::now(); - let runner = streaming::graph::InplaceDelete::new( self.index.clone(), Strategy, - 3, - diskann::graph::InplaceDeleteMethod::OneHop, - build_core::ids::Slice::new(ids.clone().into_iter().map(|i| i as u32).collect()), + self.delete_num_to_replace, + self.delete_method, + build_core::ids::Range::new(ids.start as u32..ids.end as u32), ); - let _ = build_core::build( + let r = build_core::build( runner, diskann_benchmark_core::build::Parallelism::fixed( Some(diskann::utils::ONE), @@ -864,22 +984,30 @@ impl streaming::Stream> for Stream { &self.runtime, )?; - let latency_s = start.elapsed().as_secs_f64(); - - Ok(StreamOutput::Delete { count, latency_s }) + Ok(StreamStats::Delete(GenericStats::new("delete".into(), r)?)) } fn replace( &mut self, - _args: (MatrixView<'_, f32>, Range), + (data, ids): (MatrixView<'_, T>, Range), ) -> anyhow::Result { - anyhow::bail!("inmem2-f32-stream: replace is not supported in v1") + use diskann::provider::Delete; + + // TODO: This is kind of a hack. It would be ideal to parallelize this. + // + // Also, this is *way* more expensive than it needs to be because each delete creates + // and then destroys an EBR guard. + let ctx = diskann_inmem::Context; + for id in ids.clone() { + self.runtime + .block_on(self.index.provider().delete(&ctx, &(id as u32)))?; + } + + self.insert_(data, ids).map(StreamStats::Replace) } fn maintain(&mut self, _: ()) -> anyhow::Result { - anyhow::bail!( - "inmem2-f32-stream: maintain is not supported (deletes are eager, no consolidation needed)" - ) + Ok(StreamStats::Maintain(vec![])) } fn needs_maintenance(&mut self) -> bool { diff --git a/diskann-benchmark/src/index/streaming/mod.rs b/diskann-benchmark/src/index/streaming/mod.rs index 3b4814549..83b26e1d0 100644 --- a/diskann-benchmark/src/index/streaming/mod.rs +++ b/diskann-benchmark/src/index/streaming/mod.rs @@ -9,3 +9,4 @@ pub(crate) mod stats; pub(crate) use full_precision::FullPrecisionStream; pub(crate) use managed::{Managed, ManagedStream}; +pub(crate) use stats::StreamStats; diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs index 3eb77b558..e8e97af72 100644 --- a/diskann-inmem/src/epoch.rs +++ b/diskann-inmem/src/epoch.rs @@ -99,7 +99,7 @@ pub(crate) struct Registry { // ``` // // We cycle among the queues in a round-robin manner. - retiring: [SegQueue; 4], + retiring: Box<[SegQueue; 4]>, } // Return the queue index for the `epoch`. @@ -125,7 +125,7 @@ impl Registry { guards: (0..capacity).map(|_| AtomicU64::new(0)).collect(), hint: AtomicUsize::new(0), epoch: AtomicU64::new(1), - retiring: core::array::from_fn(|_| SegQueue::new()), + retiring: Box::new(core::array::from_fn(|_| SegQueue::new())), drain: Mutex::new(()), } } diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index a9fef5deb..67e637daf 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -21,7 +21,7 @@ use diskann_wide::{ use half::f16; use thiserror::Error; -use crate::{layers, num::Bytes}; +use crate::{layers, num::Bytes, Hidden}; /// A useful trait bound for types compatible with [`Full`]. /// @@ -29,10 +29,11 @@ use crate::{layers, num::Bytes}; /// a single bound. pub trait FullPrecision: bytemuck::Pod + std::fmt::Debug + Send + Sync { #[doc(hidden)] - fn __new(dim: usize, metric: Metric) -> Full; + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full; #[doc(hidden)] fn __query_distance<'a, V>( + _: Hidden, full: &'a Full, query: &'a [Self], visitor: V, @@ -60,7 +61,7 @@ where where T: FullPrecision, { - T::__new(dim, metric) + T::__new(Hidden::new(), dim, metric) } fn from_distance_provider(dim: usize, metric: Metric) -> Self @@ -165,7 +166,7 @@ where where V: layers::QueryVisitor<'a>, { - T::__query_distance(self, query, visitor) + T::__query_distance(Hidden::new(), self, query, visitor) } } @@ -349,9 +350,9 @@ crate::opaque!(QueryDistanceError); macro_rules! mint { ($query:ident, $visitor:ident, $T:ty => { $N:literal, $f:ident }) => {{ - mint!($query, $visitor, { $T, $T } => { $N x $f }) + mint!($query, $visitor, { $T, $T } => { $N, $f }) }}; - ($query:ident, $visitor:ident, { $T:ty, $U:ty } => { $N:literal x $f:ident }) => {{ + ($query:ident, $visitor:ident, { $T:ty, $U:ty } => { $N:literal, $f:ident }) => {{ let inner = QueryDistance::<$T, $U, Specialize<$N, $f>>::new($query); $visitor.visit_sized::<{ $N * std::mem::size_of::<$U>() }, _>(inner) }}; @@ -365,11 +366,12 @@ macro_rules! mint { } impl FullPrecision for f32 { - fn __new(dim: usize, metric: Metric) -> Full { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { Full::from_distance_provider(dim, metric) } fn __query_distance<'a, V>( + _: Hidden, full: &'a Full, query: &'a [f32], visitor: V, @@ -390,12 +392,8 @@ impl FullPrecision for f32 { } } Metric::InnerProduct => { - // if full.dim() == 768 { - // mint!(query, visitor, f32 => { 768, InnerProduct }) - // } else { - mint!(query, visitor, f32 => InnerProduct) - // } - }, + mint!(query, visitor, f32 => InnerProduct) + } Metric::Cosine => mint!(query, visitor, f32 => Cosine), Metric::CosineNormalized => mint!(query, visitor, f32 => CosineNormalized), }; @@ -405,11 +403,12 @@ impl FullPrecision for f32 { } impl FullPrecision for f16 { - fn __new(dim: usize, metric: Metric) -> Full { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { Full::from_distance_provider(dim, metric) } fn __query_distance<'a, V>( + _: Hidden, full: &'a Full, query: &'a [f16], visitor: V, @@ -424,7 +423,13 @@ impl FullPrecision for f16 { let query = Calf::Owned(as_f32); let output = match full.metric { - Metric::L2 => mint!(query, visitor, { f32, f16 } => SquaredL2), + Metric::L2 => { + if full.dim() == 100 { + mint!(query, visitor, { f32, f16 } => { 100, SquaredL2 }) + } else { + mint!(query, visitor, { f32, f16 } => SquaredL2) + } + } Metric::InnerProduct => mint!(query, visitor, { f32, f16 } => InnerProduct), Metric::Cosine => mint!(query, visitor, { f32, f16 } => Cosine), Metric::CosineNormalized => mint!(query, visitor, { f32, f16 } => CosineNormalized), @@ -435,11 +440,12 @@ impl FullPrecision for f16 { } impl FullPrecision for u8 { - fn __new(dim: usize, metric: Metric) -> Full { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { Full::from_distance_provider(dim, metric) } fn __query_distance<'a, V>( + _: Hidden, full: &'a Full, query: &'a [u8], visitor: V, @@ -469,11 +475,12 @@ impl FullPrecision for u8 { } impl FullPrecision for i8 { - fn __new(dim: usize, metric: Metric) -> Full { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { Full::from_distance_provider(dim, metric) } fn __query_distance<'a, V>( + _: Hidden, full: &'a Full, query: &'a [i8], visitor: V, diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index 55d5b4653..c126a77ad 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -31,6 +31,23 @@ mod test; #[doc(hidden)] pub mod integration; +//----------------// +// Internal Tools // +//----------------// + +/// A "public" type that can only be constructed by this crate. +/// +/// This helps with public traits with internal methods that we don't want users to call. +#[doc(hidden)] +#[derive(Debug)] +pub struct Hidden(()); + +impl Hidden { + const fn new() -> Self { + Self(()) + } +} + macro_rules! opaque { ($T:ty) => { impl From<$T> for diskann::ANNError { @@ -44,3 +61,5 @@ macro_rules! opaque { } pub(crate) use opaque; + + diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 1e6dc8646..ffcea501d 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -424,6 +424,8 @@ impl glue::SearchAccessor for SearchAccessor<'_> { // SAFETY: We've verified that each entry in `self.ids` is in-bounds and the // `self.buffer` is long enough to hold all the IDs. let processed = unsafe { + // self.expand_beam + // .expand_beam(&self.ids, 8, &self.reader, &mut on_neighbors) self.expand_beam .expand_beam(&self.ids, 8, &self.reader, &mut self.buffer) }?; @@ -459,6 +461,7 @@ trait ExpandBeam: Send + Sync + std::fmt::Debug { list: &[u32], lookahead: usize, reader: &store::Reader<'_>, + // f: &mut dyn FnMut(u32, f32), buffer: &mut [(u32, f32)], ) -> ANNResult; } @@ -480,9 +483,11 @@ where list: &[u32], lookahead: usize, reader: &store::Reader<'_>, + // f: &mut dyn FnMut(u32, f32), buffer: &mut [(u32, f32)], ) -> ANNResult { // SAFETY: Inherited from caller. + // unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, f) } unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, buffer) } } } @@ -533,7 +538,7 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { // SAFETY: Inherited from caller. unsafe { _mm_prefetch(ptr.add(stride * (lines - 1)), _MM_HINT_T0) }; - for i in 0..(lines - 1).min(16) { + for i in 0..(lines - 1) { // SAFETY: Inherited from caller. unsafe { _mm_prefetch(ptr.add(stride * i), _MM_HINT_T0); @@ -552,6 +557,7 @@ unsafe fn expand_beam_inner( list: &[u32], lookahead: usize, reader: &store::Reader<'_>, + // f: &mut dyn FnMut(u32, f32), buffer: &mut [(u32, f32)], ) -> ANNResult where @@ -610,6 +616,7 @@ where // SAFETY: Caller asserts that `i` is in-bounds. if let Some(data) = unsafe { reader.read_in_bounds(i.into_usize()) } { *unsafe { buffer.get_unchecked_mut(processed) } = (i, distance.evaluate(data)?); + // f(i, distance.evaluate(data)?); processed += 1; } } @@ -911,13 +918,13 @@ where } } -// TODO: This is such a hack. -impl glue::InplaceDeleteStrategy, M>> for Strategy +impl glue::InplaceDeleteStrategy, M>> for Strategy where M: Id, + T: layers::FullPrecision, { - type DeleteElement<'a> = &'a [f32]; - type DeleteElementGuard = Box<[f32]>; + type DeleteElement<'a> = &'a [T]; + type DeleteElementGuard = Box<[T]>; type DeleteElementError = ANNError; type PruneStrategy = Self; @@ -939,7 +946,7 @@ where fn get_delete_element<'a>( &'a self, - provider: &'a Provider, M>, + provider: &'a Provider, M>, _context: &'a Context, id: u32, ) -> impl Future> + Send @@ -956,9 +963,10 @@ where } }; - let mut buf: Box<[_]> = std::iter::repeat_n(0.0, provider.layer.dim()).collect(); + let mut buf: Box<[_]> = + std::iter::repeat_n(T::zeroed(), provider.layer.dim()).collect(); - bytemuck::must_cast_slice_mut::(&mut buf).copy_from_slice(data); + bytemuck::must_cast_slice_mut::(&mut buf).copy_from_slice(data); Ok(buf) }; ready(work) From af57bb661cc9e2fa804e141252603dcc3d6a513c Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 30 Jun 2026 18:03:16 -0700 Subject: [PATCH 33/45] Minor tweaks. --- diskann-benchmark-runner/src/utils/fmt.rs | 23 ++++++- diskann-benchmark/src/index/benchmarks.rs | 48 +++++++------- diskann-benchmark/src/index/inmem2.rs | 81 ++++++++++++++++------- diskann-inmem/src/layers/full.rs | 2 +- diskann-inmem/src/lib.rs | 2 - diskann-inmem/src/provider.rs | 8 +-- 6 files changed, 106 insertions(+), 58 deletions(-) diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 95a49d066..f9034a4e6 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -388,7 +388,26 @@ impl std::fmt::Display for MaybeLazy<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Lazy(lazy) => write!(f, "{}", lazy), - Self::Eager(s) => f.write_str(&s), + Self::Eager(s) => f.write_str(s), + } + } +} + +impl std::fmt::Debug for MaybeLazy<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + struct AsDisplay<'a>(&'a dyn std::fmt::Display); + impl std::fmt::Debug for AsDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + match self { + Self::Lazy(o) => { + let as_display = AsDisplay(o); + f.debug_tuple("MaybeLazy::Lazy").field(&as_display).finish() + } + Self::Eager(s) => f.debug_tuple("MaybeLazy::Eager").field(s).finish(), } } } @@ -434,6 +453,7 @@ impl std::fmt::Display for MaybeLazy<'_> { /// /// assert_eq!(kv.to_string(), expected); /// ``` +#[derive(Debug, Default)] pub struct KeyValue<'a> { kv: Vec<(&'a str, MaybeLazy<'a>)>, max_key_length: usize, @@ -454,6 +474,7 @@ impl<'a> KeyValue<'a> { self.kv.push((key, MaybeLazy::Lazy(value))) } + /// Push the key-value pair to `self` for formatting - eagerly formatting `value`. pub fn push_eager(&mut self, key: &'a str, value: D) where D: std::fmt::Display, diff --git a/diskann-benchmark/src/index/benchmarks.rs b/diskann-benchmark/src/index/benchmarks.rs index 45a142344..0a66576a5 100644 --- a/diskann-benchmark/src/index/benchmarks.rs +++ b/diskann-benchmark/src/index/benchmarks.rs @@ -86,36 +86,36 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> "graph-index-full-precision-f16", FullPrecision::::new().search(plugins::Topk), )?; - // registry.register( - // "graph-index-full-precision-u8", - // FullPrecision::::new().search(plugins::Topk), - // )?; - // registry.register( - // "graph-index-full-precision-i8", - // FullPrecision::::new().search(plugins::Topk), - // )?; + registry.register( + "graph-index-full-precision-u8", + FullPrecision::::new().search(plugins::Topk), + )?; + registry.register( + "graph-index-full-precision-i8", + FullPrecision::::new().search(plugins::Topk), + )?; // Dynamic Full Precision registry.register( "graph-index-dynamic-full-precision-f32", DynamicFullPrecision::::new(), )?; - // registry.register( - // "graph-index-dynamic-full-precision-f16", - // DynamicFullPrecision::::new(), - // )?; - // registry.register( - // "graph-index-dynamic-full-precision-u8", - // DynamicFullPrecision::::new(), - // )?; - // registry.register( - // "graph-index-dynamic-full-precision-i8", - // DynamicFullPrecision::::new(), - // )?; - - // product::register_benchmarks(registry)?; - // scalar::register_benchmarks(registry)?; - // spherical::register_benchmarks(registry)?; + registry.register( + "graph-index-dynamic-full-precision-f16", + DynamicFullPrecision::::new(), + )?; + registry.register( + "graph-index-dynamic-full-precision-u8", + DynamicFullPrecision::::new(), + )?; + registry.register( + "graph-index-dynamic-full-precision-i8", + DynamicFullPrecision::::new(), + )?; + + product::register_benchmarks(registry)?; + scalar::register_benchmarks(registry)?; + spherical::register_benchmarks(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 69debbea9..14130bcf8 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -9,7 +9,6 @@ use std::{ num::NonZeroUsize, ops::Range, sync::Arc, - path::PathBuf, }; use diskann::graph::{self, DiskANNIndex, InplaceDeleteMethod, StartPointStrategy}; @@ -33,7 +32,6 @@ use diskann_inmem::{ }; use diskann_utils::views::{Matrix, MatrixView}; use diskann_vector::distance::Metric; -use half::f16; use serde::{Deserialize, Serialize}; use crate::{ @@ -45,15 +43,13 @@ use crate::{ StreamStats, }, }, - inputs::graph_index::DynamicRunbookParams, utils::{datafiles, SimilarityMeasure}, }; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { registry.register("inmem2-f32", Build::::new())?; - registry.register("inmem2-f16", Build::::new())?; - - registry.register("inmem2-f32-stream", StreamingBenchmark)?; + // registry.register("inmem2-f16", Build::::new())?; + registry.register("inmem2-f32-stream", StreamingBenchmark::::new())?; Ok(()) } @@ -322,10 +318,14 @@ struct StaticBuild { data: Data, build: BuildParams, search: KnnSearch, + // The serialized representation of the original input. + input: serde_json::Value, } impl StaticBuild { fn from_raw(raw: dto::StaticBuild, mut checker: Option<&mut Checker>) -> anyhow::Result { + let input = serde_json::to_value(&raw)?; + let dto::StaticBuild { data, build, @@ -340,6 +340,7 @@ impl StaticBuild { data, build, search, + input, }) } } @@ -367,7 +368,7 @@ impl Input for StaticBuild { } fn serialize(&self) -> anyhow::Result { - Ok(serde_json::to_value(())?) + Ok(self.input.clone()) } fn example() -> Self::Raw { @@ -450,7 +451,7 @@ where None => { write!( f, - "full-precision streaming with data type {}", + "full-precision static build+search with data type {}", Quote(T::DATA_TYPE) )?; } @@ -531,7 +532,7 @@ where benchmark_core::search::graph::Strategy::broadcast(Strategy), )?; - let mut results = _knn( + let results = _knn( &knn, &groundtruth, input.search.reps, @@ -561,7 +562,7 @@ fn _knn( let setup = core_search::Setup { threads: *num_threads, tasks: *num_threads, - reps: reps, + reps, }; let run = core_search::Run::new(instance.knn, setup); @@ -645,7 +646,7 @@ impl RunBook { let dto::RunBook { mut path, dataset, - mut groundtruth_directory, + groundtruth_directory, delete_method, delete_num_to_replace, } = raw; @@ -699,10 +700,13 @@ struct BigANNStreaming { build: BuildParams, search: StreamingKnnSearch, runbook: RunBook, + // The serialized representation of the original input. + input: serde_json::Value, } impl BigANNStreaming { fn from_raw(raw: dto::BigANNStreaming, checker: &mut Checker) -> anyhow::Result { + let input = serde_json::to_value(&raw)?; let data = Data::from_raw(raw.data, Some(checker))?; let build = BuildParams::from_raw(raw.build, data.distance)?; Ok(Self { @@ -710,6 +714,7 @@ impl BigANNStreaming { build, search: StreamingKnnSearch::from_raw(raw.search, Some(checker))?, runbook: RunBook::from_raw(raw.runbook, checker)?, + input, }) } } @@ -737,7 +742,7 @@ impl Input for BigANNStreaming { } fn serialize(&self) -> anyhow::Result { - Ok(serde_json::to_value(())?) + Ok(self.input.clone()) } fn example() -> Self::Raw { @@ -779,14 +784,27 @@ impl Input for BigANNStreaming { } #[derive(Debug)] -struct StreamingBenchmark; +struct StreamingBenchmark(std::marker::PhantomData); + +impl StreamingBenchmark { + fn new() -> Self { + Self(std::marker::PhantomData) + } +} -impl Benchmark for StreamingBenchmark { +impl Benchmark for StreamingBenchmark +where + T: FullPrecision + AsDataType + diskann::graph::SampleableForStart, +{ type Input = BigANNStreaming; type Output = (); - fn try_match(&self, _input: &BigANNStreaming) -> Result { - Ok(MatchScore(0)) + fn try_match(&self, input: &BigANNStreaming) -> Result { + if T::is_match(input.data.data_type) { + Ok(MatchScore(0)) + } else { + Err(FailureScore(1000)) + } } fn description( @@ -794,11 +812,28 @@ impl Benchmark for StreamingBenchmark { f: &mut std::fmt::Formatter<'_>, input: Option<&BigANNStreaming>, ) -> std::fmt::Result { + match input { + Some(input) => { + let data_type = input.data.data_type; + if !T::is_match(data_type) { + write!( + f, + "expected data-type {}, instead got {}", + Quote(T::DATA_TYPE), + Quote(data_type) + )?; + } + } + None => { + write!( + f, + "full-precision streaming with data type {}", + Quote(T::DATA_TYPE) + )?; + } + } + Ok(()) - // match input { - // Some(i) => write!(f, "{i}"), - // None => write!(f, "inmem2 f32 streaming benchmark"), - // } } fn run( @@ -814,8 +849,8 @@ impl Benchmark for StreamingBenchmark { let max_points = runbook.max_points(); // Load the dataset (consumed by `WithData`) and queries. - let dataset: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data.data))?; - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + let dataset: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data.data))?; + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( &input.search.queries, ))?); let dim = dataset.ncols(); @@ -823,7 +858,7 @@ impl Benchmark for StreamingBenchmark { // Compute the medoid of the dataset as the single start point. let start = StartPointStrategy::Medoid.compute(dataset.as_view())?; let index_config = input.build.config.clone(); - let layer = Full::::new(dim, input.data.distance); + let layer = Full::::new(dim, input.data.distance); let config = diskann_inmem::provider::Config::new(max_points, index_config.max_degree().get()); diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index 67e637daf..e0f2c0d0d 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -21,7 +21,7 @@ use diskann_wide::{ use half::f16; use thiserror::Error; -use crate::{layers, num::Bytes, Hidden}; +use crate::{Hidden, layers, num::Bytes}; /// A useful trait bound for types compatible with [`Full`]. /// diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs index c126a77ad..172a7478d 100644 --- a/diskann-inmem/src/lib.rs +++ b/diskann-inmem/src/lib.rs @@ -61,5 +61,3 @@ macro_rules! opaque { } pub(crate) use opaque; - - diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index ffcea501d..681a7b1ba 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -424,8 +424,6 @@ impl glue::SearchAccessor for SearchAccessor<'_> { // SAFETY: We've verified that each entry in `self.ids` is in-bounds and the // `self.buffer` is long enough to hold all the IDs. let processed = unsafe { - // self.expand_beam - // .expand_beam(&self.ids, 8, &self.reader, &mut on_neighbors) self.expand_beam .expand_beam(&self.ids, 8, &self.reader, &mut self.buffer) }?; @@ -461,7 +459,6 @@ trait ExpandBeam: Send + Sync + std::fmt::Debug { list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - // f: &mut dyn FnMut(u32, f32), buffer: &mut [(u32, f32)], ) -> ANNResult; } @@ -483,11 +480,9 @@ where list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - // f: &mut dyn FnMut(u32, f32), buffer: &mut [(u32, f32)], ) -> ANNResult { // SAFETY: Inherited from caller. - // unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, f) } unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, buffer) } } } @@ -557,7 +552,6 @@ unsafe fn expand_beam_inner( list: &[u32], lookahead: usize, reader: &store::Reader<'_>, - // f: &mut dyn FnMut(u32, f32), buffer: &mut [(u32, f32)], ) -> ANNResult where @@ -615,8 +609,8 @@ where // SAFETY: Caller asserts that `i` is in-bounds. if let Some(data) = unsafe { reader.read_in_bounds(i.into_usize()) } { + // SAFETY: Inherited from caller. *unsafe { buffer.get_unchecked_mut(processed) } = (i, distance.evaluate(data)?); - // f(i, distance.evaluate(data)?); processed += 1; } } From b91419bc2222ee4f165ecb8014ad286526c186e0 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 1 Jul 2026 17:55:08 -0700 Subject: [PATCH 34/45] Add RFC. --- .github/workflows/ci.yml | 2 +- diskann-benchmark/Cargo.toml | 5 +- diskann-benchmark/src/index/inmem2.rs | 4 +- diskann-benchmark/src/index/mod.rs | 8 +++- diskann-inmem/src/layers/mod.rs | 6 +-- diskann-inmem/src/num.rs | 23 ++++++++++ diskann-inmem/src/provider.rs | 66 ++++++++++++++++++++++----- diskann-inmem/src/store.rs | 11 +++++ 8 files changed, 105 insertions(+), 20 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f5eeffa6d..42983aaee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ env: CARGO_TERM_COLOR: always # The features we want to explicitly test. For example, the `flatbuffers-build` feature # of `diskann-quantization` requires additional setup and so must not be included by default. - DISKANN_FEATURES: "virtual_storage,spherical-quantization,product-quantization,tracing,experimental_diversity_search,disk-index,flatbuffers,linalg,codegen" + DISKANN_FEATURES: "virtual_storage,spherical-quantization,product-quantization,tracing,experimental_diversity_search,disk-index,flatbuffers,linalg,codegen,integration-test,inmem2" # Intel SDE version used for baseline and AVX-512 emulation jobs. SDE_VERSION: "sde-external-10.7.0-2026-02-18-lin" diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index b4c29cc5a..27c249999 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -39,7 +39,7 @@ opentelemetry_sdk = { workspace = true, optional = true } scopeguard = { version = "1.2", optional = true } diskann-benchmark-core = { workspace = true, features = ["bigann"] } itertools.workspace = true -diskann-inmem = { workspace = true } +diskann-inmem = { workspace = true, optional = true } [lints] clippy.undocumented_unsafe_blocks = "warn" @@ -68,6 +68,9 @@ minmax-quantization = [] # Enable multi-vector MaxSim distance benchmarks multi-vector = [] +# Enable inmem 2.0 +inmem2 = ["dep:diskann-inmem"] + # Enable bftree backend bftree = ["dep:diskann-bftree"] diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 14130bcf8..660d92517 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -797,7 +797,7 @@ where T: FullPrecision + AsDataType + diskann::graph::SampleableForStart, { type Input = BigANNStreaming; - type Output = (); + type Output = Vec; fn try_match(&self, input: &BigANNStreaming) -> Result { if T::is_match(input.data.data_type) { @@ -914,7 +914,7 @@ where writeln!(output, "{}", Summary::new(results.iter()))?; - Ok(()) + Ok(results) } } diff --git a/diskann-benchmark/src/index/mod.rs b/diskann-benchmark/src/index/mod.rs index f3191e6fd..0d8dd9de3 100644 --- a/diskann-benchmark/src/index/mod.rs +++ b/diskann-benchmark/src/index/mod.rs @@ -11,18 +11,22 @@ mod streaming; mod benchmarks; mod inmem; -mod inmem2; mod result; #[cfg(feature = "bftree")] mod bftree; +#[cfg(feature = "inmem2")] +mod inmem2; + pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { benchmarks::register_benchmarks(registry)?; - inmem2::register_benchmarks(registry)?; #[cfg(feature = "bftree")] bftree::register_benchmarks(registry)?; + #[cfg(feature = "inmem2")] + inmem2::register_benchmarks(registry)?; + Ok(()) } diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index 648e5a6b4..eb14f57eb 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -68,10 +68,10 @@ pub trait AsDistance: Send + Sync + std::fmt::Debug { /// A unary query distance on raw bytes slices. /// /// When paired with [`Layer`] via helpers like [`Search`], implementations may assume -/// that `x` and `y` have length [`Layer::bytes`]. +/// that `x` has length [`Layer::bytes`]. /// -/// No alignment guarantees are made for `x` and `y`, though in practice they are likely -/// to be aligned to 32 or 64 bytes. +/// No alignment guarantees are made for `x`, though in practice is isk likely to be +/// aligned to 32 or 64 bytes. pub trait QueryDistance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8]) -> ANNResult; } diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index eeba142ac..eb4a8c8d3 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -5,23 +5,30 @@ use std::num::NonZeroUsize; +/// An unsigned number of bytes. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct Bytes(usize); impl Bytes { + /// The approximate number of bytes in a CPU cache line. pub const CACHELINE: Self = Self::new(64); + + /// Zero bytes. pub const ZERO: Self = Self::new(0); + /// Construct a new [`Bytes`]. #[inline] pub const fn new(bytes: usize) -> Self { Self(bytes) } + /// Return the current value of `self`. #[inline] pub const fn value(self) -> usize { self.0 } + /// Add `self` and `other`, returning `None` if the sum would overflow `usize`. #[inline] pub(crate) const fn checked_add(self, other: Bytes) -> Option { match self.value().checked_add(other.value()) { @@ -30,6 +37,7 @@ impl Bytes { } } + /// Multiply `self` and `other`, returning `None` if the sum would overflow `usize`. #[inline] pub(crate) const fn checked_mul(self, other: usize) -> Option { match self.value().checked_mul(other) { @@ -38,16 +46,21 @@ impl Bytes { } } + /// Perform integer division of `self` by `other`. #[inline] pub(crate) const fn div(self, other: NonZeroUsize) -> Bytes { Bytes::new(self.value() / other.get()) } + /// Subtract `other` from `self` without checking for underflow. #[inline] pub(crate) const fn unchecked_sub(self, other: Bytes) -> Bytes { Self::new(self.value() - other.value()) } + /// Return the smallest multiple of `other` greater-than or equal to `self`. + /// + /// Returns `None` if the next multiple exceeds `usize::MAX`. #[inline] pub(crate) const fn checked_next_multiple_of(self, other: Bytes) -> Option { match self.value().checked_next_multiple_of(other.value()) { @@ -56,11 +69,13 @@ impl Bytes { } } + /// Return the size of `T` in [`Bytes`]. #[inline] pub const fn size_of() -> Self { Self::new(std::mem::size_of::()) } + /// Return `true` if `self` is zero. pub const fn is_zero(self) -> bool { self.0 == 0 } @@ -72,11 +87,16 @@ impl std::fmt::Display for Bytes { } } +/// An alignment for an allocation. +/// +/// All alignemnts are guaranteed to be powers of two. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[repr(transparent)] pub struct Align(NonZeroUsize); impl Align { + /// Construct a new [`Align`] from `value`, returning `None` if `value` is not a power + /// of two. pub const fn new(value: usize) -> Option { match NonZeroUsize::new(value) { Some(value) => { @@ -90,6 +110,7 @@ impl Align { } } + /// Return the raw value of `self`. pub const fn value(self) -> usize { self.0.get() } @@ -106,11 +127,13 @@ impl Align { Self(unsafe { NonZeroUsize::new_unchecked(value) }) } + /// Return the alignment of a type `T`. pub const fn of() -> Self { // SAFETY: `std::mem::align_of` is guaranteed to return a power of 2. unsafe { Self::new_unchecked(std::mem::align_of::()) } } + /// Construct a new [`Align`] from a [`std::alloc::Layout`]. pub const fn from_layout(layout: std::alloc::Layout) -> Self { // SAFETY: `Layout::align` is guaranteed to be a power of 2. unsafe { Self::new_unchecked(layout.align()) } diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 681a7b1ba..f660df522 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -30,7 +30,7 @@ //! * Lack of save/load support: The index is currently ephemeral, but there are plans to //! address this gap. -use std::hash::Hash; +use std::{hash::Hash, num::NonZeroUsize}; use diskann::{ ANNError, ANNErrorKind, ANNResult, @@ -75,6 +75,8 @@ where layer: L, // ID translation. mapping: Sharded, + // Construction `Config`. + config: Config, // `Counters` is only non-trivial under the `integration-test` feature flag. Otherwise, // all counter related operations are no-ops. @@ -115,10 +117,12 @@ where store, layer, mapping, + config, counters: Counters::new(), }) } + /// A local set of counters that update the provider-wide counters in bulk. fn local_counters(&self) -> LocalCounters<'_> { self.counters.local() } @@ -148,9 +152,12 @@ pub enum ProviderError { pub struct Config { capacity: usize, max_degree: usize, + prefetch_lookahead: Option, } impl Config { + const DEFAULT_PREFETCH_LOOKAHEAD: NonZeroUsize = NonZeroUsize::new(8).unwrap(); + /// Construct a new [`Config`]. /// /// * `capacity`: The number of dynamic entries in the resulting provider. @@ -159,6 +166,7 @@ impl Config { Self { capacity, max_degree, + prefetch_lookahead: Some(Self::DEFAULT_PREFETCH_LOOKAHEAD), } } @@ -171,6 +179,13 @@ impl Config { pub fn max_degree(&self) -> usize { self.max_degree } + + /// Configure the prefetch lookahead. + /// + /// This is used during beam expansion to prefetch data into CPU caches. + pub fn prefetch_lookahead(&mut self, prefetch_lookahead: Option) { + self.prefetch_lookahead = prefetch_lookahead; + } } /////////////////// @@ -425,7 +440,7 @@ impl glue::SearchAccessor for SearchAccessor<'_> { // `self.buffer` is long enough to hold all the IDs. let processed = unsafe { self.expand_beam - .expand_beam(&self.ids, 8, &self.reader, &mut self.buffer) + .expand_beam(&self.ids, &self.reader, &mut self.buffer) }?; self.counters.get_vector(processed as u64); @@ -457,39 +472,57 @@ trait ExpandBeam: Send + Sync + std::fmt::Debug { unsafe fn expand_beam( &self, list: &[u32], - lookahead: usize, reader: &store::Reader<'_>, buffer: &mut [(u32, f32)], ) -> ANNResult; } #[derive(Debug)] -#[repr(transparent)] -struct ExpandBeamImpl(T); +struct ExpandBeamImpl { + inner: T, + prefetch_lookahead: usize, +} + +impl ExpandBeamImpl { + fn new(inner: T, prefetch_lookahead: usize) -> Self { + Self { + inner, + prefetch_lookahead, + } + } +} impl ExpandBeam for ExpandBeamImpl where T: layers::QueryDistance, { fn evaluate(&self, x: &[u8]) -> ANNResult { - self.0.evaluate(x) + self.inner.evaluate(x) } unsafe fn expand_beam( &self, list: &[u32], - lookahead: usize, reader: &store::Reader<'_>, buffer: &mut [(u32, f32)], ) -> ANNResult { // SAFETY: Inherited from caller. - unsafe { expand_beam_inner::(&self.0, list, lookahead, reader, buffer) } + unsafe { + expand_beam_inner::( + &self.inner, + list, + self.prefetch_lookahead, + reader, + buffer, + ) + } } } #[derive(Debug)] struct ExpandBeamVisitor { bytes: Bytes, + prefetch_lookahead: usize, } impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { @@ -501,14 +534,20 @@ impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { { // This is critical to ensure we emit the correct number of prefetches. assert!(Bytes::new(BYTES + store::TAG_SIZE.value()) <= self.bytes); - Box::new(ExpandBeamImpl::<_, BYTES>(distance)) + Box::new(ExpandBeamImpl::<_, BYTES>::new( + distance, + self.prefetch_lookahead, + )) } fn visit(self, distance: T) -> Self::Output where T: QueryDistance + 'a, { - Box::new(ExpandBeamImpl::<_, 0>(distance)) + Box::new(ExpandBeamImpl::<_, 0>::new( + distance, + self.prefetch_lookahead, + )) } } @@ -589,7 +628,8 @@ where } } - let mut j = lookahead; + // Disable prefetching if the lookahead is 0. + let mut j = if lookahead == 0 { len } else { lookahead }; let mut processed = 0; for &i in list.iter() { if j != len { @@ -795,6 +835,7 @@ where query, ExpandBeamVisitor { bytes: provider.store.bytes(), + prefetch_lookahead: provider.config.prefetch_lookahead.map_or(0, |x| x.get()), }, )?; @@ -811,6 +852,8 @@ where } } +// This is a utility for helping inspect the generated code for `ExpandBeam`. +// pub fn test_function<'a>( x: &'a Provider>, strategy: &'a Strategy, @@ -820,6 +863,7 @@ pub fn test_function<'a>( glue::SearchStrategy::search_accessor(strategy, x, context, query) } +/// Perform ID translation during post-processing. #[derive(Debug, Clone, Copy)] pub struct Translate(std::marker::PhantomData<(L, M)>); diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs index 2b92d3e72..5aa04890c 100644 --- a/diskann-inmem/src/store.rs +++ b/diskann-inmem/src/store.rs @@ -336,6 +336,17 @@ impl Store { } } + /// A somewhat crude algorithm for cooperatively performing slot scanning. + /// + /// This uses [`Freelist::scan`] to acquire a disjoint chunk of the ID space for scanning, + /// spreading out the search across multiple threads. + /// + /// If we successfully acquire a slot, we continue for the rest of the bucket returned + /// by [`Freelist::scan`] and add any available slots to the freelist (allowing other + /// threads to find them). + /// + /// Periodically, the freelist is checked to see if another thread has found an available + /// slot for us. fn scan_acquire(&self) -> Option> { // This is potentially quite slow - but stop if we've scanned the entire range // without finding anything. From 9d5f1435fd3a327a8fc2cdf7f96d0ffeb164fa1b Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 1 Jul 2026 17:56:48 -0700 Subject: [PATCH 35/45] Upload RFC. --- rfcs/01206-inmem2.md | 270 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 rfcs/01206-inmem2.md diff --git a/rfcs/01206-inmem2.md b/rfcs/01206-inmem2.md new file mode 100644 index 000000000..26572f921 --- /dev/null +++ b/rfcs/01206-inmem2.md @@ -0,0 +1,270 @@ +# Concurrent In-Memory Index + +| | | +|---|---| +| **Authors** | Mark Hildebrand | +| **Contributors** | | +| **Created** | 2026-07-01 | + +## Summary + +Let's make our in-memory index robust under concurrent operations. + +## Motivation + +### Background + +There are [two methods](https://github.com/microsoft/DiskANN/blob/b603ec009ea5c3cdbbd5358ca88b4a2a30c8d52b/diskann-providers/src/model/graph/provider/async_/common.rs#L148-L195) at the heart of our current in-memory index that deeply bother me: `get_slice` and `get_slice_mut`. +These methods allow concurrent, unsynchronized access between mutable and immutable data. +While there is some measure of safety on the [write path](https://github.com/microsoft/DiskANN/blob/b603ec009ea5c3cdbbd5358ca88b4a2a30c8d52b/diskann-providers/src/model/graph/provider/async_/fast_memory_vector_provider.rs#L44-L47), this is insufficient to prevent a concurrent reader and a writer and there is no systematic safety protocol in place to prevent this situation: safety comments are basically "yolo" (I can make fun of them because I've written such comments myself). + +This is problematic for `diskann` for a number of reasons: + +* It prevents us from stress testing our algorithm under concurrent inserts/search/deletes etc.. + Unfortunately, this is the situation under which most of our database integrations actually operate. + +* Coupled with the lack of ID translation mechanism for the inmem provider, we have no protection against races inserting multiple internal IDs simulataneously or a coherent story between concurrent inserts and deletes to the same ID. + +* It makes me sad. + +So why is it this way? +Performance! +Yoloing a pointer read is cheap; concurrency is expensive. +An obvious safety mechansim would be to slap a mutex or [fine-grained RCU](https://github.com/microsoft/DiskANN/blob/b603ec009ea5c3cdbbd5358ca88b4a2a30c8d52b/diskann-providers/src/model/graph/provider/async_/memory_vector_provider.rs#L23) around each slot. +Unfortunately, approaches like this that require cache-line **writes** before (and after) each read operation are prohibitively slow and completely unnecessary for [static builds](https://github.com/microsoft/DiskANN/blob/main/diskann-disk/src/build/builder/inmem_builder.rs). + +### Problem Statement + +Let's try to fix with in-memory provider to make it: + +* Well defined under concurrent operations. +* Enable co-development of the core algorithm to patch our remaining concurrency issues. +* Easier to use than our current providers. +* Faster to compile. +* Support external/internal ID translation. +* Have much better tests. +* Without completely compromising performance. + +## Proposal + +This RFC accompanies [1206](https://github.com/microsoft/DiskANN/pull/1206), which is a MVP of the proposed design. + +The design assumes that traffic to the data store is primarily **read** heavy (to facilitated searches). +As such, we're willing to sacrifice some write performance (or more specifically, the time between when a slot is deleted and when it is reused) to make reads fast. +Importantly, readers can determine whether it is safe (or not) to access data with a single 1-byte atomic read. +**Following** a successful check, readers may access the associated data without fear of races. + +### Concurrency + +**Definition**: A "slot" is any data (potentially spread across multiple containers) that is uniquely associated with an internal ID. +This part of the discussion focuses on slots and concurrently controlling access to the data contained within. + +At a high level, consider associating every slot with a 1-byte atomic tag with states + +* "available": This slot is free to claim but cannot be read. + The data managed by this slot should be considered invalid. + +* "owned": This slot is exclusively owned by a thread. + No other threads may access data in the slot and the owner is free to write. + +* "published": The slot is publicly available. + Readers that observe a "published" tag can read the associated data. + +#### Attempt 1 + +Consider a simple (but wholly invalid) concurrency protocol where a writer acquires a slot by transitioning (via [CAS](https://doc.rust-lang.org/std/sync/atomic/type.AtomicUsize.html#method.compare_exchange)) the tag from "available" to "owned", writes data, then transitions the slot to "published". +A delete could happen by transitioning from "published" to "owned", do its thing, and then either re-publish or make it "available". +This does not work as a solution to our goal. + +The problem is that even if intrepid writer transitions the tag to "owned", it has no guaranteed that readers who observed the "published" state prior to this transition are done using the data. +As a diagram +``` +Time | Reader | Writer +-------+----------------------------+----------- + 1 | Reads "published" | + | Decides it's safe to read | +-------+----------------------------+---------------------------------------- + 2 | | Transitions from "published" to "owned" +-------+----------------------------+---------------------------------------- +RACE 3 | Starts Reading | Starts Writing +``` +We get undefined behavior at step 3! + +#### Attempt 2 + +We could augment attempt 1 by having the reader thread check the tag after it finishes its operation. +If it observes a non-"published" state, it can abort its operation. +This cannot be done safely with our current tag scheme due to the ABA problem: a reader could see "published", start its operation, and then get preempted by the operating system (OS). +In the meantime, a writer can transition from "published" to "owned", do its thing, and then set it back to "published". +When the original reader gets back from its vacation, it still sees "published" and thinks everything is okay even though it potentially operated on invalid data. + +Schemes like this can be used by constructs like [sequence locks](https://en.wikipedia.org/wiki/Seqlock). +However, these are + +* Famously impossible to represent (without UB) in the semantics of high-level programming languages. + +* Basically the only operation you can do "safely" under a sequence lock is a memcpy. + You can't compute - and we'd really like to be able to compute a distance on the data in-place. + +* Require larger counters to avoid longer-ranged versions of the previously described ABA problem where the counter fully wraps around. + +#### Solution + +What we need a little more communication. +Conceptually, the writer transitions a tag from "published" to "owned", but then **waits** until it can guarantee that no readers are alive any more that could have observed that change. +Only at this point can the writer do its thing without undefined behavior (don't worry, in the implementation here, writers are *not* blocked waiting for readers). + +This is where [epoch-based reclamation](https://docs.rs/crossbeam-epoch/latest/crossbeam_epoch/) enters the picture. +For each provider, we maintain an monotonically increasing "epoch". +The idea is that for every search/insert/delete operation, a `Reader` is first created which registers itself as using the current epoch (call it `E`). +This `Reader` behaves as previously described, reading tags and if it observes "published": reading the associated data. +When the operation finishes, the `Reader` deregisters itself. + +**Importantly**, a `Reader` can also retire slots by transitioning their tag (via CAS) from "published" to a new "retired" state and inserting the tag index into its epoch-specific queue. +This "retired" state (1) prevents future readers from accessing the data, and (2) prevents other threads from trying to claim the slot. +The epoch-specific queue holds onto the slot ID until all readers who could have observed that transition have been deregistered. + +Cleaning up retired items from epoch-queue happens during epoch advancement. +Periodically (e.g., one in every `N` inserts or searches or via a background process), we try to advance the epoch. +An epoch can **only** be advanced from `E` to `E+1` if all registered readers belong to epoch `E`. +Any reader at epoch `E-1` will prevent the transition. +When we successfully advance the epoch, we get the epoch-queue associate with epoch `E-2`. +This queue contains the slot indexes for those retired at epoch `E-2`. +Because all current readers belong to `E` or `E+1`, we are guaranteed that all current readers agree that the state of the tag is "retired" and will not be trying to read any data in the slots contained within the queue. +As such, the thread processing epoch `E-2` is free to write to the data in these slots and transition them to other states without fear of a race, thus solving our problem! + +With these scheme, we can keep recycling the same four epoch queues because the scheme guarantees that only two are ever written to at a time. +Literature often claims that just three queues are needed. +We need an extra one because each `Reader` pushes items into the epoch queue associated with the `Reader`'s creation epoch. +A `Reader` in epoch `E` can retire a slot into the `E` queue, but a `Reader` in epoch `E+1` **can** observe this retirement. +Then the epoch advances to `E+2`. +If we pulled the offending slot out of the `E` queue, the `Reader` in `E+1` could still be reading it and we're back to undefined behavior. +Introducing a fourth queue fixes this issue. + +#### Implementation + +The core components of this protocol are split across three files in #1206: + +* `tag.rs`: The implementation of atomic slot tags. + The PR contains a few more states for slots to enable slots to be in special (e.g., "frozen") states, but follows the main "available", "published", "owned", "retired" scheme outlined above. + +* `epoch.rs`: The logic for `Reader` registration, deregistration, epoch advancement, and epoch queues. + Unsurprisingly, convincing a bunch of concurrent threads to get along with minimal locking is subtle. + +* `store.rs`: A package data store with slots that completes the implementation by providing + a safe `Reader` based abstraction on the store. + +Within `store.rs`, there are actually two tags per slot. +An authoritative tag in a `Vec`, and a mirrored tag that lives inline with the data being stored in the slot. +The idea here is that during search, we can emit prefetches for the data we're going to process and either get the mirror for free, or take advantage of locality to avoid something like a page fault. +This mirror tag can then be used for the safety check instead. +For quantization algorithms like spherical that don't always generate a nice power-of-two number of bytes, there is likely unused space in our cache line padding so this 1-byte tag can be stored for free. + +### Reconciling Performance + +Even though we've brought the concurrency overhead down to just 1-byte per slot with a light weight check (about 6 instructions), this is still strictly more data than the current index. +This is particularly painful for datasets like "sift", where any additional read pulls in an additional cache line, moving from 2 cachelines to 3. +There is a little bit of work that can be done. +PR [1067](https://github.com/microsoft/DiskANN/pull/1067) moved the search contract behind a single `expand_beam` function. +PR 1026 uses a variation of [bring your own type-erasure](https://github.com/microsoft/DiskANN/pull/1068) to enable distance layers to + +* Inline their final distance functions directly into the `expand_beam` implementation rather than relying on [function pointers](https://github.com/microsoft/DiskANN/blob/main/diskann-vector/src/distance/distance_provider.rs). +* Further, length-specialized implementation can communicate their element byte size via const-generics, allowing the final `expand_beam` implementation to emit the exact number of prefetch instructions. + +As an example, the `expand_beam` inner loop for 100-dimensional L2 vectors compiles to the following assembly: +``` +.LBB4_27: | Check if all neighbors have been processed + mov r10, r15 | + add rdx, 4 | + mov rax, rsi | + cmp r14, rdx | + je .LBB4_30 | +.LBB4_23: + mov r14, r10 + mov r8, qword ptr [rbx + 16] | Prefetch if there are still items to prefetch + mov r10, qword ptr [rbx + 24] | + mov rsi, rbp | + cmp rax, rbp | + je .LBB4_25 | + mov esi, dword ptr [r15 + 4*rax] | + imul rsi, r10 | + prefetcht0 byte ptr [r8 + rsi + 384] | + prefetcht0 byte ptr [r8 + rsi] | + prefetcht0 byte ptr [r8 + rsi + 64] | + prefetcht0 byte ptr [r8 + rsi + 128] | + prefetcht0 byte ptr [r8 + rsi + 192] | + prefetcht0 byte ptr [r8 + rsi + 256] | + prefetcht0 byte ptr [r8 + rsi + 320] | + inc rax + mov rsi, rax +.LBB4_25: + mov eax, dword ptr [r15 + rdx] | Safety tag check + imul r10, rax | + add r8, r10 | + movzx r10d, byte ptr [r11 + r8] | + cmp r10b, -2 | + jb .LBB4_27 | + vmovups ymm2, ymmword ptr [rdi] | Inlined Distance Computation + vmovups ymm3, ymmword ptr [rdi + 32] | + vmovups ymm4, ymmword ptr [rdi + 64] | + vsubps ymm2, ymm2, ymmword ptr [r8] | + vmovups ymm5, ymmword ptr [rdi + 96] | + vfmadd213ps ymm2, ymm2, ymm0 | + ... repeats a lot | + vfmadd213ps ymm3, ymm3, ymm4 | + vaddps ymm3, ymm5, ymm3 | + vmaskmovps ymm4, ymm1, ymmword ptr [rcx] | + vaddps ymm2, ymm2, ymm3 | + vmaskmovps ymm3, ymm1, ymmword ptr [r8 + 384] | + vsubps ymm3, ymm4, ymm3 | + vfmadd213ps ymm3, ymm3, ymm2 | + vextractf128 xmm2, ymm3, 1 | + vaddps xmm2, xmm3, xmm2 | + vshufpd xmm3, xmm2, xmm2, 1 | + vaddps xmm2, xmm2, xmm3 | + vmovshdup xmm3, xmm2 | + vaddss xmm2, xmm2, xmm3 | + mov r8, qword ptr [rsp + 8] + mov dword ptr [r9 + 8*r8], eax | Write Back + vmovss dword ptr [r9 + 8*r8 + 4], xmm2 | + inc r8 | + mov qword ptr [rsp + 8], r8 | + jmp .LBB4_27 +``` +To prevent compile time explosions, these aggressively optimized inner loops are only generated once and then packaged in a trait object. +This avoids re-monomorphization as different closures and iterators that can be passed to the top level `SearchAccessor::expand_beam` method. + +The hope with this specialization hook is that we can tune and optimize `expand_beam` more aggressively than our current providers to offset the extra byte read (and search accessor creation times due to epoch registration). + +### Testing + +Our current in-memory index uses a [very large](https://github.com/microsoft/DiskANN/blob/main/diskann-providers/src/index/diskann_async.rs) test file with ad-hoc tests. +PR 1206 uses the [A/B test functionality](https://github.com/microsoft/DiskANN/pull/900) in `diskann-benchmark-runner` to + +* Execute [longer running](https://github.com/microsoft/DiskANN/pull/1199) tests. +* Gather richer metrics and recall stats for these tests. +* Compare against a checked-in JSON baseline and notify of any changes. + +To allow for future adaptability, the baseline can be regenerated with the `DISKANN_TEST=overwrite` environment variable setting. +Since the baseline is raw JSON, changes will show up in the git diff for reviewers to inspect. + +The goal here is to enable more robust testing of the in-memory index and by-extension the core DiskANN algorithm. + +## Trade-offs + +All of this is fairly complex stuff to solve an insidious safety loophole. +And unfortunately, the concurrency infrastructure is not strictly needed for static in-memory builds. +I've been thinking about this a lot, and have never really been able to come up with a scheme that provides the read-only property with such little overhead. +I am more than happy to entertain alternative ideas. + +## Benchmark Results + +Incoming, but inmem2 is generally on-par with inmem1 (except for sift, where is has about a 10% performance regression). +For streaming workloads, the hard deletes required by inmem2 may actually lead to higher recall. + +## Future Work + +- [ ] Add quantization (this will require figuring out how an extra blob can be protected by `store` - this is a solvable problem). +- [ ] Implement saving and loading. +- [ ] Optimize `expand_beam` a little more. +- [ ] Migrate existing users over to inmem2. From 9a103790b1a5bd7072a18203344e193752744bfd Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 13:54:32 -0700 Subject: [PATCH 36/45] Restore stress tests. --- diskann-inmem/DEV.md | 7 +++++++ diskann-inmem/integration/store.rs | 4 ++-- diskann-inmem/src/integration/store.rs | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 diskann-inmem/DEV.md diff --git a/diskann-inmem/DEV.md b/diskann-inmem/DEV.md new file mode 100644 index 000000000..59b29c213 --- /dev/null +++ b/diskann-inmem/DEV.md @@ -0,0 +1,7 @@ +# Dev Docs + +Fully testing this crate requires enabling the `integration-test` feature. +The suggested command is +``` +cargo test --package diskann-inmem --all-features --profile ci +``` diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index b8f6d7e8b..d4432358a 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -330,8 +330,8 @@ fn writer(shared: &Shared) { Some(mut writer) => { let stamp = shared.stamp.fetch_add(1, Relaxed); write_stamp(writer.as_mut_slice(), stamp); - // Dropping the writer publishes the slot. - drop(writer); + writer.publish(); + let live = shared.live.fetch_add(1, Relaxed) + 1; shared.peak_live.fetch_max(live, Relaxed); shared.acquires_ok.fetch_add(1, Relaxed); diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs index ecfe22878..ff8b5f797 100644 --- a/diskann-inmem/src/integration/store.rs +++ b/diskann-inmem/src/integration/store.rs @@ -90,6 +90,10 @@ impl<'a> Writer<'a> { Self { slot } } + pub fn publish(self) { + self.slot.publish(); + } + pub fn as_mut_slice(&mut self) -> &mut [u8] { self.slot.as_mut_slice() } From f1a71214b45d52c9a40455320c72f1f121fea263 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 14:13:49 -0700 Subject: [PATCH 37/45] More cleanups. --- diskann-benchmark-runner/src/utils/fmt.rs | 4 ---- diskann-benchmark/src/index/inmem2.rs | 5 +---- diskann-benchmark/src/index/mod.rs | 2 +- diskann-inmem/integration/index/object.rs | 2 +- diskann-inmem/integration/index/tests.rs | 2 +- diskann-utils/src/views.rs | 13 ------------- 6 files changed, 4 insertions(+), 24 deletions(-) diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 22756cd0d..795c2fab9 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -486,10 +486,6 @@ impl<'a> KeyValue<'a> { self.max_key_length = self.max_key_length.max(key.len()); self.kv.push((key, MaybeLazy::Eager(value.to_string()))) } - - pub fn render(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self) - } } impl std::fmt::Display for KeyValue<'_> { diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 660d92517..10e47e4c1 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -38,10 +38,7 @@ use crate::{ index::{ build::{BuildKind, BuildStats, ProgressMeter}, result::{AggregatedSearchResults, SearchResults}, - streaming::{ - stats::{GenericStats, Summary}, - StreamStats, - }, + streaming::stats::{GenericStats, StreamStats, Summary}, }, utils::{datafiles, SimilarityMeasure}, }; diff --git a/diskann-benchmark/src/index/mod.rs b/diskann-benchmark/src/index/mod.rs index 0d8dd9de3..7d9878c44 100644 --- a/diskann-benchmark/src/index/mod.rs +++ b/diskann-benchmark/src/index/mod.rs @@ -5,7 +5,7 @@ use diskann_benchmark_runner::Registry; -pub(crate) mod build; +mod build; mod search; mod streaming; diff --git a/diskann-inmem/integration/index/object.rs b/diskann-inmem/integration/index/object.rs index a065a9b95..cc5f4dc6f 100644 --- a/diskann-inmem/integration/index/object.rs +++ b/diskann-inmem/integration/index/object.rs @@ -142,7 +142,7 @@ impl std::fmt::Display for Counters { kv.push("get_neighbors", &self.get_neighbors); kv.push("set_neighbors", &self.set_neighbors); kv.push("append_neighbors", &self.append_neighbors); - kv.render(f) + write!(f, "{}", kv) } } diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs index 7b0277a0d..436a851e3 100644 --- a/diskann-inmem/integration/index/tests.rs +++ b/diskann-inmem/integration/index/tests.rs @@ -127,7 +127,7 @@ impl std::fmt::Display for KnnStats { kv.push("counters", &self.counters); kv.push("recall", &self.recall); kv.push("misc", &self.misc); - kv.render(f) + write!(f, "{}", kv) } } diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index 844420e82..a9352918c 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -344,19 +344,6 @@ where unsafe { self.get_row_unchecked_mut(row) } } - /// Return row `row` as a mutable slice. - pub fn get_row_mut(&mut self, row: usize) -> Option<&mut [T::Elem]> - where - T: MutDenseData, - { - if row < self.nrows() { - // SAFETY: `row` is in-bounds. - Some(unsafe { self.get_row_unchecked_mut(row) }) - } else { - None - } - } - /// Returns the requested row without boundschecking. /// /// # Safety From b3a00b7e01f7f0a80ff2a06a0ba46e95b0d12dd3 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 14:31:59 -0700 Subject: [PATCH 38/45] Add tests for `KeyValue`. --- diskann-benchmark-runner/src/utils/fmt.rs | 78 +++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 795c2fab9..29b6f9ad8 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -735,4 +735,82 @@ string, , string .with_pair(" and "); assert_eq!(d.to_string(), "\"topk\" and \"range\""); } + + //----------// + // KeyValue // + //----------// + + // Strip a preceeding newline if it exists. + fn process(x: &str) -> &str { + let x = x.strip_prefix('\n').unwrap_or(x); + x.strip_suffix('\n').unwrap_or(x) + } + + #[test] + fn test_key_value_empty() { + let kv = KeyValue::new(); + assert_eq!(kv.to_string(), ""); + } + + #[test] + fn test_key_value_single_pair() { + let mut kv = KeyValue::new(); + kv.push("a", &1); + assert_eq!(kv.to_string(), "a: 1"); + } + + #[test] + fn test_key_value_aligns_values() { + let mut kv = KeyValue::new(); + kv.push("a", &1); + kv.push("hello", &"world"); + let expected = process( + r#" +a: 1 +hello: world +"#, + ); + assert_eq!(kv.to_string(), expected); + } + + #[test] + fn test_key_value_push_eager() { + let mut kv = KeyValue::new(); + kv.push_eager("a", 1); + kv.push_eager("hello", "world"); + + let expected = process( + r#" +a: 1 +hello: world +"#, + ); + + assert_eq!(kv.to_string(), expected); + } + + #[test] + fn test_key_value_multiline_value_is_indented() { + let mut inner = KeyValue::new(); + inner.push("x", &1); + inner.push("yy", &2); + let inner = inner.to_string(); + + let mut kv = KeyValue::new(); + kv.push("name", &"example"); + kv.push("nested", &inner); + kv.push("another line", &1); + + let expected = process( + r#" +name: example +nested: + x: 1 + yy: 2 +another line: 1 +"#, + ); + + assert_eq!(kv.to_string(), expected); + } } From ab92daee13c0d2ef0066cbdcb4501ebdb15e53b5 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 14:35:28 -0700 Subject: [PATCH 39/45] Tests for KeyValue. --- diskann-benchmark-runner/src/utils/fmt.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 29b6f9ad8..7d0adeecf 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -813,4 +813,15 @@ another line: 1 assert_eq!(kv.to_string(), expected); } + + #[test] + fn maybe_lazy_debug() { + let x = MaybeLazy::Lazy(&1); + let dbg = format!("{:?}", x); + assert_eq!(dbg, "MaybeLazy::Lazy(1)"); + + let x = MaybeLazy::Eager("hello".into()); + let dbg = format!("{:?}", x); + assert_eq!(dbg, "MaybeLazy::Eager(\"hello\")"); + } } From 3b59cf1971c3c89f5a460111ddaccc5882835c85 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 14:53:33 -0700 Subject: [PATCH 40/45] Cleanup and reduce false-sharing in stress test. --- diskann-benchmark-runner/src/checker.rs | 30 ---------- diskann-benchmark/src/index/inmem2.rs | 2 +- diskann-inmem/integration/store.rs | 73 ++++++++++++++++++++++--- 3 files changed, 65 insertions(+), 40 deletions(-) diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index 23624ca5f..f561983a1 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -181,36 +181,6 @@ impl Checker { self.search_directories(), ))) } - - pub fn __check_dir(&self, dir: &Path) -> Result { - // Check if the file exists (allowing for relative paths with respect to the current - // directory. - // - // If the path is an absolute path and the file does not exist, then bail. - if dir.is_absolute() { - if dir.is_dir() { - return Ok(dir.into()); - } else { - return Err(anyhow::Error::msg(format!( - "input file with absolute path \"{}\" either does not exist or is not a file", - dir.display() - ))); - } - }; - - // At this point, start searching in the provided directories. - for d in self.search_directories() { - let absolute = d.join(dir); - if absolute.is_dir() { - return Ok(absolute); - } - } - Err(anyhow::Error::msg(format!( - "could not find input file \"{}\" in the search directories \"{:?}\"", - dir.display(), - self.search_directories(), - ))) - } } #[derive(Debug, Clone, Copy)] diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs index 10e47e4c1..a794fbfcf 100644 --- a/diskann-benchmark/src/index/inmem2.rs +++ b/diskann-benchmark/src/index/inmem2.rs @@ -650,7 +650,7 @@ impl RunBook { path.resolve(checker)?; - let groundtruth_directory = checker.__check_dir(groundtruth_directory.as_ref())?; + let groundtruth_directory = checker.find_input_dir(groundtruth_directory.as_ref())?; let runbook = bigann::RunBook::load( &path, diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index d4432358a..625f9aafe 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -278,6 +278,48 @@ fn observe( // Shared // //////////// +struct Local<'a> { + counter: u64, + parent: &'a AtomicU64, +} + +impl<'a> Local<'a> { + fn new(parent: &'a AtomicU64) -> Self { + Self { counter: 0, parent } + } + + fn add(&mut self, by: u64) { + self.counter += by; + } +} + +impl Drop for Local<'_> { + fn drop(&mut self) { + self.parent.fetch_add(self.counter, Relaxed); + } +} + +struct LocalMax<'a> { + max: usize, + parent: &'a AtomicUsize, +} + +impl<'a> LocalMax<'a> { + fn new(parent: &'a AtomicUsize) -> Self { + Self { max: 0, parent } + } + + fn max(&mut self, m: usize) { + self.max = self.max.max(m); + } +} + +impl Drop for LocalMax<'_> { + fn drop(&mut self) { + self.parent.fetch_max(self.max, Relaxed); + } +} + /// State shared by all worker threads for the duration of a run. struct Shared { store: Store, @@ -324,8 +366,14 @@ fn should_stop(shared: &Shared) -> bool { ///////////// fn writer(shared: &Shared) { + let mut ops = Local::new(&shared.ops); + let mut acquires_ok = Local::new(&shared.acquires_ok); + let mut acquires_fail = Local::new(&shared.acquires_fail); + + let mut peak_live = LocalMax::new(&shared.peak_live); + while !should_stop(shared) { - shared.ops.fetch_add(1, Relaxed); + ops.add(1); match shared.store.acquire() { Some(mut writer) => { let stamp = shared.stamp.fetch_add(1, Relaxed); @@ -333,11 +381,11 @@ fn writer(shared: &Shared) { writer.publish(); let live = shared.live.fetch_add(1, Relaxed) + 1; - shared.peak_live.fetch_max(live, Relaxed); - shared.acquires_ok.fetch_add(1, Relaxed); + peak_live.max(live); + acquires_ok.add(1); } None => { - shared.acquires_fail.fetch_add(1, Relaxed); + acquires_fail.add(1); std::thread::yield_now(); } } @@ -348,6 +396,10 @@ fn retirer(shared: &Shared, seed: u64) { let mut rng = StdRng::seed_from_u64(seed); let mut iteration: u64 = 0; + let mut retires_ok = Local::new(&shared.retires_ok); + let mut retires_fail = Local::new(&shared.retires_fail); + let mut reclaims = Local::new(&shared.reclaims); + while !should_stop(shared) { shared.ops.fetch_add(1, Relaxed); iteration += 1; @@ -357,16 +409,16 @@ fn retirer(shared: &Shared, seed: u64) { let i = rng.sample(shared.writable); if shared.store.retire(i) { shared.live.fetch_sub(1, Relaxed); - shared.retires_ok.fetch_add(1, Relaxed); + retires_ok.add(1); } else { - shared.retires_fail.fetch_add(1, Relaxed); + retires_fail.add(1); } } if iteration.is_multiple_of(RECLAIM_EVERY) && let Some(reclaimed) = shared.store.reclaim() { - shared.reclaims.fetch_add(reclaimed as u64, Relaxed); + reclaims.add(reclaimed as u64); } std::thread::yield_now(); @@ -379,8 +431,11 @@ fn reader(shared: &Shared, seed: u64) { let window = READER_WINDOW.min(slots); let mut observations = HashMap::with_capacity(window); + let mut ops = Local::new(&shared.ops); + let mut reads = Local::new(&shared.reads); + while !should_stop(shared) { - shared.ops.fetch_add(1, Relaxed); + ops.add(1); let Some(guard) = shared.store.reader() else { // All guard slots are occupied; back off and retry. std::thread::yield_now(); @@ -393,7 +448,7 @@ fn reader(shared: &Shared, seed: u64) { for k in 0..window { let i = (start + k) % slots; observe(shared, &mut observations, i, guard.read(i)); - shared.reads.fetch_add(1, Relaxed); + reads.add(1); } } } From 2678f36f62ae70a31981ccf57902d51f00e7ca54 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 15:30:13 -0700 Subject: [PATCH 41/45] Fix-up Cargo.toml --- Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 594c20999..97ecfc640 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "diskann-disk", "diskann-label-filter", "diskann-garnet", + "diskann-inmem", # Infrastructure "diskann-benchmark-runner", "diskann-benchmark-core", @@ -25,7 +26,7 @@ members = [ "diskann-record", "diskann-tools", "vectorset", - "diskann-bftree", "diskann-inmem", + "diskann-bftree", ] default-members = [ From db9d1b9bd51dc61f6e187c211a14efdabf2b4bc7 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 15:38:05 -0700 Subject: [PATCH 42/45] Make compatible with Aarch64. --- diskann-inmem/src/provider.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index f660df522..f0569b701 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -558,6 +558,7 @@ impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { /// # Safety /// /// The memory range `[ptr, ptr.add(len))` must be valid. +#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] #[inline(always)] unsafe fn prefetch(ptr: *const u8, len: usize) { use std::arch::x86_64::*; @@ -580,6 +581,16 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { } } +/// Prefetch `len` bytes beginning at `ptr`. +/// +/// The last cache line prefetched first, followed by the rest in ascending order. +/// +/// # Safety +/// +/// The memory range `[ptr, ptr.add(len))` must be valid. +#[cfg(not(any(target_arch = "x86_64", target_feature = "avx2")))] +unsafe fn prefetch(_ptr: *const u8, _len: usize) {} + /// # Safety /// /// * All items in `list` must in-bounds with respect to `reader`. From 43d3fcf8db9e9bb104527c99b4c35efa4981869c Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 15:43:13 -0700 Subject: [PATCH 43/45] Not syncing with git? --- diskann-inmem/src/provider.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index f0569b701..671a17bac 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -581,6 +581,7 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { } } + /// Prefetch `len` bytes beginning at `ptr`. /// /// The last cache line prefetched first, followed by the rest in ascending order. From ce7a3219da58364443ec5220e2dfe6b46e706dd6 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 15:43:37 -0700 Subject: [PATCH 44/45] Remove extra whitespace. --- diskann-inmem/src/provider.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs index 671a17bac..f0569b701 100644 --- a/diskann-inmem/src/provider.rs +++ b/diskann-inmem/src/provider.rs @@ -581,7 +581,6 @@ unsafe fn prefetch(ptr: *const u8, len: usize) { } } - /// Prefetch `len` bytes beginning at `ptr`. /// /// The last cache line prefetched first, followed by the rest in ascending order. From a4ee3b2f6671566567e17a14a74021bbcd404522 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 2 Jul 2026 17:21:34 -0700 Subject: [PATCH 45/45] Typos. Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark-runner/src/utils/fmt.rs | 2 +- diskann-inmem/integration/store.rs | 2 +- diskann-inmem/integration/support/check.rs | 4 ++-- diskann-inmem/src/layers/full.rs | 4 ++-- diskann-inmem/src/layers/mod.rs | 10 +++++----- diskann-inmem/src/num.rs | 2 +- rfcs/01206-inmem2.md | 6 +++--- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 7d0adeecf..f700b115a 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -740,7 +740,7 @@ string, , string // KeyValue // //----------// - // Strip a preceeding newline if it exists. + // Strip a preceding newline if it exists. fn process(x: &str) -> &str { let x = x.strip_prefix('\n').unwrap_or(x); x.strip_suffix('\n').unwrap_or(x) diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs index 625f9aafe..b26458b43 100644 --- a/diskann-inmem/integration/store.rs +++ b/diskann-inmem/integration/store.rs @@ -347,7 +347,7 @@ struct Shared { transitions: AtomicU64, } -/// Record the first observed invariant violation and signal all workers to stop. +/// Record an observed invariant violation and signal all workers to stop. fn record_violation(shared: &Shared, message: String) { let mut slot = shared.violation.lock().unwrap(); slot.push(message); diff --git a/diskann-inmem/integration/support/check.rs b/diskann-inmem/integration/support/check.rs index 60fde4b31..8e8ce2121 100644 --- a/diskann-inmem/integration/support/check.rs +++ b/diskann-inmem/integration/support/check.rs @@ -20,12 +20,12 @@ use std::{ use diskann_benchmark_runner::{benchmark::PassFail, utils::fmt::Table}; use serde::{Serialize, Serializer}; -/// Perform a baseline check on `self` and a `previous`ly saved result. +/// Perform a baseline check on `self` and a previously saved result. pub(crate) trait CheckMatch { fn check_match(&self, previous: &Self) -> Match; } -/// The result of a basline check. +/// The result of a baseline check. #[must_use = "this is a result type"] #[derive(Debug, Serialize)] #[serde(rename_all = "kebab-case")] diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs index e0f2c0d0d..b9807282e 100644 --- a/diskann-inmem/src/layers/full.rs +++ b/diskann-inmem/src/layers/full.rs @@ -243,7 +243,7 @@ where #[derive(Debug, Error)] #[error( - "expected slices of lenght {} - instead got {} and {}", + "expected slices of length {} - instead got {} and {}", self.expected, self.xlen, self.ylen @@ -337,7 +337,7 @@ where #[derive(Debug, Error)] #[error( - "expected slice of lenght {} - instead got {}", + "expected slice of length {} - instead got {}", self.expected, self.xlen, )] diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs index eb14f57eb..13179d9bc 100644 --- a/diskann-inmem/src/layers/mod.rs +++ b/diskann-inmem/src/layers/mod.rs @@ -15,8 +15,8 @@ //! The design of this module allows aggressive optimization of graph search kernels via //! the [`Search`] and [`QueryVisitor`] pairs of traits. //! -//! Implementations of [`Search`] can pass an [`QueryDistance`] kernels specialized to -//! to a specific geometry (dimensionality or metric type) which upstream [`QueryVisitor`] +//! Implementations of [`Search`] can pass a [`QueryDistance`] kernel specialized to +//! a specific geometry (dimensionality or metric type) which upstream [`QueryVisitor`] //! will fuse into larger kernels. While this allows for high performance graph kernels, //! some considerations should be taken into account: //! @@ -41,7 +41,7 @@ pub trait Layer: Send + Sync + 'static { fn bytes(&self) -> Bytes; } -/// Store a element of type `T` into a raw byte buffer. +/// Store an element of type `T` into a raw byte buffer. /// /// Implementations may assume that `bytes.len()` is equal to [`Layer::bytes`]. pub trait Set: Layer { @@ -65,12 +65,12 @@ pub trait AsDistance: Send + Sync + std::fmt::Debug { fn as_distance(&self) -> &dyn Distance; } -/// A unary query distance on raw bytes slices. +/// A unary query distance on raw byte slices. /// /// When paired with [`Layer`] via helpers like [`Search`], implementations may assume /// that `x` has length [`Layer::bytes`]. /// -/// No alignment guarantees are made for `x`, though in practice is isk likely to be +/// No alignment guarantees are made for `x`, though in practice it is likely to be /// aligned to 32 or 64 bytes. pub trait QueryDistance: Send + Sync + std::fmt::Debug { fn evaluate(&self, x: &[u8]) -> ANNResult; diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs index eb4a8c8d3..98c20d82b 100644 --- a/diskann-inmem/src/num.rs +++ b/diskann-inmem/src/num.rs @@ -89,7 +89,7 @@ impl std::fmt::Display for Bytes { /// An alignment for an allocation. /// -/// All alignemnts are guaranteed to be powers of two. +/// All alignments are guaranteed to be powers of two. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[repr(transparent)] pub struct Align(NonZeroUsize); diff --git a/rfcs/01206-inmem2.md b/rfcs/01206-inmem2.md index 26572f921..55bd0f07e 100644 --- a/rfcs/01206-inmem2.md +++ b/rfcs/01206-inmem2.md @@ -20,10 +20,10 @@ While there is some measure of safety on the [write path](https://github.com/mic This is problematic for `diskann` for a number of reasons: -* It prevents us from stress testing our algorithm under concurrent inserts/search/deletes etc.. +* It prevents us from stress testing our algorithm under concurrent inserts/search/deletes etc. Unfortunately, this is the situation under which most of our database integrations actually operate. -* Coupled with the lack of ID translation mechanism for the inmem provider, we have no protection against races inserting multiple internal IDs simulataneously or a coherent story between concurrent inserts and deletes to the same ID. +* Coupled with the lack of ID translation mechanism for the inmem provider, we have no protection against races inserting multiple internal IDs simultaneously or a coherent story between concurrent inserts and deletes to the same ID. * It makes me sad. @@ -35,7 +35,7 @@ Unfortunately, approaches like this that require cache-line **writes** before (a ### Problem Statement -Let's try to fix with in-memory provider to make it: +Let's try to fix the in-memory provider to make it: * Well defined under concurrent operations. * Enable co-development of the core algorithm to patch our remaining concurrency issues.