diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba73d6cd1..e719ff575 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ env: CARGO_TERM_COLOR: always # The features we want to explicitly test. For example, the `flatbuffers-build` feature # of `diskann-quantization` requires additional setup and so must not be included by default. - DISKANN_FEATURES: "virtual_storage,spherical-quantization,product-quantization,tracing,experimental_diversity_search,disk-index,flatbuffers,linalg,codegen" + DISKANN_FEATURES: "virtual_storage,spherical-quantization,product-quantization,tracing,experimental_diversity_search,disk-index,flatbuffers,linalg,codegen,integration-test,inmem2" # Intel SDE version used for baseline and AVX-512 emulation jobs. SDE_VERSION: "sde-external-10.8.0-2026-03-15-lin" diff --git a/Cargo.lock b/Cargo.lock index b71b13a8f..6397b80d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -657,6 +657,7 @@ dependencies = [ "diskann-benchmark-runner", "diskann-bftree", "diskann-disk", + "diskann-inmem", "diskann-label-filter", "diskann-providers", "diskann-quantization", @@ -805,6 +806,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "diskann-inmem" +version = "0.54.0" +dependencies = [ + "anyhow", + "bytemuck", + "crossbeam-queue", + "dashmap", + "diskann", + "diskann-benchmark-core", + "diskann-benchmark-runner", + "diskann-utils", + "diskann-vector", + "diskann-wide", + "half", + "parking_lot", + "rand", + "serde", + "serde_json", + "tempfile", + "thiserror 2.0.17", + "tokio", +] + [[package]] name = "diskann-label-filter" version = "0.54.0" diff --git a/Cargo.toml b/Cargo.toml index 85f29ca71..97ecfc640 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "diskann-disk", "diskann-label-filter", "diskann-garnet", + "diskann-inmem", # Infrastructure "diskann-benchmark-runner", "diskann-benchmark-core", @@ -60,6 +61,7 @@ diskann-platform = { path = "diskann-platform", version = "0.54.0" } diskann = { path = "diskann", version = "0.54.0" } # Providers diskann-providers = { path = "diskann-providers", default-features = false, version = "0.54.0" } +diskann-inmem = { path = "diskann-inmem", default-features = false, version = "0.54.0" } diskann-disk = { path = "diskann-disk", version = "0.54.0" } diskann-label-filter = { path = "diskann-label-filter", version = "0.54.0" } # Infra @@ -120,3 +122,7 @@ opt-level = 1 debug = true debug-assertions = true overflow-checks = true + +[profile.samply] +inherits = "release" +debug = true diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index 1bdec6adf..09fd58df6 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -258,8 +258,8 @@ impl App { )?; writeln!(output, "Closest matches:\n")?; for (i, mismatch) in mismatches.into_iter().enumerate() { - writeln!(output, " {}. \"{}\":", i + 1, mismatch.method(),)?; - writeln!(output, "{}\n", Indent::new(mismatch.reason(), 8),)?; + writeln!(output, " {}. \"{}\":", i + 1, mismatch.method())?; + writeln!(output, "{}\n", Indent::new(mismatch.reason(), 8))?; } writeln!(output)?; diff --git a/diskann-benchmark-runner/src/files.rs b/diskann-benchmark-runner/src/files.rs index a1bb453ff..a9c0f6171 100644 --- a/diskann-benchmark-runner/src/files.rs +++ b/diskann-benchmark-runner/src/files.rs @@ -62,6 +62,12 @@ impl std::ops::Deref for InputFile { } } +impl std::fmt::Display for InputFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display()) + } +} + /////////// // Tests // /////////// diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 7680a059b..f700b115a 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -379,6 +379,135 @@ where } } +////////////// +// KeyValue // +////////////// + +enum MaybeLazy<'a> { + Lazy(&'a dyn std::fmt::Display), + Eager(String), +} + +impl std::fmt::Display for MaybeLazy<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Lazy(lazy) => write!(f, "{}", lazy), + Self::Eager(s) => f.write_str(s), + } + } +} + +impl std::fmt::Debug for MaybeLazy<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + struct AsDisplay<'a>(&'a dyn std::fmt::Display); + impl std::fmt::Debug for AsDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + match self { + Self::Lazy(o) => { + let as_display = AsDisplay(o); + f.debug_tuple("MaybeLazy::Lazy").field(&as_display).finish() + } + Self::Eager(s) => f.debug_tuple("MaybeLazy::Eager").field(s).finish(), + } + } +} + +/// Display a dynamic list of key-value pairs in a YAML-like style. +/// +/// Keys are left-aligned and single-line values are aligned into a common column +/// just past the longest key. A value that renders to multiple lines (for example +/// a nested [`KeyValue`] or any other multi-line block) is placed on the lines +/// following its key, indented by two spaces. This keeps nested structures visibly +/// subordinate to their key regardless of whether the value is itself a key-value +/// list or an opaque block. +/// +/// # Examples +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::KeyValue; +/// +/// let mut kv = KeyValue::new(); +/// kv.push("a", &1); +/// kv.push("hello", &"world"); +/// +/// let expected = "a: 1\nhello: world"; +/// +/// assert_eq!(kv.to_string(), expected); +/// ``` +/// +/// Multi-line values are indented beneath their key: +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::KeyValue; +/// +/// let mut inner = KeyValue::new(); +/// inner.push("x", &1); +/// inner.push("yy", &2); +/// let inner = inner.to_string(); +/// +/// let mut kv = KeyValue::new(); +/// kv.push("name", &"example"); +/// kv.push("nested", &inner); +/// +/// let expected = "name: example\nnested:\n x: 1\n yy: 2"; +/// +/// assert_eq!(kv.to_string(), expected); +/// ``` +#[derive(Debug, Default)] +pub struct KeyValue<'a> { + kv: Vec<(&'a str, MaybeLazy<'a>)>, + max_key_length: usize, +} + +impl<'a> KeyValue<'a> { + /// Create a new empty [`KeyValue`] formatter. + pub fn new() -> Self { + Self { + kv: Vec::new(), + max_key_length: 0, + } + } + + /// Push the key-value pair to `self` for formatting. + pub fn push(&mut self, key: &'a str, value: &'a dyn std::fmt::Display) { + self.max_key_length = self.max_key_length.max(key.len()); + self.kv.push((key, MaybeLazy::Lazy(value))) + } + + /// Push the key-value pair to `self` for formatting - eagerly formatting `value`. + pub fn push_eager(&mut self, key: &'a str, value: D) + where + D: std::fmt::Display, + { + self.max_key_length = self.max_key_length.max(key.len()); + self.kv.push((key, MaybeLazy::Eager(value.to_string()))) + } +} + +impl std::fmt::Display for KeyValue<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let width = self.max_key_length; + let mut prefix = ""; + for (k, v) in self.kv.iter() { + let rendered = v.to_string(); + if rendered.contains('\n') { + write!(f, "{}{}:\n{}", prefix, k, Indent::new(&rendered, 2))? + } else { + // Left-align the key and pad so that all single-line values line up in a + // column one space past the longest key's colon. + let pad = (width + 1).saturating_sub(k.len()); + write!(f, "{}{}:{:pad$}{rendered}", prefix, k, "")?; + } + prefix = "\n"; + } + Ok(()) + } +} + /////////// // Tests // /////////// @@ -606,4 +735,93 @@ string, , string .with_pair(" and "); assert_eq!(d.to_string(), "\"topk\" and \"range\""); } + + //----------// + // KeyValue // + //----------// + + // Strip a preceding newline if it exists. + fn process(x: &str) -> &str { + let x = x.strip_prefix('\n').unwrap_or(x); + x.strip_suffix('\n').unwrap_or(x) + } + + #[test] + fn test_key_value_empty() { + let kv = KeyValue::new(); + assert_eq!(kv.to_string(), ""); + } + + #[test] + fn test_key_value_single_pair() { + let mut kv = KeyValue::new(); + kv.push("a", &1); + assert_eq!(kv.to_string(), "a: 1"); + } + + #[test] + fn test_key_value_aligns_values() { + let mut kv = KeyValue::new(); + kv.push("a", &1); + kv.push("hello", &"world"); + let expected = process( + r#" +a: 1 +hello: world +"#, + ); + assert_eq!(kv.to_string(), expected); + } + + #[test] + fn test_key_value_push_eager() { + let mut kv = KeyValue::new(); + kv.push_eager("a", 1); + kv.push_eager("hello", "world"); + + let expected = process( + r#" +a: 1 +hello: world +"#, + ); + + assert_eq!(kv.to_string(), expected); + } + + #[test] + fn test_key_value_multiline_value_is_indented() { + let mut inner = KeyValue::new(); + inner.push("x", &1); + inner.push("yy", &2); + let inner = inner.to_string(); + + let mut kv = KeyValue::new(); + kv.push("name", &"example"); + kv.push("nested", &inner); + kv.push("another line", &1); + + let expected = process( + r#" +name: example +nested: + x: 1 + yy: 2 +another line: 1 +"#, + ); + + assert_eq!(kv.to_string(), expected); + } + + #[test] + fn maybe_lazy_debug() { + let x = MaybeLazy::Lazy(&1); + let dbg = format!("{:?}", x); + assert_eq!(dbg, "MaybeLazy::Lazy(1)"); + + let x = MaybeLazy::Eager("hello".into()); + let dbg = format!("{:?}", x); + assert_eq!(dbg, "MaybeLazy::Eager(\"hello\")"); + } } diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index ce5018aad..27c249999 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -39,6 +39,7 @@ opentelemetry_sdk = { workspace = true, optional = true } scopeguard = { version = "1.2", optional = true } diskann-benchmark-core = { workspace = true, features = ["bigann"] } itertools.workspace = true +diskann-inmem = { workspace = true, optional = true } [lints] clippy.undocumented_unsafe_blocks = "warn" @@ -67,6 +68,9 @@ minmax-quantization = [] # Enable multi-vector MaxSim distance benchmarks multi-vector = [] +# Enable inmem 2.0 +inmem2 = ["dep:diskann-inmem"] + # Enable bftree backend bftree = ["dep:diskann-bftree"] diff --git a/diskann-benchmark/src/index/inmem2.rs b/diskann-benchmark/src/index/inmem2.rs new file mode 100644 index 000000000..a794fbfcf --- /dev/null +++ b/diskann-benchmark/src/index/inmem2.rs @@ -0,0 +1,1048 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{ + fmt::{self, Display, Formatter}, + io::Write, + num::NonZeroUsize, + ops::Range, + sync::Arc, +}; + +use diskann::graph::{self, DiskANNIndex, InplaceDeleteMethod, StartPointStrategy}; +use diskann_benchmark_core::{ + self as benchmark_core, build as build_core, recall, search as core_search, + streaming::{self, executors::bigann, Executor}, +}; +use diskann_benchmark_runner::{ + benchmark::{FailureScore, MatchScore}, + files::InputFile, + output::Output, + utils::{ + datatype::{AsDataType, DataType}, + fmt::{Delimit, KeyValue, Quote}, + }, + Benchmark, Checker, Checkpoint, Input, Registry, +}; +use diskann_inmem::{ + layers::{Full, FullPrecision}, + Provider, Strategy, +}; +use diskann_utils::views::{Matrix, MatrixView}; +use diskann_vector::distance::Metric; +use serde::{Deserialize, Serialize}; + +use crate::{ + index::{ + build::{BuildKind, BuildStats, ProgressMeter}, + result::{AggregatedSearchResults, SearchResults}, + streaming::stats::{GenericStats, StreamStats, Summary}, + }, + utils::{datafiles, SimilarityMeasure}, +}; + +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + registry.register("inmem2-f32", Build::::new())?; + // registry.register("inmem2-f16", Build::::new())?; + registry.register("inmem2-f32-stream", StreamingBenchmark::::new())?; + Ok(()) +} + +/////////// +// Input // +/////////// + +mod dto { + use super::*; + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct KnnSweep { + pub(super) search_n: usize, + pub(super) search_l: Vec, + pub(super) recall_k: usize, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct KnnSearch { + pub(super) queries: InputFile, + pub(super) groundtruth: InputFile, + pub(super) reps: NonZeroUsize, + pub(super) num_threads: Vec, + pub(super) runs: Vec, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Data { + pub(super) data_type: DataType, + pub(super) data: InputFile, + pub(super) distance: SimilarityMeasure, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct BuildParams { + pub(super) pruned_degree: usize, + pub(super) max_degree: usize, + pub(super) l_build: usize, + pub(super) alpha: f32, + pub(super) num_threads: NonZeroUsize, + } + + //-----------// + // Streaming // + //-----------// + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct StreamingKnnSearch { + pub(super) queries: InputFile, + pub(super) reps: NonZeroUsize, + pub(super) num_threads: NonZeroUsize, + pub(super) runs: Vec, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct RunBook { + pub(super) path: InputFile, + pub(super) dataset: String, + pub(super) groundtruth_directory: String, + pub(super) delete_method: crate::inputs::graph_index::InplaceDeleteMethod, + pub(super) delete_num_to_replace: usize, + } + + //------------------// + // Top Level Inputs // + //------------------// + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct StaticBuild { + pub(super) data: Data, + pub(super) build: BuildParams, + pub(super) search: KnnSearch, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct BigANNStreaming { + pub(super) data: Data, + pub(super) build: BuildParams, + pub(super) search: StreamingKnnSearch, + pub(super) runbook: RunBook, + } +} + +#[derive(Debug, Clone)] +struct KnnInstance { + knn: graph::search::Knn, + recall_k: usize, +} + +impl KnnInstance { + fn flatten(runs: &[dto::KnnSweep]) -> anyhow::Result> { + runs.iter() + .flat_map(|sweep| { + let search_n = sweep.search_n; + let recall_k = sweep.recall_k; + + sweep + .search_l + .iter() + .map(move |search_l| -> anyhow::Result<_> { + let knn = graph::search::Knn::new_default(search_n, *search_l)?; + Ok(KnnInstance { knn, recall_k }) + }) + }) + .collect() + } +} + +impl Display for KnnInstance { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "knn = {}, search_l = {}, beam_width = {}", + self.recall_k, + self.knn.l_value(), + self.knn.beam_width(), + ) + } +} + +#[derive(Debug)] +struct KnnSearch { + queries: InputFile, + groundtruth: InputFile, + reps: NonZeroUsize, + num_threads: Vec, + runs: Vec, +} + +impl KnnSearch { + fn from_raw(raw: dto::KnnSearch, checker: Option<&mut Checker>) -> anyhow::Result { + let dto::KnnSearch { + mut queries, + mut groundtruth, + reps, + num_threads, + runs, + } = raw; + + if let Some(checker) = checker { + queries.resolve(checker)?; + groundtruth.resolve(checker)?; + } + + Ok(Self { + queries, + groundtruth, + reps, + num_threads, + runs: KnnInstance::flatten(&runs)?, + }) + } + + fn maximum_recall_k(&self) -> usize { + self.runs.iter().map(|r| r.recall_k).max().unwrap_or(0) + } +} + +impl Display for KnnSearch { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("queries", &self.queries); + kv.push("groundtruth", &self.groundtruth); + kv.push("reps", &self.reps); + + let num_threads = Delimit::new(self.num_threads.iter(), ", "); + kv.push("num_threads", &num_threads); + + let runs = Delimit::new(self.runs.iter(), "\n").to_string(); + kv.push("runs", &runs); + write!(f, "{}", kv) + } +} + +#[derive(Debug)] +struct Data { + data_type: DataType, + data: InputFile, + distance: Metric, +} + +impl Data { + fn from_raw(raw: dto::Data, checker: Option<&mut Checker>) -> anyhow::Result { + let dto::Data { + data_type, + mut data, + distance, + } = raw; + + if let Some(checker) = checker { + data.resolve(checker)?; + } + + Ok(Self { + data_type, + data, + distance: distance.into(), + }) + } +} + +impl Display for Data { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("data_type", &self.data_type); + kv.push("data", &self.data); + kv.push("distance", &self.distance); + write!(f, "{}", kv) + } +} + +#[derive(Debug)] +struct BuildParams { + config: graph::Config, + num_threads: NonZeroUsize, +} + +impl BuildParams { + fn from_raw(raw: dto::BuildParams, metric: Metric) -> anyhow::Result { + let dto::BuildParams { + pruned_degree, + max_degree, + l_build, + alpha, + num_threads, + } = raw; + + let config = graph::config::Builder::new_with( + pruned_degree, + graph::config::MaxDegree::new(max_degree), + l_build, + metric.into(), + |b| { + b.alpha(alpha); + }, + ) + .build()?; + + Ok(Self { + config, + num_threads, + }) + } +} + +impl Display for BuildParams { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + + let pruned_degree = self.config.pruned_degree(); + let max_degree = self.config.max_degree(); + let alpha = self.config.alpha(); + let l_build = self.config.l_build(); + + kv.push("pruned_degree", &pruned_degree); + kv.push("max_degree", &max_degree); + kv.push("alpha", &alpha); + kv.push("l_build", &l_build); + kv.push("num_threads", &self.num_threads); + write!(f, "{}", kv) + } +} + +#[derive(Debug)] +struct StaticBuild { + data: Data, + build: BuildParams, + search: KnnSearch, + // The serialized representation of the original input. + input: serde_json::Value, +} + +impl StaticBuild { + fn from_raw(raw: dto::StaticBuild, mut checker: Option<&mut Checker>) -> anyhow::Result { + let input = serde_json::to_value(&raw)?; + + let dto::StaticBuild { + data, + build, + search, + } = raw; + + let data = Data::from_raw(data, checker.as_deref_mut())?; + let build = BuildParams::from_raw(build, data.distance)?; + let search = KnnSearch::from_raw(search, checker)?; + + Ok(Self { + data, + build, + search, + input, + }) + } +} + +impl Display for StaticBuild { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("data", &self.data); + kv.push("build", &self.build); + kv.push("search", &self.search); + + write!(f, "{}", kv) + } +} + +impl Input for StaticBuild { + type Raw = dto::StaticBuild; + + fn tag() -> &'static str { + "inmem2" + } + + fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { + Self::from_raw(raw, Some(checker)) + } + + fn serialize(&self) -> anyhow::Result { + Ok(self.input.clone()) + } + + fn example() -> Self::Raw { + const FOUR: NonZeroUsize = NonZeroUsize::new(4).unwrap(); + const THREE: NonZeroUsize = NonZeroUsize::new(3).unwrap(); + + dto::StaticBuild { + data: dto::Data { + data_type: DataType::Float32, + data: InputFile::new("path/to/data"), + distance: SimilarityMeasure::SquaredL2, + }, + build: dto::BuildParams { + pruned_degree: 28, + max_degree: 32, + l_build: 100, + alpha: 1.2, + num_threads: FOUR, + }, + search: dto::KnnSearch { + queries: InputFile::new("path/to/queries"), + groundtruth: InputFile::new("path/to/groundtruth"), + reps: THREE, + num_threads: vec![FOUR], + runs: vec![dto::KnnSweep { + search_n: 10, + search_l: vec![10, 20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} + +/////////////// +// Benchmark // +/////////////// + +#[derive(Debug)] +struct Build(std::marker::PhantomData); + +impl Build { + fn new() -> Self { + Self(std::marker::PhantomData) + } +} + +impl Benchmark for Build +where + T: diskann_inmem::layers::FullPrecision + diskann::graph::SampleableForStart + AsDataType, +{ + type Input = StaticBuild; + type Output = (); + + fn try_match(&self, input: &StaticBuild) -> Result { + if T::is_match(input.data.data_type) { + Ok(MatchScore(0)) + } else { + Err(FailureScore(1000)) + } + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&StaticBuild>, + ) -> std::fmt::Result { + match input { + Some(input) => { + let data_type = input.data.data_type; + if !T::is_match(data_type) { + write!( + f, + "expected data-type {}, instead got {}", + Quote(T::DATA_TYPE), + Quote(data_type) + )?; + } + } + None => { + write!( + f, + "full-precision static build+search with data type {}", + Quote(T::DATA_TYPE) + )?; + } + } + + Ok(()) + } + + fn run( + &self, + input: &StaticBuild, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result<()> { + writeln!(output, "{input}\n")?; + + // Load data. + let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &input.data.data, + ))?); + + let dim = data.ncols(); + let num_points = data.nrows(); + writeln!(output, "Loaded {num_points} points, dim={dim}")?; + + // Compute the medoid of the dataset as the single start point. + let start = StartPointStrategy::Medoid.compute(data.as_view())?; + let layer = Full::::new(dim, input.data.distance); + let config = + diskann_inmem::provider::Config::new(num_points, input.build.config.max_degree().get()); + let provider = Provider::<_, u32>::new(layer, config, start.row_iter())?; + + let index = Arc::new(DiskANNIndex::new( + input.build.config.clone(), + provider, + None, + )); + + // Build via SingleInsert. + let rt = benchmark_core::tokio::runtime(input.build.num_threads.get())?; + let builder = build_core::graph::SingleInsert::new( + index.clone(), + data, + Strategy, + build_core::ids::Identity::::new(), + ); + + let build_results = build_core::build_tracked( + builder, + build_core::Parallelism::dynamic(diskann::utils::ONE, input.build.num_threads), + &rt, + Some(&ProgressMeter::new(output)), + )?; + + let total_build_time = build_results.end_to_end_latency(); + writeln!( + output, + "\nBuild complete in {:.2}s", + total_build_time.as_seconds() + )?; + checkpoint.checkpoint(&total_build_time)?; + + // Search. + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &input.search.queries, + ))?); + let max_k = input.search.maximum_recall_k(); + let groundtruth = datafiles::load_groundtruth( + datafiles::BinFile(&input.search.groundtruth), + Some(max_k), + )?; + + writeln!(output, "Loaded {} queries\n", queries.nrows())?; + + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(Strategy), + )?; + + let results = _knn( + &knn, + &groundtruth, + input.search.reps, + &input.search.num_threads, + &input.search.runs, + )?; + + let results = AggregatedSearchResults::Topk(results); + + writeln!(output, "{}", results)?; + + Ok(()) + } +} + +fn _knn( + runner: &dyn crate::index::search::knn::Knn, + groundtruth: &dyn benchmark_core::recall::Rows, + reps: NonZeroUsize, + num_threads: &[NonZeroUsize], + instances: &[KnnInstance], +) -> anyhow::Result> { + let mut results = Vec::new(); + + for num_threads in num_threads.iter() { + for instance in instances.iter() { + let setup = core_search::Setup { + threads: *num_threads, + tasks: *num_threads, + reps, + }; + + let run = core_search::Run::new(instance.knn, setup); + + let r = runner.search_all( + vec![run], + groundtruth, + instance.recall_k, + instance.knn.k_value().get(), + )?; + + results.extend(r); + } + } + + Ok(results) +} + +/////////////// +// Streaming // +/////////////// + +#[derive(Debug, Clone)] +struct StreamingKnnSearch { + queries: InputFile, + reps: NonZeroUsize, + num_threads: NonZeroUsize, + runs: Vec, +} + +impl StreamingKnnSearch { + fn from_raw( + raw: dto::StreamingKnnSearch, + checker: Option<&mut Checker>, + ) -> anyhow::Result { + let dto::StreamingKnnSearch { + mut queries, + reps, + num_threads, + runs, + } = raw; + + if let Some(checker) = checker { + queries.resolve(checker)?; + } + + Ok(Self { + queries, + reps, + num_threads, + runs: KnnInstance::flatten(&runs)?, + }) + } +} + +impl Display for StreamingKnnSearch { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("queries", &self.queries); + kv.push("reps", &self.reps); + kv.push("num_threads", &self.num_threads); + + let runs = Delimit::new(self.runs.iter(), "\n"); + kv.push("runs", &runs); + write!(f, "{}", runs) + } +} + +#[derive(Debug)] +struct RunBook { + runbook: bigann::RunBook, + delete_method: InplaceDeleteMethod, + delete_num_to_replace: usize, + // This is kept for display purposes. + runbook_path: InputFile, + dataset: String, +} + +impl RunBook { + fn from_raw(raw: dto::RunBook, checker: &mut Checker) -> anyhow::Result { + let dto::RunBook { + mut path, + dataset, + groundtruth_directory, + delete_method, + delete_num_to_replace, + } = raw; + + path.resolve(checker)?; + + let groundtruth_directory = checker.find_input_dir(groundtruth_directory.as_ref())?; + + let runbook = bigann::RunBook::load( + &path, + &dataset, + &mut bigann::ScanDirectory::new(&groundtruth_directory)?, + )?; + + Ok(Self { + runbook, + delete_method: delete_method.into(), + delete_num_to_replace, + runbook_path: path, + dataset, + }) + } +} + +impl Display for RunBook { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + let path = self.runbook_path.display(); + kv.push("runbook", &path); + kv.push("dataset", &self.dataset); + + let max_points = self.runbook.max_points(); + let max_tag = self.runbook.max_tag(); + let num_stages = self.runbook.len(); + + kv.push("num_stages", &num_stages); + kv.push("max_active_points", &max_points); + if let Some(ref max_tag) = max_tag { + kv.push("max_tag", max_tag); + } + + kv.push_eager("delete_method", format_args!("{:?}", self.delete_method)); + kv.push("delete_num_to_replace", &self.delete_num_to_replace); + write!(f, "{}", kv) + } +} + +#[derive(Debug)] +struct BigANNStreaming { + data: Data, + build: BuildParams, + search: StreamingKnnSearch, + runbook: RunBook, + // The serialized representation of the original input. + input: serde_json::Value, +} + +impl BigANNStreaming { + fn from_raw(raw: dto::BigANNStreaming, checker: &mut Checker) -> anyhow::Result { + let input = serde_json::to_value(&raw)?; + let data = Data::from_raw(raw.data, Some(checker))?; + let build = BuildParams::from_raw(raw.build, data.distance)?; + Ok(Self { + data, + build, + search: StreamingKnnSearch::from_raw(raw.search, Some(checker))?, + runbook: RunBook::from_raw(raw.runbook, checker)?, + input, + }) + } +} + +impl Display for BigANNStreaming { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut kv = KeyValue::new(); + kv.push("data", &self.data); + kv.push("build", &self.build); + kv.push("search", &self.search); + kv.push("runbook", &self.runbook); + write!(f, "{}", kv) + } +} + +impl Input for BigANNStreaming { + type Raw = dto::BigANNStreaming; + + fn tag() -> &'static str { + "inmem2-streaming" + } + + fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result { + Self::from_raw(raw, checker) + } + + fn serialize(&self) -> anyhow::Result { + Ok(self.input.clone()) + } + + fn example() -> Self::Raw { + const FOUR: NonZeroUsize = NonZeroUsize::new(4).unwrap(); + const THREE: NonZeroUsize = NonZeroUsize::new(3).unwrap(); + + dto::BigANNStreaming { + data: dto::Data { + data_type: DataType::Float32, + data: InputFile::new("path/to/data"), + distance: SimilarityMeasure::SquaredL2, + }, + build: dto::BuildParams { + pruned_degree: 28, + max_degree: 32, + l_build: 100, + alpha: 1.2, + num_threads: FOUR, + }, + search: dto::StreamingKnnSearch { + queries: InputFile::new("path/to/queries"), + reps: THREE, + num_threads: FOUR, + runs: vec![dto::KnnSweep { + search_n: 10, + search_l: vec![10, 20, 30, 40, 50], + recall_k: 10, + }], + }, + runbook: dto::RunBook { + path: InputFile::new("path/to/runbook.yaml"), + dataset: "dataset-1M".into(), + groundtruth_directory: "groundtruth/dir".into(), + delete_method: crate::inputs::graph_index::InplaceDeleteMethod::TwoHopAndOneHop, + delete_num_to_replace: 3, + }, + } + } +} + +#[derive(Debug)] +struct StreamingBenchmark(std::marker::PhantomData); + +impl StreamingBenchmark { + fn new() -> Self { + Self(std::marker::PhantomData) + } +} + +impl Benchmark for StreamingBenchmark +where + T: FullPrecision + AsDataType + diskann::graph::SampleableForStart, +{ + type Input = BigANNStreaming; + type Output = Vec; + + fn try_match(&self, input: &BigANNStreaming) -> Result { + if T::is_match(input.data.data_type) { + Ok(MatchScore(0)) + } else { + Err(FailureScore(1000)) + } + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&BigANNStreaming>, + ) -> std::fmt::Result { + match input { + Some(input) => { + let data_type = input.data.data_type; + if !T::is_match(data_type) { + write!( + f, + "expected data-type {}, instead got {}", + Quote(T::DATA_TYPE), + Quote(data_type) + )?; + } + } + None => { + write!( + f, + "full-precision streaming with data type {}", + Quote(T::DATA_TYPE) + )?; + } + } + + Ok(()) + } + + fn run( + &self, + input: &BigANNStreaming, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + writeln!(output, "{input}\n")?; + + // Load the runbook so we know the eventual capacity. + let runbook = input.runbook.runbook.clone(); + let max_points = runbook.max_points(); + + // Load the dataset (consumed by `WithData`) and queries. + let dataset: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data.data))?; + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &input.search.queries, + ))?); + let dim = dataset.ncols(); + + // Compute the medoid of the dataset as the single start point. + let start = StartPointStrategy::Medoid.compute(dataset.as_view())?; + let index_config = input.build.config.clone(); + let layer = Full::::new(dim, input.data.distance); + + let config = + diskann_inmem::provider::Config::new(max_points, index_config.max_degree().get()); + let provider = Provider::<_, u32>::new(layer, config, start.row_iter())?; + + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + + let num_threads = input.build.num_threads; + let runtime = benchmark_core::tokio::runtime(num_threads.get())?; + + let stream = Stream { + index, + runtime, + search: input.search.clone(), + ntasks: input.build.num_threads, + delete_method: input.runbook.delete_method, + delete_num_to_replace: input.runbook.delete_num_to_replace, + }; + + let mut layered = bigann::WithData::new(stream, dataset, queries, move |path| { + Ok(Box::new(datafiles::load_groundtruth( + datafiles::BinFile(path), + None, + )?)) + }); + + // Here we go! + let mut results = Vec::new(); + let stages = runbook.len(); + let mut i = 1; + input.runbook.runbook.clone().run_with( + &mut layered, + |o: StreamStats| -> anyhow::Result<()> { + if o.is_maintain() { + let message = format!("Ran maintenance before stage {}", i); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + } else { + let message = format!("Finished stage {} of {}: {}", i, stages, o.kind()); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + i += 1; + } + writeln!(output, "{}", o)?; + results.push(o); + Ok(()) + }, + )?; + + write!( + output, + "{}", + crate::utils::SmallBanner("End of Run Summary") + )?; + + writeln!(output, "{}", Summary::new(results.iter()))?; + + Ok(results) + } +} + +//////////// +// Stream // +//////////// + +struct Stream +where + T: FullPrecision, +{ + index: Arc>>>, + runtime: tokio::runtime::Runtime, + search: StreamingKnnSearch, + ntasks: NonZeroUsize, + delete_method: InplaceDeleteMethod, + delete_num_to_replace: usize, +} + +impl Stream +where + T: FullPrecision, +{ + fn insert_( + &mut self, + data: MatrixView<'_, T>, + ids: Range, + ) -> anyhow::Result { + anyhow::ensure!( + data.nrows() == ids.len(), + "insert: data rows ({}) != ids range ({})", + data.nrows(), + ids.len(), + ); + + let runner = build_core::graph::SingleInsert::new( + self.index.clone(), + Arc::new(data.to_owned()), + Strategy, + build_core::ids::Range::::new(ids.start as u32..ids.end as u32), + ); + + let results = build_core::build( + runner, + build_core::Parallelism::dynamic(diskann::utils::ONE, self.ntasks), + &self.runtime, + )?; + + BuildStats::new(BuildKind::SingleInsert, results) + } +} + +impl streaming::Stream> for Stream +where + T: FullPrecision, +{ + type Output = StreamStats; + + fn search( + &mut self, + (queries, groundtruth): (Arc>, &dyn recall::Rows), + ) -> anyhow::Result { + let knn = benchmark_core::search::graph::KNN::new( + self.index.clone(), + queries, + benchmark_core::search::graph::Strategy::broadcast(Strategy), + )?; + + let r = _knn( + &knn, + groundtruth, + self.search.reps, + std::slice::from_ref(&self.search.num_threads), + &self.search.runs, + )?; + + Ok(StreamStats::Search(r)) + } + + fn insert( + &mut self, + (data, ids): (MatrixView<'_, T>, Range), + ) -> anyhow::Result { + self.insert_(data, ids).map(StreamStats::Insert) + } + + fn delete(&mut self, ids: Range) -> anyhow::Result { + let runner = streaming::graph::InplaceDelete::new( + self.index.clone(), + Strategy, + self.delete_num_to_replace, + self.delete_method, + build_core::ids::Range::new(ids.start as u32..ids.end as u32), + ); + + let r = build_core::build( + runner, + diskann_benchmark_core::build::Parallelism::fixed( + Some(diskann::utils::ONE), + self.ntasks, + ), + &self.runtime, + )?; + + Ok(StreamStats::Delete(GenericStats::new("delete".into(), r)?)) + } + + fn replace( + &mut self, + (data, ids): (MatrixView<'_, T>, Range), + ) -> anyhow::Result { + use diskann::provider::Delete; + + // TODO: This is kind of a hack. It would be ideal to parallelize this. + // + // Also, this is *way* more expensive than it needs to be because each delete creates + // and then destroys an EBR guard. + let ctx = diskann_inmem::Context; + for id in ids.clone() { + self.runtime + .block_on(self.index.provider().delete(&ctx, &(id as u32)))?; + } + + self.insert_(data, ids).map(StreamStats::Replace) + } + + fn maintain(&mut self, _: ()) -> anyhow::Result { + Ok(StreamStats::Maintain(vec![])) + } + + fn needs_maintenance(&mut self) -> bool { + false + } +} diff --git a/diskann-benchmark/src/index/mod.rs b/diskann-benchmark/src/index/mod.rs index 3900dc337..7d9878c44 100644 --- a/diskann-benchmark/src/index/mod.rs +++ b/diskann-benchmark/src/index/mod.rs @@ -16,11 +16,17 @@ mod result; #[cfg(feature = "bftree")] mod bftree; +#[cfg(feature = "inmem2")] +mod inmem2; + pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { benchmarks::register_benchmarks(registry)?; #[cfg(feature = "bftree")] bftree::register_benchmarks(registry)?; + #[cfg(feature = "inmem2")] + inmem2::register_benchmarks(registry)?; + Ok(()) } diff --git a/diskann-inmem/.clippy.toml b/diskann-inmem/.clippy.toml new file mode 100644 index 000000000..7bada5473 --- /dev/null +++ b/diskann-inmem/.clippy.toml @@ -0,0 +1,3 @@ +allow-unwrap-in-tests = true +allow-expect-in-tests = true +allow-panic-in-tests = true diff --git a/diskann-inmem/Cargo.toml b/diskann-inmem/Cargo.toml new file mode 100644 index 000000000..00e02dbf2 --- /dev/null +++ b/diskann-inmem/Cargo.toml @@ -0,0 +1,62 @@ +[package] +name = "diskann-inmem" +version.workspace = true +description.workspace = true +authors.workspace = true +repository.workspace = true +license.workspace = true +edition = "2024" + +[dependencies] +bytemuck = { workspace = true, features = ["must_cast"] } +crossbeam-queue = "0.3.12" +dashmap.workspace = true +diskann = { workspace = true } +diskann-utils = { workspace = true, default-features = false } +diskann-vector = { workspace = true } +diskann-wide = { workspace = true } +parking_lot = "0.12.5" +thiserror = { workspace = true } +half = { workspace = true } + +# Integration Test Dependencies +diskann-benchmark-runner = { workspace = true, optional = true, features = ["ux-tools"] } +serde = { workspace = true, features = ["derive"], optional = true } +serde_json = { workspace = true, optional = true } +anyhow = { workspace = true, optional = true } +rand = { workspace = true, optional = true } +diskann-benchmark-core = { workspace = true, optional = true } +tokio = { workspace = true, optional = true } + +[lints.clippy] +undocumented_unsafe_blocks = "warn" +unwrap_used = "warn" +expect_used = "warn" +panic = "warn" +uninlined_format_args = "allow" + +[dev-dependencies] +diskann = { workspace = true, features = ["testing"] } +rand = { workspace = true } +tokio = { workspace = true, features = ["macros"] } +tempfile = { workspace = true } + +[[bin]] +name = "integration-test" +path = "integration/main.rs" +required-features = ["integration-test"] + +[features] +default = [] + +# Enable stress test module +integration-test = [ + "dep:diskann-benchmark-runner", + "dep:diskann-benchmark-core", + "dep:tokio", + "dep:serde", + "dep:serde_json", + "dep:anyhow", + "dep:rand", + "diskann-utils/testing", +] diff --git a/diskann-inmem/DEV.md b/diskann-inmem/DEV.md new file mode 100644 index 000000000..59b29c213 --- /dev/null +++ b/diskann-inmem/DEV.md @@ -0,0 +1,7 @@ +# Dev Docs + +Fully testing this crate requires enabling the `integration-test` feature. +The suggested command is +``` +cargo test --package diskann-inmem --all-features --profile ci +``` diff --git a/diskann-inmem/integration/index/mod.rs b/diskann-inmem/integration/index/mod.rs new file mode 100644 index 000000000..735ca820b --- /dev/null +++ b/diskann-inmem/integration/index/mod.rs @@ -0,0 +1,17 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod object; +mod runner; +mod tests; + +use object::{Counters, Index, KnnSearch}; + +use diskann_benchmark_runner::{Registry, RegistryError}; + +pub(super) fn register(registry: &mut Registry) -> Result<(), RegistryError> { + runner::register(registry)?; + Ok(()) +} diff --git a/diskann-inmem/integration/index/object.rs b/diskann-inmem/integration/index/object.rs new file mode 100644 index 000000000..cc5f4dc6f --- /dev/null +++ b/diskann-inmem/integration/index/object.rs @@ -0,0 +1,229 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{future::Future, pin::Pin}; + +use diskann::{ + graph::{DiskANNIndex, search::Knn}, + neighbor::Neighbor, + utils::IntoUsize, +}; +use diskann_benchmark_runner::utils::fmt::KeyValue; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use diskann_inmem::{Context, Provider, Strategy, integration, layers}; + +use crate::support::{ + check::{CheckMatch, Match, check_all_fields}, + datatype::{AsDataType, FromSlice, Slice}, +}; + +pub(crate) trait Index { + fn search<'a>( + &'a self, + query: Slice<'a>, + knn: Knn, + neighbors: &'a mut Vec>, + ) -> Pin> + 'a>>; + + fn insert<'a>( + &'a self, + vector: Slice<'a>, + id: u64, + ) -> Pin> + 'a>>; + + fn counters(&self) -> Counters; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct KnnSearch { + hops: usize, + cmps: usize, +} + +impl KnnSearch { + pub(crate) fn new() -> Self { + Self { hops: 0, cmps: 0 } + } +} + +impl From for KnnSearch { + fn from(stats: diskann::graph::index::SearchStats) -> Self { + Self { + hops: stats.hops.into_usize(), + cmps: stats.cmps.into_usize(), + } + } +} + +impl std::ops::AddAssign for KnnSearch { + fn add_assign(&mut self, rhs: Self) { + self.hops = self.hops.wrapping_add(rhs.hops); + self.cmps = self.cmps.wrapping_add(rhs.cmps); + } +} + +impl std::fmt::Display for KnnSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "hops = {}, cmps = {}", self.hops, self.cmps) + } +} + +impl CheckMatch for KnnSearch { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { hops, cmps }, + ); + + builder.finish_with_remark(Some( + "check assumes deterministic (usually single-threaded) execution".into(), + )) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Counters { + query_distance: u64, + distance: u64, + get_vector: u64, + set_vector: u64, + get_neighbors: u64, + set_neighbors: u64, + append_neighbors: u64, +} + +impl Counters { + pub(crate) fn delta(&self, after: &Counters) -> anyhow::Result { + #[derive(Debug, Error)] + #[error( + "counter \"{}\" non-monotonically increasing from {} to {}", + self.0, + self.1, + self.2 + )] + struct NonMonotonic(&'static str, u64, u64); + + fn check(before: u64, after: u64, field: &'static str) -> Result { + after + .checked_sub(before) + .ok_or(NonMonotonic(field, before, after)) + } + + let delta = Self { + query_distance: check(self.query_distance, after.query_distance, "query_distance")?, + distance: check(self.distance, after.distance, "distance")?, + get_vector: check(self.get_vector, after.get_vector, "get_vector")?, + set_vector: check(self.set_vector, after.set_vector, "set_vector")?, + get_neighbors: check(self.get_neighbors, after.get_neighbors, "get_neighbors")?, + set_neighbors: check(self.set_neighbors, after.set_neighbors, "set_neighbors")?, + append_neighbors: check( + self.append_neighbors, + after.append_neighbors, + "append_neighbors", + )?, + }; + + Ok(delta) + } +} + +impl std::fmt::Display for Counters { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("query_distance", &self.query_distance); + kv.push("distance", &self.distance); + kv.push("get_vector", &self.get_vector); + kv.push("set_vector", &self.set_vector); + kv.push("get_neighbors", &self.get_neighbors); + kv.push("set_neighbors", &self.set_neighbors); + kv.push("append_neighbors", &self.append_neighbors); + write!(f, "{}", kv) + } +} + +impl From for Counters { + fn from(snapshot: integration::counters::CounterSnapshot) -> Self { + Self { + query_distance: snapshot.query_distance, + distance: snapshot.distance, + get_vector: snapshot.get_vector, + set_vector: snapshot.set_vector, + get_neighbors: snapshot.get_neighbors, + set_neighbors: snapshot.set_neighbors, + append_neighbors: snapshot.append_neighbors, + } + } +} + +impl CheckMatch for Counters { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { + query_distance, + distance, + get_vector, + set_vector, + get_neighbors, + set_neighbors, + append_neighbors + } + ); + + builder.finish_with_remark(Some( + "check assumes deterministic (usually single-threaded) execution".into(), + )) + } +} + +/////////// +// Impls // +/////////// + +impl Index for DiskANNIndex, u64>> +where + T: layers::FullPrecision + FromSlice + AsDataType, +{ + fn search<'a>( + &'a self, + query: Slice<'a>, + knn: Knn, + neighbors: &'a mut Vec>, + ) -> Pin> + 'a>> { + let fut = async move { + let query = query.try_cast()?; + let stats = self + .search(knn, &Strategy, &Context, query, neighbors) + .await?; + + Ok(stats.into()) + }; + + Box::pin(fut) + } + + fn insert<'a>( + &'a self, + vector: Slice<'a>, + id: u64, + ) -> Pin> + 'a>> { + let fut = async move { + let vector = vector.try_cast()?; + self.insert(&Strategy, &Context, &id, vector).await?; + + Ok(()) + }; + + Box::pin(fut) + } + + fn counters(&self) -> Counters { + self.provider().counters().into() + } +} diff --git a/diskann-inmem/integration/index/runner.rs b/diskann-inmem/integration/index/runner.rs new file mode 100644 index 000000000..ce85e106d --- /dev/null +++ b/diskann-inmem/integration/index/runner.rs @@ -0,0 +1,600 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{io::Write, sync::Arc}; + +use anyhow::Context; +use diskann::graph::{DiskANNIndex, search::Knn}; +use diskann_benchmark_runner::{ + Checker, Checkpoint, Output, Registry, RegistryError, + benchmark::{FailureScore, MatchScore, PassFail, Regression}, + files::InputFile, + utils::fmt::Indent, +}; +use diskann_utils::views::Matrix; +use diskann_vector::distance::Metric; +use half::f16; +use serde::{Deserialize, Serialize}; + +use diskann_inmem::{Provider, layers}; + +use crate::{ + index::{Counters, Index}, + support::{ + check::{CheckMatch, Match, check_all_fields}, + datatype::{self, DataType, Dataset, DatasetView}, + io::load_and_convert, + tolerance, + }, +}; + +pub(super) fn register(registry: &mut Registry) -> Result<(), RegistryError> { + registry.register_regression("full-precision-integration-test", FullPrecision)?; + Ok(()) +} + +mod dto { + use super::*; + + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "kebab-case")] + pub(super) enum SerdeMetric { + L2, + InnerProduct, + Cosine, + } + + impl From for Metric { + fn from(m: SerdeMetric) -> Self { + match m { + SerdeMetric::L2 => Metric::L2, + SerdeMetric::InnerProduct => Metric::InnerProduct, + SerdeMetric::Cosine => Metric::Cosine, + } + } + } + + impl TryFrom for SerdeMetric { + type Error = anyhow::Error; + fn try_from(m: Metric) -> anyhow::Result { + match m { + Metric::L2 => Ok(SerdeMetric::L2), + Metric::InnerProduct => Ok(SerdeMetric::InnerProduct), + Metric::Cosine => Ok(SerdeMetric::Cosine), + Metric::CosineNormalized => anyhow::bail!("cosine normalized is not supported"), + } + } + } + + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "kebab-case")] + pub(super) enum Preprocess { + Halve, + Floor, + } + + impl From for datatype::Preprocess { + fn from(op: Preprocess) -> Self { + match op { + Preprocess::Halve => datatype::Preprocess::Halve, + Preprocess::Floor => datatype::Preprocess::Floor, + } + } + } + + impl From<&datatype::Preprocess> for Preprocess { + fn from(op: &datatype::Preprocess) -> Self { + match op { + datatype::Preprocess::Halve => Preprocess::Halve, + datatype::Preprocess::Floor => Preprocess::Floor, + } + } + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Data { + pub(super) data: InputFile, + pub(super) queries: InputFile, + pub(super) groundtruth: InputFile, + pub(super) metric: SerdeMetric, + pub(super) data_type: DataType, + pub(super) preprocess: Vec, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) enum Layer { + FullPrecision { data_type: DataType }, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Build { + pub(super) pruned_degree: usize, + pub(super) max_degree: usize, + pub(super) l_build: usize, + pub(super) alpha: f32, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct KnnSearch { + pub(super) knn: usize, + pub(super) search_l: usize, + #[serde(deserialize_with = "Deserialize::deserialize")] + pub(super) beam_width: Option, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Search { + pub(super) knn: Vec, + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct Test { + pub(super) data: Data, + pub(super) layer: Layer, + pub(super) build: Build, + pub(super) search: Search, + } +} + +#[derive(Debug)] +struct Data { + data: InputFile, + queries: InputFile, + groundtruth: InputFile, + metric: Metric, + data_type: DataType, + preprocess: Vec, +} + +impl Data { + fn from_raw(raw: dto::Data, checker: Option<&mut Checker>) -> anyhow::Result { + let dto::Data { + mut data, + mut queries, + mut groundtruth, + metric, + data_type, + preprocess, + } = raw; + + if let Some(checker) = checker { + data.resolve(checker)?; + queries.resolve(checker)?; + groundtruth.resolve(checker)?; + } + + Ok(Self { + data, + queries, + groundtruth, + metric: metric.into(), + data_type, + preprocess: preprocess.into_iter().map(From::from).collect(), + }) + } + + fn as_raw(&self) -> anyhow::Result { + Ok(dto::Data { + data: self.data.clone(), + queries: self.queries.clone(), + groundtruth: self.groundtruth.clone(), + metric: self.metric.try_into()?, + data_type: self.data_type, + preprocess: self.preprocess.iter().map(From::from).collect(), + }) + } + + fn load_as(&self, data_type: DataType) -> anyhow::Result { + let data = { + let mut io = std::fs::File::open(&*self.data) + .with_context(|| format!("could not open {}", self.data.display()))?; + + load_and_convert(&mut io, self.data_type, data_type, &self.preprocess)? + }; + + let queries = { + let mut io = std::fs::File::open(&*self.queries) + .with_context(|| format!("could not open {}", self.queries.display()))?; + + load_and_convert(&mut io, self.data_type, data_type, &self.preprocess)? + }; + + let groundtruth = { + let mut io = std::fs::File::open(&*self.groundtruth) + .with_context(|| format!("could not open {}", self.queries.display()))?; + + let raw = diskann_utils::io::read_bin::(&mut io)?; + raw.map(|&x| u64::from(x)) + }; + + Ok(Bundle { + data, + queries, + groundtruth, + }) + } +} + +#[derive(Debug)] +struct Bundle { + data: Dataset, + queries: Dataset, + groundtruth: Matrix, +} + +#[derive(Debug)] +enum Layer { + FullPrecision { data_type: DataType }, +} + +impl Layer { + fn from_raw(raw: dto::Layer) -> Self { + match raw { + dto::Layer::FullPrecision { data_type } => Self::FullPrecision { data_type }, + } + } + + fn as_raw(&self) -> dto::Layer { + match self { + Self::FullPrecision { data_type } => dto::Layer::FullPrecision { + data_type: *data_type, + }, + } + } +} + +#[derive(Debug)] +struct Build { + config: diskann::graph::Config, +} + +impl Build { + fn from_raw(raw: dto::Build, metric: Metric) -> anyhow::Result { + let dto::Build { + pruned_degree, + max_degree, + l_build, + alpha, + } = raw; + + let config = diskann::graph::config::Builder::new_with( + pruned_degree, + diskann::graph::config::MaxDegree::new(max_degree), + l_build, + metric.into(), + |b| { + b.alpha(alpha); + }, + ) + .build()?; + + Ok(Self { config }) + } + + fn as_raw(&self) -> dto::Build { + dto::Build { + pruned_degree: self.config.pruned_degree().get(), + max_degree: self.config.max_degree().get(), + l_build: self.config.l_build().get(), + alpha: self.config.alpha(), + } + } +} + +#[derive(Debug)] +struct Search { + knn: Vec, +} + +impl Search { + fn from_raw(raw: dto::Search) -> anyhow::Result { + fn make_knn(raw: &dto::KnnSearch) -> anyhow::Result { + Ok(Knn::new(raw.knn, raw.search_l, raw.beam_width)?) + } + + Ok(Self { + knn: raw + .knn + .iter() + .map(make_knn) + .collect::>>()?, + }) + } + + fn as_raw(&self) -> dto::Search { + fn make_knn(knn: &Knn) -> dto::KnnSearch { + dto::KnnSearch { + knn: knn.k_value().get(), + search_l: knn.l_value().get(), + beam_width: Some(knn.beam_width().get()), + } + } + + dto::Search { + knn: self.knn.iter().map(make_knn).collect(), + } + } +} + +#[derive(Debug)] +struct Test { + data: Data, + layer: Layer, + build: Build, + search: Search, +} + +impl Test { + fn from_raw(raw: dto::Test, checker: Option<&mut Checker>) -> anyhow::Result { + let data = Data::from_raw(raw.data, checker)?; + let layer = Layer::from_raw(raw.layer); + let build = Build::from_raw(raw.build, data.metric)?; + let search = Search::from_raw(raw.search)?; + + Ok(Self { + data, + layer, + build, + search, + }) + } + + fn as_raw(&self) -> anyhow::Result { + Ok(dto::Test { + data: self.data.as_raw()?, + layer: self.layer.as_raw(), + build: self.build.as_raw(), + search: self.search.as_raw(), + }) + } + + fn index( + &self, + capacity: usize, + start_points: DatasetView<'_>, + ) -> anyhow::Result> { + match self.layer { + Layer::FullPrecision { data_type } => { + if start_points.data_type() != data_type { + anyhow::bail!( + "mismatched data types for start point - expected {}, got {}", + data_type, + start_points.data_type(), + ); + } + + let dim = start_points.ncols(); + let metric = self.data.metric; + let config = diskann_inmem::provider::Config::new( + capacity, + self.build.config.max_degree().get(), + ); + + let index_config = self.build.config.clone(); + + let index = match start_points { + DatasetView::F32(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, + index_config, + ), + DatasetView::F16(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, + index_config, + ), + DatasetView::U8(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, + index_config, + ), + DatasetView::I8(v) => finish( + Provider::new(layers::Full::::new(dim, metric), config, v.row_iter())?, + index_config, + ), + }; + + Ok(index) + } + } + } +} + +fn finish(provider: DP, config: diskann::graph::Config) -> Arc +where + DP: diskann::provider::DataProvider, + DiskANNIndex: Index, +{ + Arc::new(DiskANNIndex::new(config, provider, None)) +} + +/////////////// +// Benchmark // +/////////////// + +impl diskann_benchmark_runner::Input for Test { + type Raw = dto::Test; + + fn tag() -> &'static str { + "integration-test" + } + + fn from_raw(raw: dto::Test, checker: &mut Checker) -> anyhow::Result { + ::from_raw(raw, Some(checker)) + } + + fn serialize(&self) -> anyhow::Result { + let raw = self.as_raw()?; + Ok(serde_json::to_value(raw)?) + } + + fn example() -> dto::Test { + dto::Test { + data: dto::Data { + data: InputFile::new("path/to/data"), + queries: InputFile::new("path/to/queries"), + groundtruth: InputFile::new("path/to/groundtruth"), + metric: dto::SerdeMetric::L2, + data_type: DataType::F32, + preprocess: vec![], + }, + layer: dto::Layer::FullPrecision { + data_type: DataType::F32, + }, + build: dto::Build { + pruned_degree: 16, + max_degree: 20, + l_build: 50, + alpha: 1.2, + }, + search: dto::Search { + knn: vec![ + dto::KnnSearch { + knn: 10, + search_l: 50, + beam_width: None, + }, + dto::KnnSearch { + knn: 10, + search_l: 50, + beam_width: Some(3), + }, + dto::KnnSearch { + knn: 20, + search_l: 100, + beam_width: Some(3), + }, + ], + }, + } + } +} + +//////////////// +// Benchmarks // +//////////////// + +#[derive(Debug)] +struct FullPrecision; + +impl diskann_benchmark_runner::Benchmark for FullPrecision { + type Input = Test; + type Output = BuildAndSearch; + + fn try_match(&self, input: &Test) -> Result { + let Layer::FullPrecision { .. } = input.layer; + Ok(MatchScore(0)) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + _input: Option<&Test>, + ) -> std::fmt::Result { + write!(f, "nop") + } + + fn run( + &self, + input: &Test, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + let Layer::FullPrecision { data_type } = input.layer; + + // Load the data and perform any necessary data conversions. + let Bundle { + data, + queries, + groundtruth, + } = input.data.load_as(data_type)?; + + let index = input.index(data.nrows(), data.medoid().as_view())?; + let rt = diskann_benchmark_core::tokio::runtime(1)?; + let build = super::tests::insert(&*index, data.as_view(), rt.handle())?; + + let mut knn = Vec::new(); + for param in input.search.knn.iter() { + let stats = super::tests::knn( + &*index, + *param, + queries.as_view(), + &groundtruth.as_view(), + rt.handle(), + )?; + + knn.push(stats); + } + + let build_and_search = BuildAndSearch { build, knn }; + + writeln!(output, "{}", build_and_search)?; + + Ok(build_and_search) + } +} + +impl Regression for FullPrecision { + type Tolerances = tolerance::Empty; + type Pass = Match; + type Fail = Match; + + fn check( + &self, + _tolerances: &Self::Tolerances, + _input: &Self::Input, + before: &Self::Output, + after: &Self::Output, + ) -> anyhow::Result> { + Ok(before.check_match(after).pass_fail()) + } +} + +//////////// +// Output // +//////////// + +#[derive(Debug, Serialize, Deserialize)] +struct BuildAndSearch { + build: Counters, + knn: Vec, +} + +impl std::fmt::Display for BuildAndSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "build stats")?; + writeln!(f, "{}", Indent::new(&self.build.to_string(), 4))?; + writeln!(f, "knn stats")?; + for k in self.knn.iter() { + writeln!(f, "{}\n", k)?; + } + + Ok(()) + } +} + +impl CheckMatch for BuildAndSearch { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { build, knn }, + ); + builder.finish() + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn example_parses() { + let _ = Test::from_raw(::example(), None).unwrap(); + } +} diff --git a/diskann-inmem/integration/index/tests.rs b/diskann-inmem/integration/index/tests.rs new file mode 100644 index 000000000..436a851e3 --- /dev/null +++ b/diskann-inmem/integration/index/tests.rs @@ -0,0 +1,143 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann::graph::search::Knn; +use diskann_benchmark_core::recall::{RecallMetrics, Rows}; +use diskann_benchmark_runner::utils::fmt::KeyValue; +use diskann_utils::views::Matrix; +use serde::{Deserialize, Serialize}; + +use crate::{ + index::{Counters, Index, KnnSearch}, + support::{ + check::{CheckMatch, Match, check_all_fields}, + datatype::DatasetView, + }, +}; + +pub(super) fn insert( + index: &dyn Index, + dataset: DatasetView<'_>, + rt: &tokio::runtime::Handle, +) -> anyhow::Result { + let before = index.counters(); + for (i, r) in dataset.iter().enumerate() { + rt.block_on(index.insert(r, i as u64))?; + } + before.delta(&index.counters()) +} + +pub(super) fn knn( + index: &dyn Index, + knn: Knn, + queries: DatasetView<'_>, + groundtruth: &dyn Rows, + rt: &tokio::runtime::Handle, +) -> anyhow::Result { + anyhow::ensure!( + queries.nrows() == groundtruth.nrows(), + "number of queries ({}) must match number of groundtruth entries ({})", + queries.nrows(), + groundtruth.nrows(), + ); + + let mut ids = Matrix::new(u64::MAX, queries.nrows(), knn.k_value().get()); + + let before = index.counters(); + let mut misc = KnnSearch::new(); + let mut neighbors = Vec::new(); + for (out, query) in std::iter::zip(ids.row_iter_mut(), queries.iter()) { + neighbors.clear(); + + let stats = rt.block_on(index.search(query, knn, &mut neighbors))?; + misc += stats; + + std::iter::zip(out.iter_mut(), neighbors.iter()).for_each(|(d, s)| *d = s.id); + } + let counters = before.delta(&index.counters())?; + + let recall = diskann_benchmark_core::recall::knn( + groundtruth, + None, + &ids.as_view(), + knn.k_value().get(), + knn.k_value().get(), + diskann_benchmark_core::recall::GroundTruthMode::Fixed, + )?; + + Ok(KnnStats { + counters, + recall: recall.into(), + misc, + }) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct KnnRecall { + recall_k: usize, + recall_n: usize, + num_queries: usize, + average: f64, +} + +impl From for KnnRecall { + fn from(metrics: RecallMetrics) -> Self { + Self { + recall_k: metrics.recall_k, + recall_n: metrics.recall_n, + num_queries: metrics.num_queries, + average: metrics.average, + } + } +} + +impl CheckMatch for KnnRecall { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { recall_k, recall_n, num_queries, average } + ); + builder.finish() + } +} + +impl std::fmt::Display for KnnRecall { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "recall = {:.4}, recall_k = {}, recall_n = {}, num_queries = {}", + self.average, self.recall_k, self.recall_n, self.num_queries + ) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct KnnStats { + recall: KnnRecall, + counters: Counters, + misc: KnnSearch, +} + +impl std::fmt::Display for KnnStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("counters", &self.counters); + kv.push("recall", &self.recall); + kv.push("misc", &self.misc); + write!(f, "{}", kv) + } +} + +impl CheckMatch for KnnStats { + fn check_match(&self, previous: &Self) -> Match { + let builder = check_all_fields!( + self, + previous, + { recall, counters, misc } + ); + builder.finish() + } +} diff --git a/diskann-inmem/integration/jsons/checks.json b/diskann-inmem/integration/jsons/checks.json new file mode 100644 index 000000000..ed53e8d06 --- /dev/null +++ b/diskann-inmem/integration/jsons/checks.json @@ -0,0 +1,14 @@ +{ + "checks": [ + { + "input": { + "type": "integration-test", + "content": {} + }, + "tolerance": { + "type": "empty-tolerance", + "content": null + } + } + ] +} diff --git a/diskann-inmem/integration/jsons/integration-baseline.json b/diskann-inmem/integration/jsons/integration-baseline.json new file mode 100644 index 000000000..453375e2e --- /dev/null +++ b/diskann-inmem/integration/jsons/integration-baseline.json @@ -0,0 +1,489 @@ +[ + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/yfcc/yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "/yfcc/groundtruth.bin", + "metric": "l2", + "preprocess": [], + "queries": "/yfcc/yfcc_query_100.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "f32" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 96949, + "distance": 2867876, + "get_neighbors": 352139, + "get_vector": 3067092, + "query_distance": 2240744, + "set_neighbors": 23599, + "set_vector": 10000 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5441, + "get_vector": 44988, + "query_distance": 44988, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 44988, + "hops": 5441 + }, + "recall": { + "average": 0.975, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5806, + "get_vector": 49075, + "query_distance": 49075, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 49075, + "hops": 5806 + }, + "recall": { + "average": 0.974, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 10634, + "get_vector": 74001, + "query_distance": 74001, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 74001, + "hops": 10634 + }, + "recall": { + "average": 0.992, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + }, + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/yfcc/yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "/yfcc/groundtruth.bin", + "metric": "l2", + "preprocess": [], + "queries": "/yfcc/yfcc_query_100.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "f16" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 96949, + "distance": 2867876, + "get_neighbors": 352139, + "get_vector": 3067092, + "query_distance": 2240744, + "set_neighbors": 23599, + "set_vector": 10000 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5441, + "get_vector": 44988, + "query_distance": 44988, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 44988, + "hops": 5441 + }, + "recall": { + "average": 0.975, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5806, + "get_vector": 49075, + "query_distance": 49075, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 49075, + "hops": 5806 + }, + "recall": { + "average": 0.974, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 10634, + "get_vector": 74001, + "query_distance": 74001, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 74001, + "hops": 10634 + }, + "recall": { + "average": 0.992, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + }, + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/yfcc/yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "/yfcc/groundtruth.bin", + "metric": "l2", + "preprocess": [], + "queries": "/yfcc/yfcc_query_100.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "u8" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 96949, + "distance": 2867876, + "get_neighbors": 352139, + "get_vector": 3067092, + "query_distance": 2240744, + "set_neighbors": 23599, + "set_vector": 10000 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5441, + "get_vector": 44988, + "query_distance": 44988, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 44988, + "hops": 5441 + }, + "recall": { + "average": 0.975, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5806, + "get_vector": 49075, + "query_distance": 49075, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 49075, + "hops": 5806 + }, + "recall": { + "average": 0.974, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 10634, + "get_vector": 74001, + "query_distance": 74001, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 74001, + "hops": 10634 + }, + "recall": { + "average": 0.992, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + }, + { + "input": { + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "/yfcc/yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "/yfcc/groundtruth.bin", + "metric": "l2", + "preprocess": [ + "halve", + "floor" + ], + "queries": "/yfcc/yfcc_query_100.fbin" + }, + "layer": { + "FullPrecision": { + "data_type": "i8" + } + }, + "search": { + "knn": [ + { + "beam_width": 1, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + }, + "type": "integration-test" + }, + "results": { + "build": { + "append_neighbors": 97055, + "distance": 2867292, + "get_neighbors": 352087, + "get_vector": 3064106, + "query_distance": 2238420, + "set_neighbors": 23587, + "set_vector": 10000 + }, + "knn": [ + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5446, + "get_vector": 44805, + "query_distance": 44805, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 44805, + "hops": 5446 + }, + "recall": { + "average": 0.961, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 5771, + "get_vector": 48508, + "query_distance": 48508, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 48508, + "hops": 5771 + }, + "recall": { + "average": 0.9590000000000002, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + }, + { + "counters": { + "append_neighbors": 0, + "distance": 0, + "get_neighbors": 10638, + "get_vector": 74034, + "query_distance": 74034, + "set_neighbors": 0, + "set_vector": 0 + }, + "misc": { + "cmps": 74034, + "hops": 10638 + }, + "recall": { + "average": 0.9680000000000004, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + } + } + ] + } + } +] \ No newline at end of file diff --git a/diskann-inmem/integration/jsons/integration.json b/diskann-inmem/integration/jsons/integration.json new file mode 100644 index 000000000..57f2a338c --- /dev/null +++ b/diskann-inmem/integration/jsons/integration.json @@ -0,0 +1,183 @@ +{ + "search_directories": [ + "yfcc" + ], + "output_directory": null, + "jobs": [ + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "groundtruth.bin", + "metric": "l2", + "queries": "yfcc_query_100.fbin", + "preprocess": [] + }, + "layer": { + "FullPrecision": { + "data_type": "f32" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + }, + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "groundtruth.bin", + "metric": "l2", + "queries": "yfcc_query_100.fbin", + "preprocess": [] + }, + "layer": { + "FullPrecision": { + "data_type": "f16" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + }, + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "groundtruth.bin", + "metric": "l2", + "queries": "yfcc_query_100.fbin", + "preprocess": [] + }, + "layer": { + "FullPrecision": { + "data_type": "u8" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + }, + { + "type": "integration-test", + "content": { + "build": { + "alpha": 1.2000000476837158, + "l_build": 20, + "max_degree": 20, + "pruned_degree": 16 + }, + "data": { + "data": "yfcc_10k.fbin", + "data_type": "f32", + "groundtruth": "groundtruth.bin", + "metric": "l2", + "queries": "yfcc_query_100.fbin", + "preprocess": [ + "halve", + "floor" + ] + }, + "layer": { + "FullPrecision": { + "data_type": "i8" + } + }, + "search": { + "knn": [ + { + "beam_width": null, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 50 + }, + { + "beam_width": 3, + "knn": 10, + "search_l": 100 + } + ] + } + } + } + ] +} diff --git a/diskann-inmem/integration/jsons/store-stress-test.json b/diskann-inmem/integration/jsons/store-stress-test.json new file mode 100644 index 000000000..b04a008fa --- /dev/null +++ b/diskann-inmem/integration/jsons/store-stress-test.json @@ -0,0 +1,19 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "store-stress", + "content": { + "readers": 4, + "writers": 2, + "retirers": 1, + "capacity": 512, + "entry_bytes": 64, + "low_watermark": 128, + "duration_secs": 2, + "max_ops": 2000000, + "seed": 11939873485092837375 + } + } + ] +} diff --git a/diskann-inmem/integration/jsons/store-stress.json b/diskann-inmem/integration/jsons/store-stress.json new file mode 100644 index 000000000..6f1d2b836 --- /dev/null +++ b/diskann-inmem/integration/jsons/store-stress.json @@ -0,0 +1,19 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "store-stress", + "content": { + "readers": 8, + "writers": 4, + "retirers": 2, + "capacity": 4096, + "entry_bytes": 128, + "low_watermark": 1024, + "duration_secs": 10, + "max_ops": 50000000, + "seed": 11939873485092837375 + } + } + ] +} diff --git a/diskann-inmem/integration/main.rs b/diskann-inmem/integration/main.rs new file mode 100644 index 000000000..ddc69289d --- /dev/null +++ b/diskann-inmem/integration/main.rs @@ -0,0 +1,250 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod index; +mod store; +mod support; + +use diskann_benchmark_runner::{App, Registry, output}; + +/// Build a [`Registry`] with all integration benchmarks registered. +fn registry() -> anyhow::Result { + let mut registry = Registry::new(); + registry.register("store-stress", store::StoreStress)?; + index::register(&mut registry)?; + Ok(registry) +} + +fn main() -> anyhow::Result<()> { + let app = App::parse(); + app.run(®istry()?, &mut output::default()) +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::Path; + + use diskann_benchmark_runner::{ + app::{Check, Commands}, + output::Memory, + }; + use diskann_utils::test_data_root; + use serde::{Deserialize, Serialize}; + use serde_json::Value; + + // Environment variable used to regenerate committed regression baselines. + const DISKANN_TEST_ENV: &str = "DISKANN_TEST"; + + // Return `true` if `DISKANN_TEST=overwrite` is set, instructing regression tests to + // overwrite their committed baselines instead of checking against them. + // + // If `DISKANN_TEST` is set to anything other than `overwrite`, panic. + fn overwrite_baselines() -> bool { + match std::env::var(DISKANN_TEST_ENV) { + Ok(v) if v == "overwrite" => true, + Ok(v) => { + panic!("unknown value for {DISKANN_TEST_ENV}: \"{v}\". Expected \"overwrite\"") + } + Err(std::env::VarError::NotPresent) => false, + Err(std::env::VarError::NotUnicode(_)) => { + panic!("value for {DISKANN_TEST_ENV} is not unicode") + } + } + } + + // The directory containing the committed example input files. + fn example_directory() -> std::path::PathBuf { + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("integration") + .join("jsons") + } + + // TODO: add first class `diskann-benchmark-runner` support for this. + fn load_from_file(path: &std::path::Path) -> T + where + T: for<'a> Deserialize<'a>, + { + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + serde_json::from_reader(reader).unwrap() + } + + fn value_from_file(path: &std::path::Path) -> serde_json::Value { + load_from_file(path) + } + + fn save_to_file(path: &std::path::Path, value: &T, force: bool) + where + T: Serialize + ?Sized, + { + if path.exists() && !force { + panic!("path {} already exists!", path.display()); + } + let buffer = std::fs::File::create(path).unwrap(); + serde_json::to_writer_pretty(buffer, value).unwrap(); + } + + fn prefix_search_directories(raw: &mut serde_json::Value, root: &std::path::Path) { + let key = "search_directories"; + if let serde_json::Value::Object(obj) = raw { + let value = obj + .get_mut(key) + .expect("key \"search-directories\" should exist"); + if let serde_json::Value::Array(directories) = value { + for value in directories.iter_mut() { + if let serde_json::Value::String(dir) = value { + *dir = root.join(&dir).to_str().unwrap().into(); + } + } + } else { + panic!("Expected an Array - got {}", raw); + } + } else { + panic!("Expected an Object - got {}", raw); + } + } + + fn prepend(input: &Path, output: &Path, root: &Path) { + let mut v = value_from_file(input); + prefix_search_directories(&mut v, root); + save_to_file(output, &v, false); + } + + // Drive the named example through the full runner flow: load the JSON input file, + // dispatch through the registry, run the benchmark, and write results to disk. + fn run_example(name: &str) { + let input_file = example_directory().join(name); + assert!(input_file.exists(), "missing example file: {input_file:?}"); + + let tempdir = tempfile::tempdir().unwrap(); + let modified_input_file = tempdir.path().join("input.json"); + let output_file = tempdir.path().join("output.json"); + + prepend(&input_file, &modified_input_file, &test_data_root()); + + let command = Commands::Run { + input_file: modified_input_file, + output_file: output_file.clone(), + dry_run: false, + // Unit tests are a debug build; bypass the runner's debug-mode guard. + allow_debug: true, + }; + let app = App::from_commands(command); + + let mut output = Memory::new(); + // A benchmark error (e.g. an invariant violation) propagates here and fails the test. + app.run(®istry().unwrap(), &mut output).unwrap(); + + assert!(output_file.exists(), "results file was not written"); + } + + // Drive the named example through the runner, then run a regression check comparing the + // freshly produced results against a committed baseline. + // + // By default this fails the test if the regression check reports a negative result. When + // `DISKANN_TEST=overwrite` is set, the committed baseline is instead overwritten with the + // freshly produced results (enabling future migrations) and no check is performed. + fn run_regression_example(input_name: &str, tolerances_name: &str, baseline_name: &str) { + let input_file = example_directory().join(input_name); + let tolerances_file = example_directory().join(tolerances_name); + let baseline_file = example_directory().join(baseline_name); + assert!(input_file.exists(), "missing example file: {input_file:?}"); + assert!( + tolerances_file.exists(), + "missing tolerances file: {tolerances_file:?}" + ); + + let tempdir = tempfile::tempdir().unwrap(); + let modified_input_file = tempdir.path().join("input.json"); + let output_file = tempdir.path().join("output.json"); + + prepend(&input_file, &modified_input_file, &test_data_root()); + + // Run the benchmark to produce the "after" results. + let command = Commands::Run { + input_file: modified_input_file.clone(), + output_file: output_file.clone(), + dry_run: false, + // Unit tests are a debug build; bypass the runner's debug-mode guard. + allow_debug: true, + }; + let mut output = Memory::new(); + App::from_commands(command) + .run(®istry().unwrap(), &mut output) + .unwrap(); + assert!(output_file.exists(), "results file was not written"); + + // In overwrite mode, replace the committed baseline and skip the check. + if overwrite_baselines() { + // When over-writing, we need to scrub the file paths of the test directory. + // + // Otherwise, we end up with absolute paths in the baselines. + let mut v = value_from_file(&output_file); + scrub(&mut v, &test_data_root()); + save_to_file(&baseline_file, &v, true); + + return; + } + + assert!( + baseline_file.exists(), + "missing baseline {baseline_file:?}; regenerate it with {DISKANN_TEST_ENV}=overwrite" + ); + + // Run the regression check. A negative result (or any error) propagates here and + // fails the test. + let command = Commands::Check(Check::Run { + tolerances: tolerances_file, + input_file: modified_input_file, + before: baseline_file, + after: output_file, + output_file: None, + }); + let mut output = Memory::new(); + + if let Err(err) = App::from_commands(command).run(®istry().unwrap(), &mut output) { + panic!( + "Regression check failed:\n\n{}\n\n{}", + err, + String::from_utf8(output.into_inner()).unwrap() + ); + } + } + + fn scrub(value: &mut Value, root: &Path) { + let mut values = vec![value]; + while let Some(value) = values.pop() { + match value { + Value::Null | Value::Bool(_) | Value::Number(_) => {} + Value::String(s) => { + *s = diskann_benchmark_runner::ux::scrub_path(s.clone(), root, ""); + } + Value::Array(v) => v.iter_mut().for_each(|v| values.push(v)), + Value::Object(m) => m.values_mut().for_each(|v| values.push(v)), + } + } + } + + #[test] + fn store_stress_integration() { + run_example("store-stress-test.json"); + } + + #[test] + #[cfg(not(miri))] + fn graph_index() { + run_regression_example( + "integration.json", + "checks.json", + "integration-baseline.json", + ); + } +} diff --git a/diskann-inmem/integration/store.rs b/diskann-inmem/integration/store.rs new file mode 100644 index 000000000..b26458b43 --- /dev/null +++ b/diskann-inmem/integration/store.rs @@ -0,0 +1,580 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Concurrency stress test for the in-memory [`Store`](diskann_inmem::integration::store::Store). +//! +//! Reader, writer, and retirer threads hammer the epoch-based store concurrently while a +//! per-guard invariant checker verifies the store's safety guarantees: +//! +//! 1. Reads are never torn. +//! 2. A readable value is stable for the lifetime of a single reader guard. +//! 3. A slot never resurrects (`readable -> unreadable -> readable`) within one guard. + +#![expect( + clippy::unwrap_used, + reason = "this code works mainly as an integration test" +)] + +use std::{ + collections::HashMap, + io::Write, + sync::{ + Mutex, + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering::Relaxed}, + }, + time::{Duration, Instant}, +}; + +use diskann_benchmark_runner::{ + Benchmark, Checker, Checkpoint, Input, Output, + benchmark::{FailureScore, MatchScore}, + utils::fmt::KeyValue, +}; +use rand::{Rng, SeedableRng, distr::Uniform, rngs::StdRng}; +use serde::{Deserialize, Serialize}; + +use diskann_inmem::integration::store::Store; + +/// Maximum number of concurrent reader guards supported by the epoch registry. +const GUARD_CAPACITY: usize = 256; + +/// Number of slots a reader inspects per guard. Kept small so guards are short-lived, +/// allowing the epoch to advance and reclamation to make progress. +const READER_WINDOW: usize = 64; + +/// Number of times a reader re-reads its window within a single guard. Re-reading is what +/// exercises the value-stability and no-resurrection invariants. +const READER_PASSES: usize = 4; + +/// How often (in retirer iterations) a retirer attempts to reclaim retired slots. +const RECLAIM_EVERY: u64 = 16; + +/////////// +// Input // +/////////// + +/// Configuration for a [`StoreStress`] run. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoreStressInput { + /// Number of reader threads. Must be below [`GUARD_CAPACITY`]. + readers: usize, + /// Number of writer threads. + writers: usize, + /// Number of retirer threads. + retirers: usize, + /// Number of writable (non-frozen) slots. + capacity: usize, + /// Bytes per entry. Must be a non-zero multiple of 8 (the stamp lane width). + entry_bytes: usize, + /// Retirers only retire while the live published population exceeds this watermark. + low_watermark: usize, + /// Wall-clock cap for the run, in seconds. Zero means unbounded (rely on `max_ops`). + duration_secs: u64, + /// Total-operation cap across all worker threads. Zero means unbounded (rely on + /// `duration_secs`). + max_ops: u64, + /// Seed for the worker pseudo-random number generators. + seed: u64, +} + +impl StoreStressInput { + fn check(self) -> anyhow::Result { + if self.readers == 0 || self.writers == 0 { + anyhow::bail!("`readers` and `writers` must be non-zero"); + } + if self.readers >= GUARD_CAPACITY { + anyhow::bail!( + "`readers` ({}) must be below the epoch guard capacity ({GUARD_CAPACITY})", + self.readers, + ); + } + if self.capacity == 0 { + anyhow::bail!("`capacity` must be non-zero"); + } + if self.entry_bytes == 0 || !self.entry_bytes.is_multiple_of(8) { + anyhow::bail!( + "`entry_bytes` ({}) must be a non-zero multiple of 8", + self.entry_bytes, + ); + } + if self.low_watermark > self.capacity { + anyhow::bail!( + "`low_watermark` ({}) must not exceed `capacity` ({})", + self.low_watermark, + self.capacity, + ); + } + if self.duration_secs == 0 && self.max_ops == 0 { + anyhow::bail!("at least one of `duration_secs` or `max_ops` must be non-zero"); + } + Ok(self) + } +} + +impl Input for StoreStressInput { + type Raw = Self; + + fn tag() -> &'static str { + "store-stress" + } + + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Self::check(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { + StoreStressInput { + readers: 8, + writers: 4, + retirers: 2, + capacity: 4096, + entry_bytes: 128, + low_watermark: 1024, + duration_secs: 5, + max_ops: 50_000_000, + seed: 0xA5A5_1234_DEAD_BEEF, + } + } +} + +impl std::fmt::Display for StoreStressInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("readers", &self.readers); + kv.push("writers", &self.writers); + kv.push("retirers", &self.retirers); + kv.push("capacity", &self.capacity); + kv.push("entry_bytes", &self.entry_bytes); + kv.push("low_watermark", &self.low_watermark); + kv.push("duration_secs", &self.duration_secs); + kv.push("max_ops", &self.max_ops); + kv.push("seed", &self.seed); + write!(f, "{}", kv) + } +} + +//////////// +// Output // +//////////// + +/// Summary statistics produced by a [`StoreStress`] run. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoreStressStats { + elapsed_secs: f64, + reads: u64, + acquires_ok: u64, + acquires_fail: u64, + retires_ok: u64, + retires_fail: u64, + reclaims: u64, + /// Observed `readable -> unreadable` transitions across all reader guards. + transitions: u64, + /// Peak observed live (published, not-yet-retired) population. + peak_live: usize, +} + +impl std::fmt::Display for StoreStressStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut kv = KeyValue::new(); + kv.push("elapsed_secs", &self.elapsed_secs); + kv.push("reads", &self.reads); + kv.push("acquires_ok", &self.acquires_ok); + kv.push("acquires_fail", &self.acquires_fail); + kv.push("retires_ok", &self.retires_ok); + kv.push("retires_fail", &self.retires_fail); + kv.push("reclaims", &self.reclaims); + kv.push("transitions", &self.transitions); + kv.push("peak_live", &self.peak_live); + write!(f, "{}", kv) + } +} + +///////////// +// Payload // +///////////// + +/// Fill `buf` with `stamp` replicated across every 8-byte lane. +fn write_stamp(buf: &mut [u8], stamp: u64) { + let bytes = stamp.to_ne_bytes(); + for lane in buf.chunks_exact_mut(8) { + lane.copy_from_slice(&bytes); + } +} + +/// Read the stamp from `buf`, returning `Err` if any 8-byte lane disagrees (a torn read). +fn read_stamp(buf: &[u8]) -> Result { + let (lanes, _) = buf.as_chunks::<8>(); + let mut lanes = lanes.iter(); + let first = u64::from_ne_bytes(*lanes.next().ok_or(())?); + for lane in lanes { + if u64::from_ne_bytes(*lane) != first { + return Err(()); + } + } + Ok(first) +} + +//////////////// +// Invariants // +//////////////// + +/// Per-guard observation of a single slot. +#[derive(Debug, Clone, Copy)] +enum SlotObservations { + /// The slot was observed readable with the given stamp. + Readable(u64), + /// The slot was observed readable and then became unreadable (retired). + Retired, +} + +/// Feed a single observation of slot `i` into the per-guard checker, recording a violation +/// on the shared state if a safety invariant is broken. +fn observe( + shared: &Shared, + observed: &mut HashMap, + i: usize, + read: Option<&[u8]>, +) { + match (observed.get(&i).copied(), read) { + // Not yet observed readable; an unreadable slot tells us nothing actionable. + (None, None) => {} + // First readable observation: record the stamp (after a tearing check). + (None, Some(bytes)) => match read_stamp(bytes) { + Ok(stamp) => { + observed.insert(i, SlotObservations::Readable(stamp)); + } + Err(()) => record_violation(shared, format!("torn read at slot {i}")), + }, + // Still readable: the value must be identical and untorn. + (Some(SlotObservations::Readable(prev)), Some(bytes)) => match read_stamp(bytes) { + Ok(stamp) if stamp != prev => record_violation( + shared, + format!("slot {i} value changed within guard: {prev} -> {stamp}"), + ), + Ok(_) => {} + Err(()) => record_violation(shared, format!("torn read at slot {i}")), + }, + // Readable -> unreadable: an allowed, terminal transition. + (Some(SlotObservations::Readable(_)), None) => { + observed.insert(i, SlotObservations::Retired); + shared.transitions.fetch_add(1, Relaxed); + } + // Resurrection: a slot that retired came back to life within the same guard. + (Some(SlotObservations::Retired), Some(_)) => record_violation( + shared, + format!("resurrection at slot {i}: unreadable -> readable within one guard"), + ), + (Some(SlotObservations::Retired), None) => {} + } +} + +//////////// +// Shared // +//////////// + +struct Local<'a> { + counter: u64, + parent: &'a AtomicU64, +} + +impl<'a> Local<'a> { + fn new(parent: &'a AtomicU64) -> Self { + Self { counter: 0, parent } + } + + fn add(&mut self, by: u64) { + self.counter += by; + } +} + +impl Drop for Local<'_> { + fn drop(&mut self) { + self.parent.fetch_add(self.counter, Relaxed); + } +} + +struct LocalMax<'a> { + max: usize, + parent: &'a AtomicUsize, +} + +impl<'a> LocalMax<'a> { + fn new(parent: &'a AtomicUsize) -> Self { + Self { max: 0, parent } + } + + fn max(&mut self, m: usize) { + self.max = self.max.max(m); + } +} + +impl Drop for LocalMax<'_> { + fn drop(&mut self) { + self.parent.fetch_max(self.max, Relaxed); + } +} + +/// State shared by all worker threads for the duration of a run. +struct Shared { + store: Store, + slots: usize, + readable: Uniform, + writable: Uniform, + low_watermark: usize, + max_ops: u64, + deadline: Instant, + + stop: AtomicBool, + violation: Mutex>, + + stamp: AtomicU64, + live: AtomicUsize, + peak_live: AtomicUsize, + + ops: AtomicU64, + reads: AtomicU64, + acquires_ok: AtomicU64, + acquires_fail: AtomicU64, + retires_ok: AtomicU64, + retires_fail: AtomicU64, + reclaims: AtomicU64, + transitions: AtomicU64, +} + +/// Record an observed invariant violation and signal all workers to stop. +fn record_violation(shared: &Shared, message: String) { + let mut slot = shared.violation.lock().unwrap(); + slot.push(message); + shared.stop.store(true, Relaxed); +} + +/// Return `true` once any termination condition is met. +fn should_stop(shared: &Shared) -> bool { + shared.stop.load(Relaxed) + || shared.ops.load(Relaxed) >= shared.max_ops + || Instant::now() >= shared.deadline +} + +///////////// +// Workers // +///////////// + +fn writer(shared: &Shared) { + let mut ops = Local::new(&shared.ops); + let mut acquires_ok = Local::new(&shared.acquires_ok); + let mut acquires_fail = Local::new(&shared.acquires_fail); + + let mut peak_live = LocalMax::new(&shared.peak_live); + + while !should_stop(shared) { + ops.add(1); + match shared.store.acquire() { + Some(mut writer) => { + let stamp = shared.stamp.fetch_add(1, Relaxed); + write_stamp(writer.as_mut_slice(), stamp); + writer.publish(); + + let live = shared.live.fetch_add(1, Relaxed) + 1; + peak_live.max(live); + acquires_ok.add(1); + } + None => { + acquires_fail.add(1); + std::thread::yield_now(); + } + } + } +} + +fn retirer(shared: &Shared, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let mut iteration: u64 = 0; + + let mut retires_ok = Local::new(&shared.retires_ok); + let mut retires_fail = Local::new(&shared.retires_fail); + let mut reclaims = Local::new(&shared.reclaims); + + while !should_stop(shared) { + shared.ops.fetch_add(1, Relaxed); + iteration += 1; + + // Flow control: keep a steady readable population. + if shared.live.load(Relaxed) > shared.low_watermark { + let i = rng.sample(shared.writable); + if shared.store.retire(i) { + shared.live.fetch_sub(1, Relaxed); + retires_ok.add(1); + } else { + retires_fail.add(1); + } + } + + if iteration.is_multiple_of(RECLAIM_EVERY) + && let Some(reclaimed) = shared.store.reclaim() + { + reclaims.add(reclaimed as u64); + } + + std::thread::yield_now(); + } +} + +fn reader(shared: &Shared, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let slots = shared.slots; + let window = READER_WINDOW.min(slots); + let mut observations = HashMap::with_capacity(window); + + let mut ops = Local::new(&shared.ops); + let mut reads = Local::new(&shared.reads); + + while !should_stop(shared) { + ops.add(1); + let Some(guard) = shared.store.reader() else { + // All guard slots are occupied; back off and retry. + std::thread::yield_now(); + continue; + }; + + observations.clear(); + let start = rng.sample(shared.readable); + for _ in 0..READER_PASSES { + for k in 0..window { + let i = (start + k) % slots; + observe(shared, &mut observations, i, guard.read(i)); + reads.add(1); + } + } + } +} + +/////////////// +// Benchmark // +/////////////// + +/// The store concurrency stress benchmark. +#[derive(Debug)] +pub struct StoreStress; + +impl Benchmark for StoreStress { + type Input = StoreStressInput; + type Output = StoreStressStats; + + fn try_match(&self, _input: &StoreStressInput) -> Result { + Ok(MatchScore(0)) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + _input: Option<&StoreStressInput>, + ) -> std::fmt::Result { + write!( + f, + "concurrency stress test for the in-memory store (readers/writers/retirers)" + ) + } + + fn run( + &self, + input: &StoreStressInput, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + let store = Store::new(input.capacity, input.entry_bytes); + let writable = store.writable(); + let slots = store.slots(); + let start = Instant::now(); + + let shared = Shared { + store, + slots, + readable: Uniform::new(0, slots)?, + writable: Uniform::try_from(writable)?, + low_watermark: input.low_watermark, + max_ops: if input.max_ops == 0 { + u64::MAX + } else { + input.max_ops + }, + deadline: if input.duration_secs == 0 { + // Effectively unbounded; the op cap terminates the run. + start + Duration::from_secs(u64::from(u32::MAX)) + } else { + start + Duration::from_secs(input.duration_secs) + }, + stop: AtomicBool::new(false), + violation: Mutex::new(Vec::new()), + // Stamp 0 is reserved for the zeroed frozen point. + stamp: AtomicU64::new(1), + live: AtomicUsize::new(0), + peak_live: AtomicUsize::new(0), + ops: AtomicU64::new(0), + reads: AtomicU64::new(0), + acquires_ok: AtomicU64::new(0), + acquires_fail: AtomicU64::new(0), + retires_ok: AtomicU64::new(0), + retires_fail: AtomicU64::new(0), + reclaims: AtomicU64::new(0), + transitions: AtomicU64::new(0), + }; + + writeln!(output, "{}", input)?; + + std::thread::scope(|scope| { + let shared = &shared; + for _ in 0..input.writers { + scope.spawn(move || writer(shared)); + } + for t in 0..input.retirers { + let seed = input.seed ^ (0x2000_0000 + t as u64); + scope.spawn(move || retirer(shared, seed)); + } + for t in 0..input.readers { + let seed = input.seed ^ (0x4000_0000 + t as u64); + scope.spawn(move || reader(shared, seed)); + } + }); + + let errors: Vec<_> = std::mem::take(&mut *shared.violation.lock().unwrap()); + if !errors.is_empty() { + anyhow::bail!("invariants violated: {:?}", errors); + } + + let elapsed = start.elapsed(); + let stats = StoreStressStats { + elapsed_secs: elapsed.as_secs_f64(), + reads: shared.reads.load(Relaxed), + acquires_ok: shared.acquires_ok.load(Relaxed), + acquires_fail: shared.acquires_fail.load(Relaxed), + retires_ok: shared.retires_ok.load(Relaxed), + retires_fail: shared.retires_fail.load(Relaxed), + reclaims: shared.reclaims.load(Relaxed), + transitions: shared.transitions.load(Relaxed), + peak_live: shared.peak_live.load(Relaxed), + }; + + writeln!(output, "{}", stats)?; + Ok(stats) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn make_sure_example_parses() { + let _ = StoreStressInput::check(StoreStressInput::example()).unwrap(); + } +} diff --git a/diskann-inmem/integration/support/check.rs b/diskann-inmem/integration/support/check.rs new file mode 100644 index 000000000..8e8ce2121 --- /dev/null +++ b/diskann-inmem/integration/support/check.rs @@ -0,0 +1,710 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Baseline Checking +//! +//! The [`Regression`](diskann_benchmark_runner::benchmark::Regression) provides a means +//! of performing before/after comparisons against previously generated results. However, +//! presentation of these results is largely left to the devices of the implementors. +//! +//! This module provides a means of aggregating all match failures (if any) and presenting +//! all failures as a single unit. + +use std::{ + borrow::Cow, + fmt::{Display, Write}, +}; + +use diskann_benchmark_runner::{benchmark::PassFail, utils::fmt::Table}; +use serde::{Serialize, Serializer}; + +/// Perform a baseline check on `self` and a previously saved result. +pub(crate) trait CheckMatch { + fn check_match(&self, previous: &Self) -> Match; +} + +/// The result of a baseline check. +#[must_use = "this is a result type"] +#[derive(Debug, Serialize)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum Match { + /// Successful match. + Ok, + + /// A mismatch on a specific field. + Mismatch { + got: String, + expected: String, + remark: Option>, + }, + + /// A collection of mismatches for an aggregate data type or collection. + /// + /// Use [`MatchBuilder`] to easier construction. + Nested { + children: Vec<(Key, Match)>, + remark: Option>, + }, +} + +impl Match { + /// Return `true` if `self` is [`Match::Ok`]. + #[must_use = "this has no side-effects"] + pub(crate) fn is_ok(&self) -> bool { + matches!(self, Self::Ok) + } + + /// Record a single mismatch between the retrieved value `got` and the `expected` result. + pub(crate) fn mismatch(got: &dyn Display, expected: &dyn Display) -> Self { + Self::mismatch_with_remark(got, expected, None) + } + + /// Record a single mismatch between the retrieved value `got` and the `expected` result + /// with an additional optional remark. + /// + /// The remark can be used for contexts where matches are more complex than simple + /// equality. + pub(crate) fn mismatch_with_remark( + got: &dyn Display, + expected: &dyn Display, + remark: Option>, + ) -> Self { + Self::Mismatch { + expected: expected.to_string(), + got: got.to_string(), + remark, + } + } + + /// Convert `self` into a [`PassFail`] for regression checks. + /// + /// Returns `PassFail::Pass` only if `self.is_ok`. + pub(crate) fn pass_fail(self) -> PassFail { + if self.is_ok() { + PassFail::Pass(self) + } else { + PassFail::Fail(self) + } + } +} + +impl std::fmt::Display for Match { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ok => f.write_str("ok"), + Self::Mismatch { + got, + expected, + remark, + } => { + let header = ["got", "expected", "remark"]; + let mut table = Table::new(header, 1); + let mut row = table.row(0); + row.insert(got.clone(), 0); + row.insert(expected.clone(), 1); + if let Some(remark) = remark { + row.insert(remark.clone(), 2); + } + + table.fmt(f) + } + Self::Nested { children, remark } => { + let mut records = Vec::new(); + if let Some(remark) = remark { + records.push(Record { + path: String::new(), + got: "", + expected: "", + remark, + }); + } + + let mut buf = String::new(); + gather_mismatches(children, &mut records, Stack::new(&mut buf)); + + let mut table = Table::new(["path", "got", "expected", "remark"], records.len()); + for (i, r) in records.into_iter().enumerate() { + let mut row = table.row(i); + row.insert(r.path, 0); + row.insert(r.got.to_owned(), 1); + row.insert(r.expected.to_owned(), 2); + row.insert(r.remark.to_owned(), 3); + } + + table.fmt(f) + } + } + } +} + +fn gather_mismatches<'a>( + mismatches: &'a [(Key, Match)], + records: &mut Vec>, + mut path: Stack<'_>, +) { + for (k, m) in mismatches.iter() { + match m { + Match::Ok => continue, + Match::Mismatch { + got, + expected, + remark, + } => { + let record = Record { + path: path.push(k).get(), + got, + expected, + remark: remark.as_deref().unwrap_or(""), + }; + records.push(record); + } + Match::Nested { children, remark } => { + let path = path.push(k); + + if let Some(remark) = remark { + records.push(Record { + path: path.get(), + got: "", + expected: "", + remark, + }) + } + + gather_mismatches(children, records, path) + } + } + } +} + +#[derive(Debug)] +struct Stack<'a> { + s: &'a mut String, + len: usize, +} + +impl<'a> Stack<'a> { + fn new(s: &'a mut String) -> Self { + s.clear(); + Self { s, len: 0 } + } + + #[expect(clippy::unwrap_used, reason = "formatting shouldn't be failing here")] + fn push(&mut self, key: &Key) -> Stack<'_> { + let len = self.s.len(); + if len == 0 { + write!(self.s, "{}", key).unwrap(); + } else { + write!(self.s, ".{}", key).unwrap(); + } + + Stack { s: self.s, len } + } + + fn get(&self) -> String { + self.s.clone() + } +} + +impl Drop for Stack<'_> { + fn drop(&mut self) { + self.s.truncate(self.len) + } +} + +#[derive(Debug)] +struct Record<'a> { + path: String, + got: &'a str, + expected: &'a str, + remark: &'a str, +} + +///////// +// Key // +///////// + +/// A key to develop the full hierarchical path for a match. +/// +/// Keys can either be strings or positional indices. The latter are used when traversing +/// arrays. +#[derive(Debug, Clone, Eq, PartialEq)] +pub(crate) enum Key { + Str(&'static str), + Position(usize), + String(String), +} + +impl std::fmt::Display for Key { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Str(s) => f.write_str(s), + Self::Position(i) => write!(f, "{}", i), + Self::String(s) => f.write_str(s), + } + } +} + +impl Serialize for Key { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::Str(s) => serializer.serialize_str(s), + Self::Position(i) => serializer.serialize_u64(*i as u64), + Self::String(s) => serializer.serialize_str(s), + } + } +} + +impl From<&'static str> for Key { + fn from(s: &'static str) -> Key { + Key::Str(s) + } +} + +impl From for Key { + fn from(i: usize) -> Key { + Key::Position(i) + } +} + +impl From for Key { + fn from(s: String) -> Key { + Key::String(s) + } +} + +///////////// +// Builder // +///////////// + +/// A utility for building a nested [`Match`]. +#[derive(Debug)] +pub(crate) struct MatchBuilder { + children: Vec<(Key, Match)>, +} + +impl MatchBuilder { + /// Construct a new empty collection of matches. + pub(crate) fn new() -> Self { + Self { + children: Vec::new(), + } + } + + /// Push the [`Match`] into the collection only if [`Match::is_ok`] fails. + pub(crate) fn push(&mut self, key: Key, child: Match) { + if !child.is_ok() { + self.children.push((key, child)); + } + } + + /// Package the collection of matches into a single [`Match`]. + /// + /// If no failing matches have been aggregated, returns [`Match::Ok`]. + pub(crate) fn finish(self) -> Match { + self.finish_with_remark(None) + } + + /// Package the collection of matches into a single [`Match`] with a remark. + /// + /// If no failing matches have been aggregated, returns [`Match::Ok`]. + pub(crate) fn finish_with_remark(self, remark: Option>) -> Match { + if self.children.is_empty() { + Match::Ok + } else { + Match::Nested { + children: self.children, + remark, + } + } + } +} + +macro_rules! check_match_impl { + ($T:ty) => { + impl CheckMatch for $T { + fn check_match( + &self, + previous: &Self, + ) -> Match { + if self == previous { + Match::Ok + } else { + Match::mismatch(self, previous) + } + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(check_match_impl!($Ts);)+ + } +} + +check_match_impl!( + bool, u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, f32, f64, &str, String +); + +impl CheckMatch for [T] +where + T: CheckMatch, +{ + fn check_match(&self, previous: &[T]) -> Match { + if self.len() != previous.len() { + return Match::mismatch_with_remark( + &self.len(), + &previous.len(), + Some("number of results is different between runs".into()), + ); + } + + let mut builder = MatchBuilder::new(); + for (i, (got, expected)) in std::iter::zip(self.iter(), previous.iter()).enumerate() { + builder.push(Key::from(i), got.check_match(expected)); + } + + builder.finish() + } +} + +impl CheckMatch for Vec +where + T: CheckMatch, +{ + fn check_match(&self, previous: &Vec) -> Match { + self.as_slice().check_match(previous.as_slice()) + } +} + +//--------// +// Macros // +//--------// + +macro_rules! check_all_fields { + ($self:expr, $prev:expr, { $($field:ident),+ $(,)? } $(,)?) => {{ + let Self { $($field),+ } = $self; + let mut builder = $crate::support::check::MatchBuilder::new(); + $( + builder.push( + stringify!($field).into(), + <_ as $crate::support::check::CheckMatch>::check_match( + $field, + &$prev.$field + ), + ); + )+ + builder + }}; +} + +pub(crate) use check_all_fields; + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + //-------// + // Match // + //-------// + + #[test] + fn match_is_ok() { + assert!(Match::Ok.is_ok()); + assert!(!Match::mismatch(&1, &2).is_ok()); + } + + #[test] + fn mismatch_records_got_and_expected() { + match Match::mismatch(&1, &2) { + Match::Mismatch { + got, + expected, + remark, + } => { + assert_eq!(got, "1"); + assert_eq!(expected, "2"); + assert!(remark.is_none()); + } + other => panic!("expected Mismatch, got {other:?}"), + } + } + + #[test] + fn mismatch_with_remark_records_remark() { + match Match::mismatch_with_remark(&"a", &"b", Some("note".into())) { + Match::Mismatch { + got, + expected, + remark, + } => { + assert_eq!(got, "a"); + assert_eq!(expected, "b"); + assert_eq!(remark.as_deref(), Some("note")); + } + other => panic!("expected Mismatch, got {other:?}"), + } + } + + #[test] + fn pass_fail_follows_is_ok() { + assert!(matches!(Match::Ok.pass_fail(), PassFail::Pass(Match::Ok))); + + assert!(matches!( + Match::mismatch(&1, &2).pass_fail(), + PassFail::Fail(Match::Mismatch { .. }) + )); + + let mut builder = MatchBuilder::new(); + builder.push(Key::from("test"), Match::mismatch(&1, &2)); + builder.push(Key::from("test2"), Match::mismatch(&2, &3)); + let mismatch = builder.finish(); + + assert!(matches!(mismatch, Match::Nested { .. })); + assert!(matches!( + mismatch.pass_fail(), + PassFail::Fail(Match::Nested { .. }) + )); + } + + //------------// + // CheckMatch // + //------------// + + #[test] + fn primitive_check_match() { + assert!(1u32.check_match(&1u32).is_ok()); + assert!(!2u32.check_match(&3u32).is_ok()); + assert!("x".check_match(&"x").is_ok()); + assert!(!"x".check_match(&"y").is_ok()); + } + + #[test] + fn slice_check_match_equal() { + let a = vec![1u32, 2, 3]; + let b = vec![1u32, 2, 3]; + assert!(a.check_match(&b).is_ok()); + } + + #[test] + fn slice_check_match_length_mismatch() { + let a = vec![1u32, 2, 3]; + let b = vec![1u32, 2]; + match a.check_match(&b) { + Match::Mismatch { + got, + expected, + remark, + } => { + assert_eq!(got, "3"); + assert_eq!(expected, "2"); + assert!(remark.is_some()); + } + other => panic!("expected length Mismatch, got {other:?}"), + } + } + + #[test] + fn slice_check_match_element_mismatch() { + let a = vec![1u32, 9, 3]; + let b = vec![1u32, 2, 3]; + match a.check_match(&b) { + Match::Nested { children, .. } => { + assert_eq!(children.len(), 1); + assert!(matches!(children[0].0, Key::Position(1))); + } + other => panic!("expected Nested, got {other:?}"), + } + } + + //--------------// + // MatchBuilder // + //--------------// + + #[test] + fn builder_empty_is_ok() { + assert!(MatchBuilder::new().finish().is_ok()); + } + + #[test] + fn builder_skips_ok_matches() { + let mut builder = MatchBuilder::new(); + builder.push("a".into(), Match::Ok); + builder.push("b".into(), Match::Ok); + assert!(builder.finish().is_ok()); + } + + #[test] + fn builder_collects_failures() { + let mut builder = MatchBuilder::new(); + builder.push("a".into(), Match::Ok); + builder.push("b".into(), Match::mismatch(&1, &2)); + match builder.finish() { + Match::Nested { children, remark } => { + assert_eq!(children.len(), 1); + assert!(remark.is_none()); + } + other => panic!("expected Nested, got {other:?}"), + } + } + + #[test] + fn builder_finish_with_remark() { + let mut builder = MatchBuilder::new(); + builder.push("b".into(), Match::mismatch(&1, &2)); + match builder.finish_with_remark(Some("ctx".into())) { + Match::Nested { remark, .. } => assert_eq!(remark.as_deref(), Some("ctx")), + other => panic!("expected Nested, got {other:?}"), + } + } + + //-----// + // Key // + //-----// + + #[test] + fn key_display() { + assert_eq!(Key::from("field").to_string(), "field"); + assert_eq!(Key::from(7usize).to_string(), "7"); + assert_eq!(Key::from(String::from("owned")).to_string(), "owned"); + } + + #[test] + fn key_serde() { + let k = serde_json::to_value(Key::Str("field")).unwrap(); + assert_eq!(k, serde_json::Value::String("field".into())); + + let k = serde_json::to_value(Key::Position(10)).unwrap(); + assert_eq!(k, serde_json::Value::Number(10.into())); + + let k = serde_json::to_value(Key::String("world".into())).unwrap(); + assert_eq!(k, serde_json::Value::String("world".into())); + } + + //---------// + // Display // + //---------// + + #[test] + fn display_ok() { + assert_eq!(Match::Ok.to_string(), "ok"); + } + + #[test] + fn display_nonnested() { + let mismatch = Match::mismatch_with_remark(&"hello", &1, Some("word".into())); + let rendered = mismatch.to_string(); + + let expected = r#" + got, expected, remark +=========================== +hello, 1, word +"#; + let expected = expected.strip_prefix('\n').unwrap(); + + println!("rendered = {:?}", rendered); + + let mut count = 0; + for (line, (got, expected)) in + std::iter::zip(rendered.lines(), expected.lines()).enumerate() + { + count += 1; + assert_eq!(got.trim(), expected.trim(), "failed on line {line}",); + } + assert_eq!(count, 3); + } + + #[test] + fn display_nested() { + // Build a nested match and ensure the hierarchical path is rendered. + let mut inner = MatchBuilder::new(); + inner.push(1usize.into(), Match::mismatch(&9, &2)); + inner.push( + "test".into(), + Match::mismatch_with_remark(&9, &2, Some("hello".into())), + ); + let nested = inner.finish_with_remark(Some("some remark".into())); + + let mut outer = MatchBuilder::new(); + outer.push("results".into(), nested); + let rendered = outer + .finish_with_remark(Some("final remarks".into())) + .to_string(); + + let expected = r#" + path, got, expected, remark + ================================================ + , , , final remarks + results, , , some remark + results.1, 9, 2, + results.test, 9, 2, hello + "#; + + let expected = expected.strip_prefix('\n').unwrap(); + + println!("rendered = {:?}", rendered); + + let mut count = 0; + for (line, (got, expected)) in + std::iter::zip(rendered.lines(), expected.lines()).enumerate() + { + count += 1; + assert_eq!(got.trim(), expected.trim(), "failed on line {line}",); + } + assert_eq!(count, 6); + } + + //-------------------// + // check_all_fields! // + //-------------------// + + #[derive(Debug)] + struct Sample { + a: u32, + b: String, + } + + impl CheckMatch for Sample { + fn check_match(&self, previous: &Self) -> Match { + check_all_fields!(self, previous, { a, b }).finish() + } + } + + #[test] + fn check_all_fields_equal() { + let x = Sample { + a: 1, + b: "hi".into(), + }; + let y = Sample { + a: 1, + b: "hi".into(), + }; + assert!(x.check_match(&y).is_ok()); + } + + #[test] + fn check_all_fields_reports_changed_field() { + let x = Sample { + a: 1, + b: "hi".into(), + }; + let y = Sample { + a: 1, + b: "bye".into(), + }; + match x.check_match(&y) { + Match::Nested { children, .. } => { + assert_eq!(children.len(), 1); + assert_eq!(children[0].0.to_string(), "b"); + } + other => panic!("expected Nested, got {other:?}"), + } + } +} diff --git a/diskann-inmem/integration/support/datatype.rs b/diskann-inmem/integration/support/datatype.rs new file mode 100644 index 000000000..fe61de539 --- /dev/null +++ b/diskann-inmem/integration/support/datatype.rs @@ -0,0 +1,691 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::{ + sampling::medoid::ComputeMedoid, + views::{Matrix, MatrixView, MutMatrixView}, +}; +use half::f16; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +////////////// +// DataType // +////////////// + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum DataType { + F32, + F16, + U8, + I8, +} + +impl DataType { + fn as_str(self) -> &'static str { + match self { + Self::F32 => "f32", + Self::F16 => "f16", + Self::U8 => "u8", + Self::I8 => "i8", + } + } +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +pub(crate) trait AsDataType { + const DATA_TYPE: DataType; +} + +#[derive(Debug, Error)] +#[error("wrong data-type: expected {}, got {}", self.expected, self.got)] +pub(crate) struct WrongDataType { + expected: DataType, + got: DataType, +} + +impl WrongDataType { + fn new(expected: DataType, got: DataType) -> Self { + Self { expected, got } + } +} + +/////////// +// Slice // +/////////// + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Slice<'a> { + F32(&'a [f32]), + F16(&'a [f16]), + U8(&'a [u8]), + I8(&'a [i8]), +} + +impl<'a> Slice<'a> { + pub(crate) fn data_type(&self) -> DataType { + match self { + Self::F32(_) => DataType::F32, + Self::F16(_) => DataType::F16, + Self::U8(_) => DataType::U8, + Self::I8(_) => DataType::I8, + } + } + + pub(crate) fn len(&self) -> usize { + match self { + Self::F32(s) => s.len(), + Self::F16(s) => s.len(), + Self::U8(s) => s.len(), + Self::I8(s) => s.len(), + } + } + + pub(crate) fn try_cast(self) -> Result<&'a [T], WrongDataType> + where + T: FromSlice, + { + T::from_slice(self) + } +} + +pub(crate) trait FromSlice: Sized { + fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType>; +} + +////////////// +// SliceMut // +////////////// + +#[derive(Debug)] +pub(crate) enum SliceMut<'a> { + F32(&'a mut [f32]), + F16(&'a mut [f16]), + U8(&'a mut [u8]), + I8(&'a mut [i8]), +} + +fn map(dst: &mut [T], src: &[U], f: F) +where + T: std::fmt::Display + AsDataType, + U: std::fmt::Display + AsDataType + Copy, + F: Fn(U) -> T, +{ + std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| { + *d = f(*s); + }) +} + +fn try_map(dst: &mut [T], src: &[U], f: F) -> anyhow::Result<()> +where + T: std::fmt::Display + AsDataType, + U: std::fmt::Display + AsDataType + Copy, + F: Fn(U) -> Result, +{ + std::iter::zip(dst.iter_mut(), src.iter()).try_for_each(|(d, s)| { + let converted = match f(*s) { + Ok(c) => c, + Err(_) => anyhow::bail!( + "could not losslessly convert {} {} to {}", + U::DATA_TYPE, + s, + T::DATA_TYPE, + ), + }; + *d = converted; + Ok(()) + }) +} + +fn f32_to_f16(x: f32) -> Result { + let y = f16::from_f32(x); + let z = f32::from(y); + if z != x { Err(()) } else { Ok(y) } +} + +fn f32_to_u8(x: f32) -> Result { + let y = x as u8; + let z = f32::from(y); + if z != x { Err(()) } else { Ok(y) } +} + +fn f32_to_i8(x: f32) -> Result { + let y = x as i8; + let z = f32::from(y); + if z != x { Err(()) } else { Ok(y) } +} + +fn f16_to_u8(x: f16) -> Result { + f32_to_u8(x.into()) +} + +fn f16_to_i8(x: f16) -> Result { + f32_to_i8(x.into()) +} + +impl<'a> SliceMut<'a> { + fn len(&self) -> usize { + match self { + Self::F32(s) => s.len(), + Self::F16(s) => s.len(), + Self::U8(s) => s.len(), + Self::I8(s) => s.len(), + } + } + + pub(crate) fn convert_lossless(&mut self, rhs: Slice<'_>) -> anyhow::Result<()> { + if self.len() != rhs.len() { + anyhow::bail!( + "lhs len {} must be equal to rhs len {}", + self.len(), + rhs.len() + ); + } + + match (self, rhs) { + (SliceMut::F32(dst), Slice::F32(src)) => dst.copy_from_slice(src), + (SliceMut::F32(dst), Slice::F16(src)) => map(dst, src, |x| x.into()), + (SliceMut::F32(dst), Slice::U8(src)) => map(dst, src, |x| x.into()), + (SliceMut::F32(dst), Slice::I8(src)) => map(dst, src, |x| x.into()), + + (SliceMut::F16(dst), Slice::F32(src)) => try_map(dst, src, f32_to_f16)?, + (SliceMut::F16(dst), Slice::F16(src)) => dst.copy_from_slice(src), + (SliceMut::F16(dst), Slice::U8(src)) => map(dst, src, |x| x.into()), + (SliceMut::F16(dst), Slice::I8(src)) => map(dst, src, |x| x.into()), + + (SliceMut::U8(dst), Slice::F32(src)) => try_map(dst, src, f32_to_u8)?, + (SliceMut::U8(dst), Slice::F16(src)) => try_map(dst, src, f16_to_u8)?, + (SliceMut::U8(dst), Slice::U8(src)) => dst.copy_from_slice(src), + (SliceMut::U8(dst), Slice::I8(src)) => try_map(dst, src, |x| x.try_into())?, + + (SliceMut::I8(dst), Slice::F32(src)) => try_map(dst, src, f32_to_i8)?, + (SliceMut::I8(dst), Slice::F16(src)) => try_map(dst, src, f16_to_i8)?, + (SliceMut::I8(dst), Slice::U8(src)) => try_map(dst, src, |x| x.try_into())?, + (SliceMut::I8(dst), Slice::I8(src)) => dst.copy_from_slice(src), + }; + + Ok(()) + } +} + +///////////// +// Dataset // +///////////// + +#[derive(Debug)] +pub(crate) enum Dataset { + F32(Matrix), + F16(Matrix), + U8(Matrix), + I8(Matrix), +} + +impl Dataset { + pub(crate) fn nrows(&self) -> usize { + self.as_view().nrows() + } + + pub(crate) fn ncols(&self) -> usize { + self.as_view().ncols() + } + + pub(crate) fn as_view(&self) -> DatasetView<'_> { + match self { + Self::F32(m) => DatasetView::F32(m.as_view()), + Self::F16(m) => DatasetView::F16(m.as_view()), + Self::U8(m) => DatasetView::U8(m.as_view()), + Self::I8(m) => DatasetView::I8(m.as_view()), + } + } + + pub(crate) fn as_slice(&self) -> Slice<'_> { + match self { + Self::F32(m) => m.as_slice().into(), + Self::F16(m) => m.as_slice().into(), + Self::U8(m) => m.as_slice().into(), + Self::I8(m) => m.as_slice().into(), + } + } + + pub(crate) fn medoid(&self) -> Dataset { + self.as_view().medoid() + } + + pub(crate) fn preprocess(&mut self, op: &Preprocess) { + match self { + Self::F32(m) => op.apply(m.as_mut_view()), + Self::F16(m) => op.apply(m.as_mut_view()), + Self::U8(m) => op.apply(m.as_mut_view()), + Self::I8(m) => op.apply(m.as_mut_view()), + } + } +} + +/// Preprocess steps for [`Dataset`]s. +/// +/// These exist so we can coax `u8` data into a form compatible for testing `i8` data. +#[derive(Debug)] +pub(crate) enum Preprocess { + // Divide each component by 2. + Halve, + // Perform a `floor` operation on the each component. + Floor, +} + +trait Apply { + fn apply(&self, m: MutMatrixView<'_, T>); +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, f32>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| *v *= 0.5), + Self::Floor => m.as_mut_slice().iter_mut().for_each(|v| *v = v.floor()), + } + } +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, f16>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| { + *v = f16::from_f32(f32::from(*v) * 0.5); + }), + Self::Floor => m.as_mut_slice().iter_mut().for_each(|v| { + *v = f16::from_f32(f32::from(*v).floor()); + }), + } + } +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, u8>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| *v /= 2), + Self::Floor => {} + } + } +} + +impl Apply for Preprocess { + fn apply(&self, mut m: MutMatrixView<'_, i8>) { + match self { + Self::Halve => m.as_mut_slice().iter_mut().for_each(|v| *v /= 2), + Self::Floor => {} + } + } +} + +///////////////// +// DatasetView // +///////////////// + +#[derive(Debug, Clone, Copy)] +pub(crate) enum DatasetView<'a> { + F32(MatrixView<'a, f32>), + F16(MatrixView<'a, f16>), + U8(MatrixView<'a, u8>), + I8(MatrixView<'a, i8>), +} + +impl<'a> DatasetView<'a> { + pub(crate) fn data_type(&self) -> DataType { + match self { + Self::F32(_) => DataType::F32, + Self::F16(_) => DataType::F16, + Self::U8(_) => DataType::U8, + Self::I8(_) => DataType::I8, + } + } + + pub(crate) fn nrows(&self) -> usize { + match self { + Self::F32(m) => m.nrows(), + Self::F16(m) => m.nrows(), + Self::U8(m) => m.nrows(), + Self::I8(m) => m.nrows(), + } + } + + pub(crate) fn ncols(&self) -> usize { + match self { + Self::F32(m) => m.ncols(), + Self::F16(m) => m.ncols(), + Self::U8(m) => m.ncols(), + Self::I8(m) => m.ncols(), + } + } + + pub(crate) fn row(&self, i: usize) -> Option> { + match self { + Self::F32(m) => m.get_row(i).map(Slice::from), + Self::F16(m) => m.get_row(i).map(Slice::from), + Self::U8(m) => m.get_row(i).map(Slice::from), + Self::I8(m) => m.get_row(i).map(Slice::from), + } + } + + pub(crate) fn medoid(&self) -> Dataset { + match self { + Self::F32(v) => Matrix::row_vector(Box::from(f32::compute_medoid(*v))).into(), + Self::F16(v) => Matrix::row_vector(Box::from(f16::compute_medoid(*v))).into(), + Self::U8(v) => Matrix::row_vector(Box::from(u8::compute_medoid(*v))).into(), + Self::I8(v) => Matrix::row_vector(Box::from(i8::compute_medoid(*v))).into(), + } + } + + pub(crate) fn iter(&self) -> Iter<'_> { + Iter::new(self) + } +} + +pub(crate) struct Iter<'a> { + view: &'a DatasetView<'a>, + row: usize, +} + +impl<'a> Iter<'a> { + fn new(view: &'a DatasetView<'a>) -> Self { + Self { view, row: 0 } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Slice<'a>; + + fn next(&mut self) -> Option> { + let r = self.view.row(self.row)?; + self.row += 1; + Some(r) + } +} + +//------// +// Impl // +//------// + +macro_rules! define { + ($T:ty, $variant:ident) => { + impl AsDataType for $T { + const DATA_TYPE: DataType = DataType::$variant; + } + + impl<'a> From<&'a [$T]> for Slice<'a> { + fn from(s: &'a [$T]) -> Self { + Self::$variant(s) + } + } + + impl<'a> From<&'a mut [$T]> for SliceMut<'a> { + fn from(s: &'a mut [$T]) -> Self { + Self::$variant(s) + } + } + + impl FromSlice for $T { + fn from_slice(slice: Slice<'_>) -> Result<&[Self], WrongDataType> { + if let Slice::$variant(s) = slice { + Ok(s) + } else { + Err(WrongDataType::new(DataType::$variant, slice.data_type())) + } + } + } + + impl From> for Dataset { + fn from(m: Matrix<$T>) -> Self { + Self::$variant(m) + } + } + }; +} + +define!(f32, F32); +define!(f16, F16); +define!(u8, U8); +define!(i8, I8); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + fn matrix(data: &[T], nrows: usize, ncols: usize) -> Matrix + where + T: Copy, + { + Matrix::try_from(Box::from(data), nrows, ncols).unwrap() + } + + //----------// + // DataType // + //----------// + + #[test] + fn datatype_display() { + assert_eq!(DataType::F32.to_string(), "f32"); + assert_eq!(DataType::F16.to_string(), "f16"); + assert_eq!(DataType::U8.to_string(), "u8"); + assert_eq!(DataType::I8.to_string(), "i8"); + } + + //-------// + // Slice // + //-------// + + #[test] + fn slice_data_type_and_len() { + let f: &[f32] = &[1.0, 2.0, 3.0]; + let s = Slice::from(f); + assert_eq!(s.data_type(), DataType::F32); + assert_eq!(s.len(), 3); + + let u: &[u8] = &[1, 2]; + assert_eq!(Slice::from(u).data_type(), DataType::U8); + assert_eq!(Slice::from(u).len(), 2); + } + + #[test] + fn slice_try_cast_success() { + let f: &[f32] = &[1.0, 2.0]; + let s = Slice::from(f); + let out: &[f32] = s.try_cast().unwrap(); + assert_eq!(out, &[1.0, 2.0]); + } + + #[test] + fn slice_try_cast_wrong_type() { + let f: &[f32] = &[1.0, 2.0]; + let s = Slice::from(f); + let err = s.try_cast::().unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("u8"), "msg: {msg}"); + assert!(msg.contains("f32"), "msg: {msg}"); + } + + //----------// + // SliceMut // + //----------// + + #[test] + fn convert_lossless_same_type() { + let mut dst = [0.0f32; 3]; + let src: &[f32] = &[1.0, 2.0, 3.0]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [1.0, 2.0, 3.0]); + } + + #[test] + fn convert_lossless_widening() { + // u8 -> f32 is always lossless. + let mut dst = [0.0f32; 3]; + let src: &[u8] = &[1, 2, 250]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [1.0, 2.0, 250.0]); + + // i8 -> f16 is always lossless. + let mut dst = [f16::ZERO; 2]; + let src: &[i8] = &[-5, 7]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [f16::from_f32(-5.0), f16::from_f32(7.0)]); + } + + #[test] + fn convert_lossless_narrowing_exact() { + // Whole-valued, in-range f32 -> u8 is lossless. + let mut dst = [0u8; 3]; + let src: &[f32] = &[0.0, 12.0, 255.0]; + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap(); + assert_eq!(dst, [0, 12, 255]); + } + + #[test] + fn convert_lossless_narrowing_fraction_errors() { + let mut dst = [0u8; 2]; + let src: &[f32] = &[1.0, 0.5]; + let err = SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap_err(); + assert!(err.to_string().contains("losslessly"), "{err}"); + } + + #[test] + fn convert_lossless_signedness_errors() { + // Negative i8 cannot fit into u8. + let mut dst = [0u8; 2]; + let src: &[i8] = &[5, -1]; + assert!( + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .is_err() + ); + + // u8 > 127 cannot fit into i8. + let mut dst = [0i8; 2]; + let src: &[u8] = &[10, 200]; + assert!( + SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .is_err() + ); + } + + #[test] + fn convert_lossless_length_mismatch_errors() { + let mut dst = [0.0f32; 2]; + let src: &[f32] = &[1.0, 2.0, 3.0]; + let err = SliceMut::from(dst.as_mut_slice()) + .convert_lossless(Slice::from(src)) + .unwrap_err(); + assert!(err.to_string().contains("len"), "{err}"); + } + + //---------// + // Dataset // + //---------// + + #[test] + fn dataset_shape_and_views() { + let ds: Dataset = matrix(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).into(); + assert_eq!(ds.nrows(), 2); + assert_eq!(ds.ncols(), 3); + assert_eq!(ds.as_view().data_type(), DataType::F32); + assert_eq!(ds.as_slice().data_type(), DataType::F32); + assert_eq!(ds.as_slice().len(), 6); + } + + #[test] + fn dataset_medoid_shape() { + let ds: Dataset = matrix(&[1.0f32, 2.0, 3.0, 4.0], 2, 2).into(); + let medoid = ds.medoid(); + assert_eq!(medoid.nrows(), 1); + assert_eq!(medoid.ncols(), 2); + } + + #[test] + fn dataset_preprocess_halve() { + let mut ds: Dataset = matrix(&[2.0f32, 4.0, 6.0, 8.0], 2, 2).into(); + ds.preprocess(&Preprocess::Halve); + let slice: &[f32] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0]); + } + + //------------// + // Preprocess // + //------------// + + #[test] + fn preprocess_floor_f32() { + let mut ds: Dataset = matrix(&[1.7f32, 2.2, 3.9, 4.0], 1, 4).into(); + ds.preprocess(&Preprocess::Floor); + let slice: &[f32] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn preprocess_floor_integer_is_noop() { + let mut ds: Dataset = matrix(&[3u8, 7, 9, 11], 1, 4).into(); + ds.preprocess(&Preprocess::Floor); + let slice: &[u8] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[3, 7, 9, 11]); + } + + #[test] + fn preprocess_halve_integer() { + let mut ds: Dataset = matrix(&[4i8, 7, 8, 10], 1, 4).into(); + ds.preprocess(&Preprocess::Halve); + let slice: &[i8] = ds.as_slice().try_cast().unwrap(); + assert_eq!(slice, &[2, 3, 4, 5]); + } + + //-------------// + // DatasetView // + //-------------// + + #[test] + fn dataset_view_accessors() { + let ds: Dataset = matrix(&[1u8, 2, 3, 4, 5, 6], 2, 3).into(); + let view = ds.as_view(); + assert_eq!(view.data_type(), DataType::U8); + assert_eq!(view.nrows(), 2); + assert_eq!(view.ncols(), 3); + } + + #[test] + fn dataset_view_row() { + let ds: Dataset = matrix(&[1u8, 2, 3, 4, 5, 6], 2, 3).into(); + let view = ds.as_view(); + + let row1: &[u8] = view.row(1).unwrap().try_cast().unwrap(); + assert_eq!(row1, &[4, 5, 6]); + + assert!(view.row(2).is_none()); + } + + #[test] + fn dataset_view_medoid() { + let ds: Dataset = matrix(&[1i8, 2, 3, 4], 2, 2).into(); + let medoid = ds.as_view().medoid(); + assert_eq!(medoid.nrows(), 1); + assert_eq!(medoid.ncols(), 2); + } +} diff --git a/diskann-inmem/integration/support/io.rs b/diskann-inmem/integration/support/io.rs new file mode 100644 index 000000000..7b45018b5 --- /dev/null +++ b/diskann-inmem/integration/support/io.rs @@ -0,0 +1,59 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::{io::read_bin, views::Matrix}; +use half::f16; + +use super::datatype::{DataType, Dataset, Preprocess, SliceMut}; + +pub(crate) fn load_and_convert( + io: &mut IO, + src: DataType, + target: DataType, + ops: &[Preprocess], +) -> anyhow::Result +where + IO: std::io::Read + std::io::Seek, +{ + let mut data = match src { + DataType::F32 => Dataset::from(read_bin::(io)?), + DataType::F16 => Dataset::from(read_bin::(io)?), + DataType::U8 => Dataset::from(read_bin::(io)?), + DataType::I8 => Dataset::from(read_bin::(io)?), + }; + + for op in ops { + data.preprocess(op); + } + + if src == target { + return Ok(data); + } + + let dst = match target { + DataType::F32 => { + let mut dst = Matrix::new(0.0f32, data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; + Dataset::from(dst) + } + DataType::F16 => { + let mut dst = Matrix::new(f16::from_f32(0.0f32), data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; + Dataset::from(dst) + } + DataType::U8 => { + let mut dst = Matrix::new(0u8, data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; + Dataset::from(dst) + } + DataType::I8 => { + let mut dst = Matrix::new(0i8, data.nrows(), data.ncols()); + SliceMut::from(dst.as_mut_slice()).convert_lossless(data.as_slice())?; + Dataset::from(dst) + } + }; + + Ok(dst) +} diff --git a/diskann-inmem/integration/support/mod.rs b/diskann-inmem/integration/support/mod.rs new file mode 100644 index 000000000..982329f46 --- /dev/null +++ b/diskann-inmem/integration/support/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) mod check; +pub(crate) mod datatype; +pub(crate) mod io; +pub(crate) mod tolerance; diff --git a/diskann-inmem/integration/support/tolerance.rs b/diskann-inmem/integration/support/tolerance.rs new file mode 100644 index 000000000..7f8efa870 --- /dev/null +++ b/diskann-inmem/integration/support/tolerance.rs @@ -0,0 +1,32 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_benchmark_runner::{Checker, Input}; +use serde::{Deserialize, Serialize}; + +/// A tolerance [`Input`] for [`diskann_benchmark_runner::benchmark::Regression`]s that +/// do not need any external tolerances. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub(crate) struct Empty; + +impl Input for Empty { + type Raw = Self; + + fn tag() -> &'static str { + "empty-tolerance" + } + + fn from_raw(raw: Self::Raw, _: &mut Checker) -> anyhow::Result { + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { + Self + } +} diff --git a/diskann-inmem/src/buffer.rs b/diskann-inmem/src/buffer.rs new file mode 100644 index 000000000..32e6e4930 --- /dev/null +++ b/diskann-inmem/src/buffer.rs @@ -0,0 +1,539 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{alloc::Layout, marker::PhantomData, ptr::NonNull}; + +use crate::num::{Align, Bytes}; + +/// An unsynchronized row-store for raw data. +/// +/// The backing data is stored as raw pointers and interacted with via [`RawSlice`], which +/// is also raw pointer based. Careful use of this struct enables safe use of +/// [`RawSlice::as_slice`], [`RawSlice::as_mut_slice`], and other accesses from multiple +/// threads without undefined behavior. +/// +/// `Buffer` is unconditionally `Send` and `Sync`: it holds only a pointer plus metadata, +/// and the synchronization burden is shifted to users of [`RawSlice`]. +#[derive(Debug)] +pub(crate) struct Buffer { + ptr: NonNull, + stride: Bytes, + entries: usize, + layout: Layout, +} + +impl Buffer { + /// Construct a new [`Buffer`] capable of holding `entries` with each entry occupying + /// exactly `bytes_per_entry`. Subsequent entries are separated by exactly + /// `bytes_per_entry` bytes. The base point will be aligned to at least `align`. + /// + /// # Errors + /// + /// Returns an error if the number of bytes `bytes_per_entry * entries` rounded up to + /// the next multiple of `align` exceeds `isize::MAX`. + pub(crate) fn new( + entries: usize, + bytes_per_entry: Bytes, + align: Align, + ) -> Result { + // If we overflow `usize::MAX`, we will definitely overflow `isize::MAX`. + let bytes = bytes_per_entry.checked_mul(entries).ok_or(BufferError)?; + + // Since `align` is constrained to be a power of two, the only way this fails is + // if we overflow `isize::MAX`. + let layout = std::alloc::Layout::from_size_align(bytes.value(), align.value()) + .map_err(|_: std::alloc::LayoutError| BufferError)?; + + let ptr = if layout.size() == 0 { + std::ptr::dangling_mut() + } else { + // SAFETY: `layout.size()` is non-zero. + unsafe { std::alloc::alloc_zeroed(layout) } + }; + + let ptr = match NonNull::new(ptr) { + Some(ptr) => ptr, + None => std::alloc::handle_alloc_error(layout), + }; + + Ok(Self { + ptr, + stride: bytes_per_entry, + entries, + layout, + }) + } + + /// Return the number of entries in this [`Buffer`]. + #[inline] + pub(crate) fn len(&self) -> usize { + self.entries + } + + /// Return the number of bytes for each entry. + #[inline] + pub(crate) fn stride(&self) -> Bytes { + self.stride + } + + /// Return the `i`th entry without bounds checking. + /// + /// The returned [`RawSlice`] is guaranteed to have a length of [`Self::stride`] and + /// begin at `self.as_ptr().add(self.stride().value() * i)`. + /// + /// # Safety + /// + /// `i` must be less than [`len`](Self::len). + #[inline] + pub(crate) unsafe fn get_unchecked(&self, i: usize) -> RawSlice<'_> { + debug_assert!(i < self.entries); + + // SAFETY: The caller asserts that `i` is in-bounds; the computed pointer stays + // within a single allocated object. + let ptr = unsafe { self.ptr.add(self.stride().value() * i) }; + + RawSlice { + ptr, + len: self.stride, + _lifetime: PhantomData, + } + } + + #[cfg(test)] + pub(crate) fn get(&self, i: usize) -> Option> { + if i >= self.entries { + None + } else { + // SAFETY: We have validated that `i < self.entries`. This does two things: + // + // 1. Ensure that the multiplication will not overflow. + // 2. Ensures that the computed offset is within the original allocation. + Some(unsafe { self.get_unchecked(i) }) + } + } + + #[cfg(test)] + fn as_ptr(&self) -> *const u8 { + self.ptr.as_ptr().cast_const() + } +} + +impl Drop for Buffer { + fn drop(&mut self) { + // If the layout size is zero, there's nothing to do because we hold a dangling pointer. + if self.layout.size() != 0 { + // SAFETY: This is the same pointer and allocation that was previously returned + // from a successful `alloc_zeroed`. + unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) } + } + } +} + +#[derive(Debug)] +#[non_exhaustive] +pub(crate) struct BufferError; + +impl std::fmt::Display for BufferError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("requested allocation exceeds `isize::MAX`") + } +} + +impl std::error::Error for BufferError {} + +// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `RawSlice` +// that needs to be arbitrated. +unsafe impl Send for Buffer {} + +// SAFETY: We're safe to pass around the `Buffer`. It's just use of the returned `RawSlice` +// that needs to be arbitrated. +unsafe impl Sync for Buffer {} + +/// A raw entry in [`Buffer`]. +/// +/// The memory in the range `[RawSlice::as_ptr(), RawSlice::as_ptr().add(slice.len()))` is +/// guaranteed to be within a single alive allocation. +/// +/// This has borrowing semantics of a raw pointer. +#[derive(Debug)] +pub(crate) struct RawSlice<'a> { + ptr: NonNull, + len: Bytes, + _lifetime: PhantomData<&'a ()>, +} + +impl<'a> RawSlice<'a> { + /// Create a new [`RawSlice`]. + /// + /// # Safety + /// + /// The memory `[ptr, ptr.add(len.value()))` must be part of a single allocation for + /// the duration of the lifetime `'a`. + /// + /// The underlying allocation is safe to alias from multiple threads. [`RawSlice`] + /// itself is intentionally `!Send + !Sync`; each thread must derive its own from a + /// shared `&Buffer`. + unsafe fn new(ptr: NonNull, len: Bytes) -> Self { + Self { + ptr, + len, + _lifetime: PhantomData, + } + } + + /// Create a new slice to the first `n.min(self.len())` bytes of `self`. + #[inline] + pub(crate) fn truncate(&self, n: Bytes) -> RawSlice<'a> { + // SAFETY: The `min` operation ensures we provide an argument <= `self.len()`. + unsafe { self.truncate_unchecked(self.len.min(n)) } + } + + /// Shorten the slice to `n` bytes. + /// + /// # Safety + /// + /// `n` must be less than or equal to `self.len()`. + #[inline] + pub(crate) unsafe fn truncate_unchecked(&self, n: Bytes) -> RawSlice<'a> { + debug_assert!(n <= self.len); + + // SAFETY: Inherited from the caller. + unsafe { Self::new(self.ptr, n) } + } + + /// Split `self` into two as `([ptr, ptr.add(m)), [ptr.add(m), ptr.add(self.len())))` + /// where `m = n.min(self.len())`. + #[inline] + pub(crate) fn split(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { + // SAFETY: the argument is <= `self.len()`. + unsafe { self.split_unchecked(self.len.min(n)) } + } + + /// Split `self` into two as `([ptr, ptr.add(n)), [ptr.add(n), ptr.add(self.len())))` + /// + /// # Safety + /// + /// `n` must be less than or equal to `self.len()`. + #[inline] + pub(crate) unsafe fn split_unchecked(&self, n: Bytes) -> (RawSlice<'a>, RawSlice<'a>) { + debug_assert!(n <= self.len); + + // SAFETY: the argument is <= `self.len()`. + unsafe { + ( + Self::new(self.ptr, n), + Self::new(self.ptr.add(n.value()), self.len.unchecked_sub(n)), + ) + } + } + + /// Return the length of the slice in bytes. + #[inline] + pub(crate) fn len(&self) -> Bytes { + self.len + } + + /// Return the base [`NonNull`] pointer of the slice. + pub(crate) fn as_non_null(&self) -> NonNull { + self.ptr + } + + /// Return the base pointer of the slice as `*const u8`. + pub(crate) fn as_ptr(&self) -> *const u8 { + self.ptr.as_ptr().cast_const() + } + + /// Return the base pointer of the slice as `*mut u8`. + /// + /// This returns a mutable pointer regardless of the receiver's mutability, matching + /// the raw-pointer semantics of [`RawSlice`]. + pub(crate) fn as_mut_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + + /// Materialize the raw slice as a true shared slice. + /// + /// # Safety + /// + /// Correct adherence to the API of [`RawSlice`] will ensure that the memory behind the + /// materialized slice resides within a single allocation. + /// + /// However, it is the responsibility of the caller to ensure that materializing this + /// slice does not violate Rust's borrowing rules. + #[inline] + pub(crate) unsafe fn as_slice(&self) -> &'a [u8] { + // SAFETY: Inherited from caller. + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len.value()) } + } + + /// Materialize the raw slice as a true mutable slice. + /// + /// # Safety + /// + /// Correct adherence to the API of [`RawSlice`] will ensure that the memory behind the + /// materialized slice resides within a single allocation. + /// + /// However, it is the responsibility of the caller to ensure that materializing this + /// slice does not violate Rust's borrowing rules. + #[inline] + pub(crate) unsafe fn as_mut_slice(&mut self) -> &'a mut [u8] { + // SAFETY: Inherited from caller. + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len.value()) } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{sync::Barrier, thread}; + + #[derive(Debug)] + struct Ctx { + entries: usize, + bytes_per_entry: Bytes, + align: Align, + } + + impl std::fmt::Display for Ctx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "entries = {}, bytes_per_entry = {}, align = {}", + self.entries, self.bytes_per_entry, self.align + ) + } + } + + fn test_buffer_inner(entries: usize, bytes_per_entry: Bytes, align: Align) { + let ctx = Ctx { + entries, + bytes_per_entry, + align, + }; + let mut buffer = Buffer::new(entries, bytes_per_entry, align).unwrap(); + + // Initial Checks + assert_eq!(buffer.len(), entries, "{}", ctx); + assert_eq!(buffer.stride(), bytes_per_entry, "{}", ctx); + + if entries != 0 && !bytes_per_entry.is_zero() { + let addr = buffer.as_ptr() as usize; + assert!( + addr.is_multiple_of(align.value()), + "pointer address {:#x} must be a multiple of the requested alignment: {}", + addr, + ctx, + ); + } + + // Verify zero initialization + assert_is_zeroed(&mut buffer, &ctx); + + // Check Slice Methods + check_slice_methods(&mut buffer, &ctx); + + // Check that concurrent mutation is allowed. + // + // This is mainly a Miri check. + zero(&mut buffer); + check_threaded(&mut buffer, &ctx); + } + + fn zero(buffer: &mut Buffer) { + // SAFETY NOTE: Exclusive reference to `buffer` guarantees no concurrent mutation. + for i in 0..buffer.len() { + let mut raw_slice = buffer.get(i).unwrap(); + assert_eq!(raw_slice.len(), buffer.stride()); + + // SAFETY: See safety note. + let slice = unsafe { raw_slice.as_mut_slice() }; + assert_eq!(slice.len(), buffer.stride().value()); + slice.fill(0); + } + } + + fn assert_is_zeroed(buffer: &mut Buffer, ctx: &Ctx) { + // SAFETY NOTE: Exclusive reference to `buffer` guarantees no concurrent mutation. + // All `unsafe` calls below rely on this guarantee. + + for i in 0..buffer.len() { + let raw_slice = buffer.get(i).unwrap(); + assert_eq!(raw_slice.len(), buffer.stride()); + + assert_eq!(raw_slice.as_non_null().as_ptr(), raw_slice.as_mut_ptr()); + assert_eq!( + raw_slice.as_non_null().as_ptr().cast_const(), + raw_slice.as_ptr() + ); + + assert_eq!( + raw_slice.as_ptr(), + buffer + .as_ptr() + .wrapping_add(buffer.stride().checked_mul(i).unwrap().value()), + "stride mismatch - {}", + ctx + ); + + // SAFETY: See safety note. + let slice = unsafe { raw_slice.as_slice() }; + assert_eq!(slice.len(), buffer.stride().value()); + assert!(slice.iter().all(|&i| i == 0), "{}", ctx); + } + + // Verify that bounds-checking works. + assert!(buffer.get(buffer.len()).is_none(), "{}", ctx); + } + + fn check_slice_methods(buffer: &mut Buffer, ctx: &Ctx) { + // SAFETY NOTE: We take `buffer` by exclusive reference to guarantee that there + // is no possibility of concurrent mutation outside this method. All `unsafe` calls + // below rely on this guarantee unless otherwise noted. + + if buffer.len() == 0 { + return; + } + + let mut raw = buffer.get(0).unwrap(); + let base: u8 = 5; + let base_usize: usize = base.into(); + + // truncate // + + // SAFETY: see safety note. + iota(unsafe { raw.as_mut_slice() }, base); + for i in 0..raw.len().value() + base_usize { + let expected = i.min(raw.len().value()); + + let truncated = raw.truncate(Bytes::new(i)); + assert_eq!(truncated.len().value(), expected, "{}", ctx); + // SAFETY: see safety note. + assert!(is_iota(unsafe { truncated.as_slice() }, base), "{}", ctx); + } + + // split // + + for i in 0..raw.len().value() + base_usize { + let first = i.min(raw.len().value()); + let last = raw.len().value() - first; + + let (mut prefix, mut suffix) = raw.split(Bytes::new(i)); + + assert_eq!(prefix.len().value(), first, "{}", ctx); + assert_eq!(suffix.len().value(), last, "{}", ctx); + + // SAFETY: see safety note. + assert!(is_iota(unsafe { prefix.as_slice() }, base), "{}", ctx); + + assert!( + // SAFETY: see safety note. + is_iota(unsafe { suffix.as_slice() }, base.wrapping_add(i as u8)), + "{}", + ctx + ); + + // Verify it's okay to mutate two disjoint slices concurrently. + // + // SAFETY: `prefix` and `suffix` are non-overlapping sub-ranges of the same + // entry, so materializing both as mutable is sound. + { + // SAFETY: see above + let prefix = unsafe { prefix.as_mut_slice() }; + // SAFETY: see above + let suffix = unsafe { suffix.as_mut_slice() }; + suffix.fill(0); + prefix.fill(0); + } + + // SAFETY: see safety note. + assert!(unsafe { raw.as_slice() }.iter().all(|i| *i == 0), "{}", ctx); + // SAFETY: see safety note. + iota(unsafe { raw.as_mut_slice() }, base); + } + } + + fn check_threaded(buffer: &mut Buffer, ctx: &Ctx) { + let spawns = buffer.len(); + + // The goal here is to ensure that threads hold concurrent mutable references to + // disjoint entries within the `Buffer` and that when the mutate them concurrently, + // we get a coherent result. + let pre = &Barrier::new(spawns); + let post = &Barrier::new(spawns); + { + let borrowed: &Buffer = buffer; + thread::scope(|s| { + for i in 0..spawns { + s.spawn(move || { + // SAFETY: The top level method has an exclusive reference to the buffer. + // + // This loop by construction accesses disjoint offsets. This is sufficient + // to guarantee exclusivity for this thread. + let slice = unsafe { borrowed.get(i).unwrap().as_mut_slice() }; + pre.wait(); + iota(slice, i as u8); + post.wait(); + }); + } + }); + } + + for i in 0..spawns { + // SAFETY: at this point we have exclusive access to `buffer`. + let slice = unsafe { buffer.get(i).unwrap().as_slice() }; + assert!(is_iota(slice, i as u8), "i = {} -- {}", i, ctx); + } + } + + fn iota(x: &mut [u8], base: u8) { + for (i, v) in x.iter_mut().enumerate() { + *v = base.wrapping_add(i as u8); + } + } + + #[must_use] + fn is_iota(x: &[u8], base: u8) -> bool { + for (i, v) in x.iter().enumerate() { + if *v != base.wrapping_add(i as u8) { + return false; + } + } + true + } + + #[test] + fn test_buffer() { + let entries = [0, 1, 2, 5]; + let bytes_per_entry = [0, 1, 2, 5, 10].map(Bytes::new); + let align = [Align::_1, Align::_64]; + + for entries in entries { + for bytes_per_entry in bytes_per_entry { + for align in align { + test_buffer_inner(entries, bytes_per_entry, align); + } + } + } + } + + #[test] + fn test_buffer_overflow_mul() { + // entries * bytes_per_entry overflows usize. + let result = Buffer::new(usize::MAX, Bytes::new(2), Align::_1); + assert!(result.is_err()); + } + + #[test] + fn test_buffer_overflow_layout() { + // Total size exceeds isize::MAX (Layout rejects this). + let result = Buffer::new(isize::MAX as usize, Bytes::new(2), Align::_1); + assert!(result.is_err()); + } +} diff --git a/diskann-inmem/src/counters.rs b/diskann-inmem/src/counters.rs new file mode 100644 index 000000000..b53940dc2 --- /dev/null +++ b/diskann-inmem/src/counters.rs @@ -0,0 +1,173 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) use inner::{Counters, LocalCounters}; + +#[cfg(not(feature = "integration-test"))] +mod inner { + use std::marker::PhantomData; + + #[derive(Debug, Default)] + pub(crate) struct Counters; + + impl Counters { + pub(crate) fn new() -> Self { + Self + } + + pub(crate) fn local(&self) -> LocalCounters<'_> { + LocalCounters::new() + } + } + + #[derive(Debug)] + pub(crate) struct LocalCounters<'a> { + _marker: PhantomData<&'a ()>, + } + + impl LocalCounters<'_> { + fn new() -> Self { + Self { + _marker: PhantomData, + } + } + + pub(crate) fn fork(&self) -> Self { + Self::new() + } + + pub(crate) fn query_distance(&mut self, _i: u64) {} + pub(crate) fn distance_ref(&self, _i: u64) {} + pub(crate) fn get_vector(&mut self, _i: u64) {} + pub(crate) fn get_vector_ref(&self, _i: u64) {} + pub(crate) fn set_vector(&mut self, _i: u64) {} + pub(crate) fn get_neighbors(&mut self, _i: u64) {} + pub(crate) fn set_neighbors(&mut self, _i: u64) {} + pub(crate) fn append_vector(&mut self, _i: u64) {} + } +} + +#[cfg(feature = "integration-test")] +mod inner { + use std::sync::atomic::{AtomicU64, Ordering::Relaxed}; + + #[derive(Debug, Default)] + pub(crate) struct Counters { + query_distance: AtomicU64, + distance: AtomicU64, + get_vector: AtomicU64, + set_vector: AtomicU64, + get_neighbors: AtomicU64, + set_neighbors: AtomicU64, + append_neighbors: AtomicU64, + } + + impl Counters { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn local(&self) -> LocalCounters<'_> { + LocalCounters::new(self) + } + + pub(crate) fn snapshot(&self) -> crate::integration::counters::CounterSnapshot { + crate::integration::counters::CounterSnapshot { + query_distance: self.query_distance.load(Relaxed), + distance: self.distance.load(Relaxed), + get_vector: self.get_vector.load(Relaxed), + set_vector: self.set_vector.load(Relaxed), + get_neighbors: self.get_neighbors.load(Relaxed), + set_neighbors: self.set_neighbors.load(Relaxed), + append_neighbors: self.append_neighbors.load(Relaxed), + } + } + } + + #[derive(Debug)] + pub(crate) struct LocalCounters<'a> { + query_distance: u64, + // This fields needs to be `AtomicU64` because we increment in some loops where we + // have to increment it behind a shared reference. + distance: AtomicU64, + // This fields needs to be `AtomicU64` because we increment in some loops where we + // have to increment it behind a shared reference. + get_vector: AtomicU64, + set_vector: u64, + get_neighbors: u64, + set_neighbors: u64, + append_neighbors: u64, + parent: &'a Counters, + } + + impl<'a> LocalCounters<'a> { + fn new(parent: &'a Counters) -> Self { + Self { + query_distance: 0, + distance: AtomicU64::new(0), + get_vector: AtomicU64::new(0), + set_vector: 0, + get_neighbors: 0, + set_neighbors: 0, + append_neighbors: 0, + parent, + } + } + + pub(crate) fn fork(&self) -> LocalCounters<'a> { + Self::new(self.parent) + } + + pub(crate) fn query_distance(&mut self, i: u64) { + self.query_distance += i; + } + + pub(crate) fn distance_ref(&self, i: u64) { + self.distance.fetch_add(i, Relaxed); + } + + pub(crate) fn get_vector(&mut self, i: u64) { + *self.get_vector.get_mut() += i; + } + + pub(crate) fn get_vector_ref(&self, i: u64) { + self.get_vector.fetch_add(i, Relaxed); + } + + pub(crate) fn set_vector(&mut self, i: u64) { + self.set_vector += i; + } + + pub(crate) fn get_neighbors(&mut self, i: u64) { + self.get_neighbors += i; + } + + pub(crate) fn set_neighbors(&mut self, i: u64) { + self.set_neighbors += i; + } + + pub(crate) fn append_vector(&mut self, i: u64) { + self.append_neighbors += i; + } + } + + impl Drop for LocalCounters<'_> { + fn drop(&mut self) { + let parent = self.parent; + + fn update(dst: &AtomicU64, src: u64) { + dst.fetch_add(src, Relaxed); + } + + update(&parent.query_distance, self.query_distance); + update(&parent.distance, *self.distance.get_mut()); + update(&parent.get_vector, *self.get_vector.get_mut()); + update(&parent.set_vector, self.set_vector); + update(&parent.get_neighbors, self.get_neighbors); + update(&parent.set_neighbors, self.set_neighbors); + update(&parent.append_neighbors, self.append_neighbors); + } + } +} diff --git a/diskann-inmem/src/epoch.rs b/diskann-inmem/src/epoch.rs new file mode 100644 index 000000000..e8e97af72 --- /dev/null +++ b/diskann-inmem/src/epoch.rs @@ -0,0 +1,847 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! The core logic for the epoch-based reclamation algorithm. +//! +//! ## What Problem is Being Solved? +//! +//! Epoch-based reclamation (EBR) can be used to safely implement read-heavy algorithms with +//! a moderate level of concurrent writes. In this context, we want readers to be able to ask +//! the question: "can I safely read some data" in a way that generates only read traffic to +//! the CPU caches. +//! +//! The crux is that after the safety check, a reader can hold a reference to the associated +//! data for an arbitrary period of time. Any actor trying to *write* to that data needs +//! to figure out when it is safe to do so. +//! +//! EBR solves this problem by separating when data is "retired" versus "reclaimed". +//! Retirement involves disabling the safety check. When an item is retired, concurrent +//! readers will fail the safety check and no longer try to read the associated data. +//! However, we still need to wait until we can prove that readers who passed the safety +//! check before retirement are no longer accessing the data. At this point, the data can be +//! "reclaimed" and written to safely. +//! +//! We can prove this by using a monotonically increasing epoch: if an item was "retired" +//! at epoch `N` its associated data could be in use by any reader belonging to any epoch +//! `N` or lower. Therefore, it is only safe to "reclaim" when all readers belong to epoch +//! `N+1` or higher. +//! +//! One consequence of this design is that misbehaving (e.g. long-lived) readers can delay +//! reclamation indefinitely. As such, this system must be used with care and in situations +//! where there is enough slack in the system to accommodate the lifetime of any readers. +//! +//! ## Primitives +//! +//! Actors call [`Registry::guard`] to receive a [`Guard`]. This guard protects items +//! at its creation epoch. Any items pushed to [`Guard::retire`] will be buffered until the +//! [`Registry`] can prove that all [`Guard`]s (correctly using the data structure) that +//! could have observed the retired item have been destroyed. +//! +//! Items can be reclaimed via [`Registry::try_advance`]. If successful, a [`Drain`] of +//! such items will be returned for processing. +//! +//! Note that retired payloads are fixed to `u32` ids (typically interpreted by the caller +//! as indices into some external storage); this is not a general-purpose deferred-drop EBR +//! system. + +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + +use crossbeam_queue::SegQueue; +use diskann::utils::IntoUsize; +use parking_lot::{Mutex, MutexGuard}; + +const CAPACITY: usize = 256; + +/// A registry of epoch-based [`Guard`]s. See the [module-level docs](self). +#[derive(Debug)] +pub(crate) struct Registry { + // A record of the active guards. + // + // * 0 = "available". + // * Anything less = "guarded". + guards: Box<[AtomicU64]>, + + // A hint for the next available registration slot. + hint: AtomicUsize, + + // The current epoch. This begins at 1 (to not be conflated with the 0 state in `guards`) + // and increments over time. + // + // NOTE: This can **only** be mutated in `try_advance`. + // + // Additionally, the logic in the module ensures that there are at most two active epochs + // at any given time. It is only safe to advance an epoch if *all* readers belong to the + // current `epoch`. + epoch: AtomicU64, + + // We can only retire a single generation at a time. + // + // This guard avoids situations where two threads concurrently advance the epoch and + // hand out overlapping `Drain`s referring to the same retiring queue. + drain: Mutex<()>, + + // We use four queues for storing retiring items. The rationale is documented below. + // + // ```text + // + // 1. Safe to drain + // +-------------------------- + // Items retired at N-1 can | 2. Epoch N-1 + // be observed by guards at | +----------------------- + // N. If we transition to | | 3. Epoch N + // N+1, guards at N can be +-------------------------- + // active still. This it is | 4. Epoch N+1 + // not safe to reclaim items +----------------------- + // from this queue until all 5. Epoch N+2 (reuse #1 queue) + // guards are at least N+1. + // ``` + // + // We cycle among the queues in a round-robin manner. + retiring: Box<[SegQueue; 4]>, +} + +// Return the queue index for the `epoch`. +fn queue(epoch: u64) -> usize { + epoch.into_usize() % 4 +} + +fn last_queue(epoch: u64) -> usize { + queue(epoch.wrapping_sub(2)) +} + +impl Registry { + /// Construct a new [`Registry`] with the default number of guard slots (256). + pub(crate) fn new() -> Self { + Self::with_capacity(CAPACITY) + } + + /// Construct a new [`Registry`] with `capacity` guard slots. + /// + /// This is the number of [`Guard`]s that can be registered concurrently. + pub(crate) fn with_capacity(capacity: usize) -> Self { + Self { + guards: (0..capacity).map(|_| AtomicU64::new(0)).collect(), + hint: AtomicUsize::new(0), + epoch: AtomicU64::new(1), + retiring: Box::new(core::array::from_fn(|_| SegQueue::new())), + drain: Mutex::new(()), + } + } + + /// Return the current epoch. + /// + /// This has [`Ordering::Acquire`] semantics. + pub(crate) fn epoch(&self) -> u64 { + self.epoch.load(Ordering::Acquire) + } + + /// Register the caller with `self`. + /// + /// Any items retired while [`Guard`] is held will be protected. + /// + /// # Errors + /// + /// Returns an error if the number of currently active guards exceeds the number of + /// internal guard slots and thus a new guard cannot be made. + pub(crate) fn guard(&self) -> Result, Unavailable> { + self.guard_inner(NoDelay) + } + + #[inline] + fn guard_inner(&self, mut delay: T) -> Result, Unavailable> + where + T: GuardDelay, + { + // GUARD CHECK + let mut epoch = self.epoch(); + let hint = self.hint.fetch_add(1, Ordering::Relaxed); + delay.post_guard_check(); + let nguards = self.guards.len(); + for i in 0..nguards { + let slot = hint.wrapping_add(i) % nguards; + + let m = &self.guards[slot]; + delay.pre_cas(); + if m.compare_exchange(0, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + delay.post_cas(); + let mut reset = false; + loop { + // GUARD FENCE: This fence is paired with "WAITING FENCE". + // + // See that comment for details. + delay.pre_fence(); + std::sync::atomic::fence(Ordering::SeqCst); + delay.post_fence(); + + // GUARD RECHECK + let current = self.epoch(); + if current == epoch { + break; + } + + reset = true; + epoch = current; + } + + if reset { + m.store(epoch, Ordering::Relaxed); + } + + return Ok(Guard { + slot: m, + retire: &self.retiring[queue(epoch)], + #[cfg(test)] + epoch, + #[cfg(test)] + slot_index: slot, + }); + } + } + + Err(Unavailable) + } + + fn can_advance(&self, delay: &mut T) -> (bool, u64) + where + T: CanAdvanceDelay, + { + // WAITING FENCE: This is a very important part for the correctness of the algorithm. + // + // What we're protecting against is a scenario where "registering" thread A reads an + // epoch, then "waiting" thread B does a scan, thinks everything is safe, and then + // thread A finishes its CAS for its registration. + // + // This is prevented by the sequentially consistent fences. Consider the following. + // + // 1. Thread A invokes "GUARD FENCE" after a successful CAS, and then checks the + // generation at "GUARD RECHECK". + // + // 2. Thread B now enters the this block of code, executes "WAITING FENCE", then + // reads the epoch tags for all guards. + // + // With the total order induced by the sequentially consistency, either thread A's + // fence executes first, or thread B's executes first. + // + // * If thread A's fence executes first, then thread B will see the CAS and the set + // value is guaranteed to be less-than or equal to "WAITING CHECK" because: + // + // 1. The epoch is monotonically increasing. + // 2. Writes to the epoch are also sequentially consistent. + // + // * If Thread B's fence executes first, then thread A's "GUARD RECHECK" will + // observe at least the result of "WAITING CHECK" and update itself on the retry. + // + // It's possible that thread B observes the CAS to "GUARD CHECK", but since + // thread A will monotonically increase it before exiting, the value thread B + // observes is conservative and not incorrect. + delay.pre_fence(); + std::sync::atomic::fence(Ordering::SeqCst); + delay.post_fence(); + + // WAITING CHECK + let current = self.epoch(); + let mut min = current; + + for s in self.guards.iter() { + let guarded = s.load(Ordering::Relaxed); + if guarded != 0 { + min = min.min(guarded); + } + } + + // This synchronizes with all the guard's `Release`s. + std::sync::atomic::fence(Ordering::Acquire); + (min == current, min) + } + + /// Try to advance the current epoch. + /// + /// If successful, returns a [`Drain`]. All items in the drain can be reclaimed. + /// + /// Returns `None` if the epoch cannot yet be advanced (some [`Guard`] still belongs to + /// a prior epoch) or if another [`Drain`] is currently active. + /// + /// # Panics + /// + /// Panics if the epoch counter is about to overflow `u64::MAX`. In practice this is + /// effectively unreachable. + pub(crate) fn try_advance(&self) -> Option> { + self.try_advance_inner(NoDelay) + } + + #[expect( + clippy::panic, + reason = "the panic is exceedingly unlikely to happen and if it does, we can't continue" + )] + fn try_advance_inner(&self, mut delay: T) -> Option> + where + T: TryAdvanceDelay, + { + // We first try to acquire the `drain` lock. + // + // It can only fail if someone else is holding the drain lock, which means we can't + // proceed anyways. + // + // This can help save an expensive slot scan. + let drain = self.drain.try_lock()?; + + let (can_advance, current) = self.can_advance(&mut delay); + + // Don't wrap around! + if current == u64::MAX { + panic!( + "we've managed to go through nearly `u64::MAX` ids - this is unlikely in a real program" + ); + } + + // All waiters belong to the current epoch. Therefore, it is safe to release the old + // array queue + if can_advance { + // We are safe to use a `fetch_add` here because `drain` is ensuring exclusivity + // of the access. + // + // However, this still needs to be `SeqCst` so that this properly synchronizes + // with "GUARD FENCE" and "WAITING FENCE". + let _previous = self.epoch.fetch_add(1, Ordering::SeqCst); + debug_assert_eq!(_previous, current, "concurrency violation"); + + let queue = &self.retiring[last_queue(current)]; + Some(Drain { + queue, + _drain: drain, + }) + } else { + // Previous generation has not completely retired. + None + } + } + + #[cfg(test)] + fn assert_no_workers(&self) { + for s in self.guards.iter() { + assert_eq!(s.load(Ordering::Relaxed), 0); + } + } + + #[cfg(test)] + fn waiting(&self) -> u64 { + self.can_advance(&mut NoDelay).1 + } +} + +/// A handle registering the caller as a reader at a particular epoch. +/// +/// While this guard is held, the [`Registry`] will not advance past the guard's epoch, and +/// any items retired through *any* guard at that epoch (or earlier) will not be reclaimed. +/// +/// Obtained via [`Registry::guard`]. +#[derive(Debug)] +pub(crate) struct Guard<'a> { + slot: &'a AtomicU64, + retire: &'a SegQueue, + + #[cfg(test)] + pub(super) epoch: u64, + + #[cfg(test)] + slot_index: usize, +} + +impl Guard<'_> { + /// Retire the id `i` at this guard's epoch. + /// + /// `i` is a caller-defined id (typically an index into external storage). It will be + /// returned from a future [`Drain`] once the registry has advanced far enough that no + /// reader could observe it. + #[inline] + pub(crate) fn retire(&self, i: u32) { + self.retire.push(i) + } +} + +impl Drop for Guard<'_> { + fn drop(&mut self) { + self.slot.store(0, Ordering::Release); + } +} + +/// An iterator over ids that are safe to reclaim, returned from [`Registry::try_advance`]. +/// +/// While this drain is alive, no other thread can advance the [`Registry`]'s epoch. Drop +/// it promptly after processing. +#[derive(Debug)] +pub(crate) struct Drain<'a> { + queue: &'a SegQueue, + _drain: MutexGuard<'a, ()>, +} + +impl Drain<'_> { + /// Pop the next id ready for reclamation, or `None` if the drain is empty. + #[must_use = "reclaimed ids must be reclaimed"] + pub(crate) fn pop(&self) -> Option { + self.queue.pop() + } + + /// Return the number of ids remaining in this drain. + pub(crate) fn len(&self) -> usize { + self.queue.len() + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Iterator for Drain<'_> { + type Item = u32; + fn next(&mut self) -> Option { + self.pop() + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +// NOTE: This relies on `Drain` holding the `drain` guard. In this state, we are guaranteed +// that no-one is writing into the queue, which would otherwise invalidate the exact-size +// iterator guarantee. +impl ExactSizeIterator for Drain<'_> {} + +/// Returned by [`Registry::guard`] when all guard slots are occupied. +#[derive(Debug)] +#[non_exhaustive] +pub(crate) struct Unavailable; + +impl std::fmt::Display for Unavailable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("all available registry guard slots are occupied") + } +} + +impl std::error::Error for Unavailable {} + +crate::opaque!(Unavailable); + +// Delays +// +// To help test standard race scenarios without advanced tooling, we use optional delays +// that our tests can introduce to ensure threads are in various intermediate points. +// +// This does not necessarily test that the memory orderings are correct, but at least +// is a smoke test that various (known) races are handled properly. + +#[derive(Debug)] +struct NoDelay; + +trait GuardDelay { + fn post_guard_check(&mut self) {} + fn pre_cas(&mut self) {} + fn post_cas(&mut self) {} + fn pre_fence(&mut self) {} + fn post_fence(&mut self) {} +} + +impl GuardDelay for NoDelay {} + +trait CanAdvanceDelay { + fn pre_fence(&mut self) {} + fn post_fence(&mut self) {} +} + +impl CanAdvanceDelay for NoDelay {} + +trait TryAdvanceDelay: CanAdvanceDelay {} + +impl TryAdvanceDelay for NoDelay {} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test::Sequencer; + + // This test ensures that two threads racing on `hint` will correctly resolve themselves + // when claiming a slot. + #[test] + fn test_cas_race() { + let seq = Sequencer::new(); + + let mut thread_a_loop_count = 0; + let mut thread_b_loop_count = 0; + let delay = TestGuardDelay::default() + .post_guard_check(|| seq.wait_for(0)) + .with_post_fence(|| thread_a_loop_count += 1); + + let registry = Registry::with_capacity(2); + std::thread::scope(|s| { + // Thread A + s.spawn(|| { + let g = registry.guard_inner(delay).unwrap(); + assert_eq!(g.slot_index, 1); + seq.wait_for(1); + }); + + // Thread B + s.spawn(|| { + // wait for Thread A to reach the delay point. + seq.until_waiting_for(0); + { + let delay = + TestGuardDelay::default().with_post_fence(|| thread_b_loop_count += 1); + let g = registry.guard_inner(delay).unwrap(); + assert_eq!(g.slot_index, 1); + } + let g = registry.guard_inner(NoDelay).unwrap(); + assert_eq!(g.slot_index, 0); + seq.advance_past(0); + seq.advance_past(1); + }); + }); + + assert_eq!(thread_a_loop_count, 1); + assert_eq!(thread_b_loop_count, 1); + + registry.assert_no_workers(); + } + + #[test] + fn test_register_wait() { + // This tests the case where a thread enters registration, reads a generation, then + // sleeps for several generation advances. It ensures that the thread recovers properly. + let seq = Sequencer::new(); + + let mut loop_count = 0; + let delay = TestGuardDelay::default() + .post_guard_check(|| seq.wait_for(0)) + .with_post_cas(|| seq.wait_for(1)) + .with_pre_fence(|| loop_count += 1); + + let registry = Registry::with_capacity(2); + + std::thread::scope(|s| { + let handle = s.spawn(|| { + let guard = registry.guard_inner(delay).unwrap(); + + // Since we hit the CAS loop - this serves as a sanity check that we have + // the correct drain buffer. + guard.retire(10); + guard.retire(1); + guard.retire(2); + guard.retire(3); + guard + }); + + // Wait for the spawned thread to reach the critical section. + seq.until_waiting_for(0); + + assert_eq!(registry.waiting(), 1); + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + assert_eq!(registry.epoch(), 2); + } + + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + assert_eq!(registry.epoch(), 3); + } + + // We allow the registering thread to make it past the CAS. + // + // We pause it again because we want to verify that it registers an old generation. + seq.advance_past(0); + seq.until_waiting_for(1); + let (can_advance, waiter) = registry.can_advance(&mut NoDelay); + assert!(!can_advance); + assert_eq!( + waiter, 1, + "waiting thread registers an older generation before observing the change" + ); + seq.advance_past(1); + + let expected = 3; + + // The generation should be the last set one - even though this thread was + // parked during the transition. + let r = handle.join().unwrap(); + assert_eq!(r.epoch, expected); + assert_eq!(registry.waiting(), expected); + }); + + assert_eq!( + loop_count, 2, + "the registering thread should have looped to update its generation" + ); + + registry.assert_no_workers(); + + // Verify that we reclaim the ID flushed by the registering thread. + // + // This requires three epoch advancements. + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + } + + { + let drain = registry.try_advance().unwrap(); + assert!(drain.is_empty()); + } + + { + let drain = registry.try_advance().unwrap(); + let ids: Vec<_> = drain.collect(); + assert_eq!(ids, &[10, 1, 2, 3]); + } + } + + // Verifies that filling every slot causes `register` to return `Unavailable`, and that + // dropping an existing guard frees up its slot for a subsequent registration. + #[test] + fn test_slot_exhaustion() { + let registry = Registry::with_capacity(2); + + let g0 = registry.guard().unwrap(); + let g1 = registry.guard().unwrap(); + + // All guard slots are now occupied. The next registration must fail. + assert!(matches!(registry.guard(), Err(Unavailable))); + assert!(matches!(registry.guard(), Err(Unavailable))); + + // Dropping a guard releases its slot. + let freed_slot = g0.slot_index; + drop(g0); + + let g2 = registry.guard().unwrap(); + assert_eq!( + g2.slot_index, freed_slot, + "newly freed slot should be reclaimed" + ); + + // Registry is full again. + assert!(matches!(registry.guard(), Err(Unavailable))); + + drop(g1); + drop(g2); + + registry.assert_no_workers(); + } + + #[test] + fn test_slot_wrap_around() { + let registry = Registry::with_capacity(4); + + let (g2, g3) = { + let _g0 = registry.guard().unwrap(); + let _g1 = registry.guard().unwrap(); + + let g2 = registry.guard().unwrap(); + let g3 = registry.guard().unwrap(); + (g2, g3) + }; + + assert_eq!(g2.slot_index, 2); + assert_eq!(g3.slot_index, 3); + + let f = || { + // Keep wrapping and hitting the first two guard slots. + for _ in 0..10 { + let g0 = registry.guard().unwrap(); + let g1 = registry.guard().unwrap(); + + let s0 = g0.slot_index; + let s1 = g1.slot_index; + + // Due to how the hint works, the slots could be acquired in either order. + if s0 < s1 { + assert_eq!((s0, s1), (0, 1)); + } else { + assert_eq!((s0, s1), (1, 0)); + }; + + assert!(matches!(registry.guard(), Err(Unavailable))); + } + }; + + // Run with the default hint. + f(); + + // Set the hint to `usize::MAX`. + registry.hint.store(usize::MAX - 10, Ordering::Relaxed); + + // Run tests again to ensure we can properly handle wrap-around. + f(); + + drop((g2, g3)); + registry.assert_no_workers(); + } + + // Verifies that `try_advance` short-circuits to `None` when another thread already holds + // the `drain` mutex, even if `can_advance` would otherwise succeed. This guards the + // early `try_lock` that avoids a redundant slot scan. + #[test] + fn test_concurrent_try_advance() { + let registry = Registry::with_capacity(2); + + // No outstanding registrations, so `can_advance` would succeed for any caller. + let drain = registry + .try_advance() + .expect("first try_advance must succeed"); + let gen_after_first = registry.epoch(); + assert_eq!(gen_after_first, 2); + + // While the first `Drain` is alive (holding the drain mutex), a concurrent + // `try_advance` must return `None` without advancing the generation. + std::thread::scope(|s| { + s.spawn(|| { + assert!( + registry.try_advance().is_none(), + "try_advance must fail while another holds the drain mutex" + ); + assert_eq!( + registry.epoch(), + gen_after_first, + "generation must not advance when drain is contended" + ); + }); + }); + + // Releasing the drain unblocks subsequent advances. + drop(drain); + + let _drain2 = registry + .try_advance() + .expect("try_advance must succeed once drain is released"); + assert_eq!(registry.epoch(), 3); + } + + // Verifies the 3-queue rotation invariant: items retired at generation `G` are drained + // on the second `try_advance` after `G`. The first advance returns the queue from + // `(G - 1) % 3` (one cycle older), so it must NOT contain items from `G`. + #[test] + fn test_drain_rotation() { + let registry = Registry::with_capacity(1); + + // Helper: register, retire one item, drop. Returns the generation we retired at. + let retire_at = |id: u32| { + let g = registry.guard().unwrap(); + let epoch = g.epoch; + g.retire(id); + epoch + }; + + // Retire 100 at generation A (= 1). + let gen_a = retire_at(100); + assert_eq!(gen_a, 1); + + // 1st advance after A: must NOT drain item 100. + { + let drain = registry.try_advance().unwrap(); + assert!( + drain.is_empty(), + "100 must not drain on 1st advance after A" + ); + } + + // Retire 200 at generation B (= A - 1). + let gen_b = retire_at(200); + assert_eq!(gen_b, gen_a + 1); + + // 2st advance after A: must NOT drain item 100. + { + let drain = registry.try_advance().unwrap(); + assert!( + drain.is_empty(), + "100 must not drain on 2nd advance after A" + ); + } + + // Retire 300 at generation C. + let _gen_c = retire_at(300); + + // 3rd advance after A (1st after B): drains A's queue → [100]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[100]); + } + + // 3rd advance after B: drains B's queue → [200]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[200]); + } + + // 3rd advance after C: drains C's queue → [300]. + { + let drained: Vec<_> = registry.try_advance().unwrap().collect(); + assert_eq!(drained, &[300]); + } + + // Rotation has cycled back to where A's queue used to live — must be empty, + // proving the queue slot was drained cleanly and is reusable. + { + let drain = registry.try_advance().unwrap(); + assert!( + drain.is_empty(), + "rotation should leave queues empty after one cycle" + ); + } + + registry.assert_no_workers(); + } + + //-------------// + // Test Delays // + //-------------// + + macro_rules! tester { + ($struct:ident, $trait:ident, $($with:ident => $f:ident),* $(,)?) => { + #[derive(Default)] + struct $struct<'a> { + $($f: Option>,)* + } + + impl<'a> $struct<'a> { + $( + fn $with(mut self, f: F) -> Self + where + F: FnMut() + Send + 'a + { + self.$f = Some(Box::new(f)); + self + } + )* + } + + impl $trait for $struct<'_> { + $( + fn $f(&mut self) { + if let Some(f) = self.$f.as_mut() { + f() + } + } + )* + } + } + } + + tester! { + TestGuardDelay, + GuardDelay, + post_guard_check => post_guard_check, + with_post_cas => post_cas, + with_pre_fence => pre_fence, + with_post_fence => post_fence, + } +} diff --git a/diskann-inmem/src/freelist.rs b/diskann-inmem/src/freelist.rs new file mode 100644 index 000000000..a0400e672 --- /dev/null +++ b/diskann-inmem/src/freelist.rs @@ -0,0 +1,444 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Finding Unused IDs +//! +//! When working with slots into an index, finding an available slot efficiently can be +//! challenging. This module provides a [`Freelist`] to make this more efficient. +//! +//! IDs are retrieved in the following order of precedence: +//! +//! ## Recycles +//! +//! Previously reclaimed slots can be recycled and are the preferred way of finding slots. +//! Reclaimed slots IDs live inside an atomic queue and as such, the size of this queue is +//! bounded to conserve memory. +//! +//! ## Minted +//! If no slots live in the recycled queue, new slots can be "minted" up to the configured +//! maximum. This simply tracks the maximum slot ID that has been yielded so far and returns +//! the next one. +//! +//! This path only works during the initial filling of the managed slots and exists to +//! provide a fast-path for static index builds. Once the maximum slot has been yielded, +//! minting no longer applies. +//! +//! ## Scanning +//! +//! If a slot cannot be found via recycling or via minting, a scan is requested. Scans +//! typically involve searching over an authoritative source of slot usage to find and +//! claim an unused slot. +//! +//! The [`Freelist`] assists with scans in several ways: +//! +//! 1. [`Freelist::scan`]: Receive a range of managed IDs to scan. Multiple threads +//! can call this method to receive disjoint ranges to process. +//! +//! 2. [`Freelist::push`]: Available slots can be placed into the +//! freelist for recycling. +//! +//! 3. [`Freelist::pop_recycled`]: Attempt to retrieve a slot ID directly from the recycled +//! buffer. +//! +//! Together, these tools can be used to build a cooperative scan. A thread scans a block of +//! IDs returned by [`Freelist::scan`]. If a slot is claimed this way, the thread can continue +//! scanning the rest of the block, pushing any available slot IDs to the freelist. +//! +//! Other threads that are unsuccessfully scanning can periodically check +//! [`Freelist::pop_recycled`] to benefit from the work done by another more successful thread. +//! +//! # Non-Authoritative +//! +//! Note that the [`Freelist`] does not attempt to be authoritative on the list of slots IDs +//! that are used and unused. Its job is mainly to improve performance. +//! +//! An authoritative collection of [`AtomicTag`](super::AtomicTag)s must be used to correctly +//! manage slots. + +use std::{ + num::NonZeroU32, + sync::atomic::{AtomicU32, Ordering}, +}; + +use crossbeam_queue::ArrayQueue; +use diskann::utils::IntoUsize; + +// NOTE: We want the scan size to be relatively big. Each tag occupied just a single byte, +// so a scan needs to be at least 64 to ensure a thread is working with just a single cache +// line. +const SCAN_SIZE: u32 = 256; + +/// A tool for quickly finding unused slots in an index. +/// +/// See [freelist](self) for details. +#[derive(Debug)] +pub(crate) struct Freelist { + // Bounded fast queue of retired slots. + recycled: ArrayQueue, + + // The highest ID the freelist manages. IDs `>= max` are rejected by `push`/`append` + // and the minting path will not yield them. + max: u32, + + // The next `unminted` Id. This becomes unused once this reaches `max`. + next: AtomicU32, + + // The current bucket for scanning. + scan_bucket: AtomicU32, +} + +impl Freelist { + /// Construct a new [`Freelist`] that manages `max` ids. + /// + /// The internal fast recycled list will hold up to `recycled` items. + /// + /// The memory occupied by this struct is `O(recycled)`. + pub(crate) fn new(max: u32, recycled: NonZeroU32) -> Self { + Self { + recycled: ArrayQueue::new(recycled.get().into_usize()), + max, + next: AtomicU32::new(0), + scan_bucket: AtomicU32::new(0), + } + } + + /// Try to retrieve an id. + /// + /// If successful, return [`Id::Found`]. Otherwise, returns [`Id::Scan`]. + pub(crate) fn pop(&self) -> Id { + if let Some(id) = self.recycled.pop() { + return Id::Found(id); + } + + // Missed in the recycled buffer. Try pulling from the high-water mark. + let mut next = self.next.load(Ordering::Relaxed); + while next < self.max { + match self + .next + .compare_exchange(next, next + 1, Ordering::Relaxed, Ordering::Relaxed) + { + Ok(next) => return Id::Found(next), + Err(actual) => { + next = actual; + } + } + } + + // Missed in the recycle bin and from unallocated IDs. Time to indicate a scan. + Id::Scan + } + + /// Attempt to retrieve an ID directly from the recycled list. + /// + /// This may be used during scans to retrieve IDs found by other threads. + pub(crate) fn pop_recycled(&self) -> Option { + self.recycled.pop() + } + + /// Return a new [`Scan`] containing a range of IDs to check. + /// + /// This is managed such that multiple threads calling this function will receive + /// disjoint ranges to scan. + pub(crate) fn scan(&self) -> Scan { + if self.max == 0 { + return Scan { start: 0, stop: 0 }; + } + + let num_buckets = self.max.div_ceil(SCAN_SIZE); + + // It's possible that if `scan_bucket` wraps, we do a bit of redundant scanning. + // + // This is fine as this should happen rarely. + let bucket = self.scan_bucket.fetch_add(1, Ordering::Relaxed) % num_buckets; + + let start = bucket * SCAN_SIZE; + let stop = match start.checked_add(SCAN_SIZE) { + Some(stop) => stop.min(self.max), + None => self.max, + }; + + Scan { start, stop } + } + + /// Attempt to push `id` into the recycled list. Return `true` if `id` was inserted. + /// + /// If `false` is returned, it is likely because the internal recycle buffer is full. + /// + /// IDs at or above [`Self::max`] are discarded. + pub(crate) fn push(&self, id: u32) -> bool { + if id < self.max { + self.recycled.push(id).is_ok() + } else { + false + } + } +} + +/// The result of [`Freelist::pop`]. +#[derive(Debug, Clone, Copy)] +#[must_use] +pub(crate) enum Id { + /// An ID was found directly in the [`Freelist`]. + Found(u32), + /// No ID was found in the [`Freelist`] and an exhaustive scan is recommended. + Scan, +} + +#[cfg(test)] +impl Id { + fn unwrap(self) -> u32 { + match self { + Self::Found(i) => i, + Self::Scan => panic!("expected Id::Found, got Id::Scan"), + } + } + + fn is_scan(self) -> bool { + matches!(self, Self::Scan) + } +} + +/// An [`ExactSizeIterator`] over IDs to scan. Returned by [`Freelist::scan`]. +#[derive(Debug)] +pub(crate) struct Scan { + start: u32, + stop: u32, +} + +impl Scan { + #[cfg(test)] + fn as_range(&self) -> std::ops::Range { + self.start..self.stop + } +} + +impl Iterator for Scan { + type Item = u32; + fn next(&mut self) -> Option { + if self.start >= self.stop { + None + } else { + let i = self.start; + self.start += 1; + Some(i) + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = (self.stop - self.start).into_usize(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for Scan {} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{collections::HashSet, sync::Barrier, thread}; + + fn freelist(max: u32, recycled: u32) -> Freelist { + Freelist::new(max, NonZeroU32::new(recycled).unwrap()) + } + + //---------// + // Minting // + //---------// + + #[test] + fn pop_mints_sequentially_until_exhausted() { + let fl = freelist(4, 8); + + let mut got = Vec::new(); + for _ in 0..4 { + got.push(fl.pop().unwrap()); + } + assert_eq!(got, vec![0, 1, 2, 3]); + assert!(fl.pop().is_scan()); + assert!(fl.pop().is_scan()); + } + + #[test] + fn pop_returns_scan_when_max_zero() { + let fl = freelist(0, 1); + assert!(fl.pop().is_scan()); + } + + #[test] + fn recycled_ids_take_precedence_over_minting() { + let fl = freelist(4, 8); + // Seed the recycled queue. + assert!(fl.push(2)); + // First pop must come from the recycled queue, not mint 0. + assert_eq!(fl.pop().unwrap(), 2); + // Subsequent pops mint from 0. + assert_eq!(fl.pop().unwrap(), 0); + } + + //------// + // Push // + //------// + + #[test] + fn push_rejects_ids_at_or_above_max() { + let fl = freelist(4, 8); + assert!(!fl.push(4)); + assert!(!fl.push(u32::MAX)); + assert!(fl.push(3)); + + assert_eq!(fl.pop_recycled().unwrap(), 3); + } + + #[test] + fn push_returns_false_when_recycled_full() { + let fl = freelist(16, 2); + assert!(fl.push(2)); + assert!(fl.push(3)); + assert!(!fl.push(5)); + + // Drained from recycled queue. + assert_eq!(fl.pop().unwrap(), 2); + assert_eq!(fl.pop().unwrap(), 3); + } + + #[test] + fn pop_recycled_empty_returns_none() { + let fl = freelist(4, 4); + assert!(fl.pop_recycled().is_none()); + } + + #[test] + fn pop_recycled_does_not_mint() { + let fl = freelist(4, 4); + // No pushes, no recycled entries — `pop_recycled` must not fall through to minting. + assert!(fl.pop_recycled().is_none()); + // The minting counter should be untouched. + assert_eq!(fl.pop().unwrap(), 0); + } + + //------// + // Scan // + //------// + + fn as_vec(itr: I) -> Vec + where + I: Iterator, + { + itr.collect() + } + + #[test] + fn scan_on_empty_freelist_yields_nothing() { + let fl = freelist(0, 1); + let mut scan = fl.scan(); + assert_eq!(scan.len(), 0); + assert!(scan.next().is_none()); + } + + #[test] + fn scan_covers_full_range_in_one_pass() { + // Choose `max` to force a partial last bucket. + let max = 2 * SCAN_SIZE + 50; + let fl = freelist(max, 4); + + // First Round + + let scan = fl.scan(); + assert_eq!(scan.as_range(), 0..SCAN_SIZE); + assert_eq!(scan.len(), SCAN_SIZE.into_usize()); + assert_eq!(as_vec(scan), as_vec(0..SCAN_SIZE)); + + let scan = fl.scan(); + assert_eq!(scan.as_range(), SCAN_SIZE..2 * SCAN_SIZE); + assert_eq!(scan.len(), SCAN_SIZE.into_usize()); + assert_eq!(as_vec(scan), as_vec(SCAN_SIZE..2 * SCAN_SIZE)); + + let scan = fl.scan(); + assert_eq!(scan.as_range(), 2 * SCAN_SIZE..(2 * SCAN_SIZE + 50)); + assert_eq!(scan.len(), 50); + assert_eq!(as_vec(scan), as_vec((2 * SCAN_SIZE)..(2 * SCAN_SIZE + 50))); + + // Check Wrapping + + let scan = fl.scan(); + assert_eq!(scan.as_range(), 0..SCAN_SIZE); + assert_eq!(scan.len(), SCAN_SIZE.into_usize()); + assert_eq!(as_vec(scan), as_vec(0..SCAN_SIZE)); + } + + //-------------// + // Concurrency // + //-------------// + + #[test] + fn concurrent_pop_yields_unique_ids() { + let max = 4096u32; + let fl = Freelist::new(max, NonZeroU32::new(8).unwrap()); + let nthreads = 8; + let barrier = Barrier::new(nthreads); + + let results: Vec> = thread::scope(|s| { + let handles: Vec<_> = (0..nthreads) + .map(|_| { + s.spawn(|| { + let mut out = Vec::new(); + barrier.wait(); + while let Id::Found(id) = fl.pop() { + out.push(id); + } + out + }) + }) + .collect(); + handles.into_iter().map(|h| h.join().unwrap()).collect() + }); + + let mut all: Vec = results.into_iter().flatten().collect(); + all.sort(); + let expected: Vec = (0..max).collect(); + assert_eq!(all, expected, "all ids in [0, max) minted exactly once"); + } + + #[test] + fn concurrent_scan_partitions_one_pass() { + let max = SCAN_SIZE * 4; + let fl = Freelist::new(max, NonZeroU32::new(4).unwrap()); + let num_buckets = max.div_ceil(SCAN_SIZE) as usize; + let nthreads = num_buckets; + let barrier = Barrier::new(nthreads); + + let ids: Vec = thread::scope(|s| { + let handles: Vec<_> = (0..nthreads) + .map(|_| { + s.spawn(|| { + barrier.wait(); + fl.scan().collect::>() + }) + }) + .collect(); + handles + .into_iter() + .flat_map(|h| h.join().unwrap()) + .collect() + }); + + let unique: HashSet = ids.iter().copied().collect(); + assert_eq!( + unique.len(), + ids.len(), + "no id appeared twice across threads" + ); + assert_eq!( + unique.len() as u32, + max, + "scans covered every id in [0, max)" + ); + } +} diff --git a/diskann-inmem/src/integration/counters.rs b/diskann-inmem/src/integration/counters.rs new file mode 100644 index 000000000..fcb16d6cd --- /dev/null +++ b/diskann-inmem/src/integration/counters.rs @@ -0,0 +1,17 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +/// A snapshot of global [`Counters`](crate::counters::Counters). +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct CounterSnapshot { + pub query_distance: u64, + pub distance: u64, + pub get_vector: u64, + pub set_vector: u64, + pub get_neighbors: u64, + pub set_neighbors: u64, + pub append_neighbors: u64, +} diff --git a/diskann-inmem/src/integration/mod.rs b/diskann-inmem/src/integration/mod.rs new file mode 100644 index 000000000..312e41f4c --- /dev/null +++ b/diskann-inmem/src/integration/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub mod counters; +pub mod store; diff --git a/diskann-inmem/src/integration/store.rs b/diskann-inmem/src/integration/store.rs new file mode 100644 index 000000000..ff8b5f797 --- /dev/null +++ b/diskann-inmem/src/integration/store.rs @@ -0,0 +1,100 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +#![expect( + clippy::expect_used, + reason = "integration test tools are not production code" +)] + +use diskann_utils::views::Matrix; + +use crate::{num::Bytes, store}; + +#[derive(Debug)] +pub struct Store { + store: store::Store, +} + +impl Store { + /// Construct a store with `capacity` writable slots, each holding `entry_bytes` bytes. + /// + /// A single zeroed frozen point is created internally to satisfy the underlying + /// store's requirement of at least one frozen entry; it occupies the highest slot + /// index and is always readable. + /// + /// # Panics + /// + /// Panics if the underlying store could not be constructed (e.g. `capacity` plus the + /// frozen point exceeds `u32::MAX`). + pub fn new(capacity: usize, entry_bytes: usize) -> Self { + let data = Matrix::new(0u8, 1, entry_bytes); + let store = store::Store::new(capacity, Bytes::new(entry_bytes), 0, data.as_view()) + .expect("failed to construct store"); + Self { store } + } + + /// Return the total number of slots, including the frozen point. + pub fn slots(&self) -> usize { + self.store.frozen().end as usize + } + + /// Return the range of writable (non-frozen) slot indices. + pub fn writable(&self) -> std::ops::Range { + 0..(self.store.frozen().start as usize) + } + + /// Attempt to reclaim retired slots, returning the number reclaimed if any. + pub fn reclaim(&self) -> Option { + self.store.try_drain() + } + + pub fn acquire(&self) -> Option> { + self.store.acquire().map(Writer::new) + } + + #[must_use = "result indicates success or failure"] + pub fn retire(&self, i: usize) -> bool { + self.store.retire(i).is_ok() + } + + pub fn reader(&self) -> Option> { + match self.store.reader() { + Ok(reader) => Some(Reader::new(reader)), + Err(crate::epoch::Unavailable) => None, + } + } +} + +pub struct Reader<'a> { + reader: store::Reader<'a>, +} + +impl<'a> Reader<'a> { + fn new(reader: store::Reader<'a>) -> Self { + Self { reader } + } + + pub fn read(&self, i: usize) -> Option<&[u8]> { + self.reader.read(i) + } +} + +pub struct Writer<'a> { + slot: store::Slot<'a>, +} + +impl<'a> Writer<'a> { + fn new(slot: store::Slot<'a>) -> Self { + Self { slot } + } + + pub fn publish(self) { + self.slot.publish(); + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.slot.as_mut_slice() + } +} diff --git a/diskann-inmem/src/layers/full.rs b/diskann-inmem/src/layers/full.rs new file mode 100644 index 000000000..b9807282e --- /dev/null +++ b/diskann-inmem/src/layers/full.rs @@ -0,0 +1,703 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{fmt::Debug, marker::PhantomData}; + +use diskann::{ANNError, ANNResult}; +use diskann_vector::{ + UnalignedSlice, + conversion::SliceCast, + distance::{ + self, Cosine, CosineNormalized, DistanceProvider, InnerProduct, Metric, Specialize, + SquaredL2, + }, +}; +use diskann_wide::{ + ARCH, + arch::{Current, FTarget2}, +}; +use half::f16; +use thiserror::Error; + +use crate::{Hidden, layers, num::Bytes}; + +/// A useful trait bound for types compatible with [`Full`]. +/// +/// This encompases *everything* required for `Full: layers::Insert` and can be used as +/// a single bound. +pub trait FullPrecision: bytemuck::Pod + std::fmt::Debug + Send + Sync { + #[doc(hidden)] + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full; + + #[doc(hidden)] + fn __query_distance<'a, V>( + _: Hidden, + full: &'a Full, + query: &'a [Self], + visitor: V, + ) -> ANNResult + where + V: layers::QueryVisitor<'a>; +} + +/// Full-precision data layer. +#[derive(Debug)] +pub struct Full +where + T: 'static, +{ + distance: Distance, + metric: Metric, +} + +impl Full +where + T: 'static, +{ + /// Create a new full-precision layer for data with the given `dim` and `metric`. + pub fn new(dim: usize, metric: Metric) -> Self + where + T: FullPrecision, + { + T::__new(Hidden::new(), dim, metric) + } + + fn from_distance_provider(dim: usize, metric: Metric) -> Self + where + T: DistanceProvider, + { + let distance = Distance { + f: T::distance_comparer(metric, Some(dim)), + dim, + }; + + Self { distance, metric } + } + + /// Return the logical dimension of the data handled by this [`layers::Layer`]. + pub fn dim(&self) -> usize { + self.distance.dim + } + + /// Return the number of bytes of the data handles by this [`layers::Layer`]. + pub fn bytes(&self) -> Bytes { + Bytes::new(self.dim() * std::mem::size_of::()) + } + + fn check_dim(&self, dim: usize) -> Result<(), QueryDistanceError> { + if self.dim() != dim { + Err(QueryDistanceError { + expected: self.dim(), + xlen: dim, + }) + } else { + Ok(()) + } + } +} + +impl layers::Layer for Full +where + T: FullPrecision, +{ + fn bytes(&self) -> Bytes { + >::bytes(self) + } +} + +impl layers::Set<&[T]> for Full +where + T: FullPrecision, +{ + fn set(&self, v: &[T], bytes: &mut [u8]) -> ANNResult<()> { + if v.len() != self.dim() { + Err(ANNError::from(SetError::Dim { + got: v.len(), + expected: self.dim(), + })) + } else if bytes.len() != self.bytes().value() { + Err(ANNError::from(SetError::Bytes { + got: bytes.len(), + expected: self.bytes().value(), + })) + } else { + bytes.copy_from_slice(bytemuck::must_cast_slice::(v)); + Ok(()) + } + } +} + +#[derive(Debug, Error)] +enum SetError { + #[error( + "data of dimension {} does not match full precision layer's dimension {}", + got, + expected + )] + Dim { got: usize, expected: usize }, + #[error( + "raw byte slice of length {} does not match expected length {}", + got, + expected + )] + Bytes { got: usize, expected: usize }, +} + +crate::opaque!(SetError); + +impl layers::AsDistance for Full +where + T: FullPrecision, +{ + fn as_distance(&self) -> &dyn layers::Distance { + &self.distance + } +} + +impl layers::Search for Full +where + T: FullPrecision, +{ + type Query<'a> = &'a [T]; + + fn query_distance<'a, V>(&'a self, query: &'a [T], visitor: V) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + T::__query_distance(Hidden::new(), self, query, visitor) + } +} + +impl layers::Insert for Full where T: FullPrecision {} + +////////////// +// Distance // +////////////// + +#[derive(Debug)] +#[doc(hidden)] +pub struct Distance +where + T: 'static, + U: 'static, +{ + f: distance::Distance, + dim: usize, +} + +impl Clone for Distance { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Distance {} + +impl Distance +where + T: 'static, + U: 'static, +{ + #[cold] + #[inline(never)] + fn error(&self, x: &[u8], y: &[u8]) -> ANNResult { + let error = DistanceError { + expected: self.bytes(), + xlen: x.len(), + ylen: y.len(), + }; + + Err(ANNError::opaque(error)) + } + + fn dim(&self) -> usize { + self.dim + } + + fn bytes(&self) -> usize { + self.dim() * std::mem::size_of::() + } +} + +impl layers::Distance for Distance +where + T: Debug + 'static, +{ + fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult { + let bytes = self.bytes(); + if x.len() != bytes || y.len() != bytes { + self.error(x, y) + } else { + // SAFETY: We've checked that both `x` and `y` are valid for + // `size_of::() * self.dim` bytes. + let ux = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.dim) }; + + // SAFETY: Same as above + let uy = unsafe { UnalignedSlice::new(y.as_ptr().cast::(), self.dim) }; + Ok(self.f.call_unaligned(ux, uy)) + } + } +} + +#[derive(Debug, Error)] +#[error( + "expected slices of length {} - instead got {} and {}", + self.expected, + self.xlen, + self.ylen +)] +struct DistanceError { + expected: usize, + xlen: usize, + ylen: usize, +} + +/////////////////// +// QueryDistance // +/////////////////// + +// A baby [`std::borrow::Cow`]. +#[derive(Debug)] +enum Calf<'a, T> { + Borrowed(&'a [T]), + Owned(Box<[T]>), +} + +impl std::ops::Deref for Calf<'_, T> { + type Target = [T]; + fn deref(&self) -> &Self::Target { + match self { + Self::Borrowed(slice) => slice, + Self::Owned(boxed) => boxed, + } + } +} + +/// A fused query distance based on [`PureDistanceFunction`] to enable inlining of the final +/// distance function (`D`). +/// +/// The type of the embedded query (`T`) is distinct from the expected data-set (`U`) to +/// allow `f16` queries to be pre-converted to `f32`, saving on-the-fly conversion that +/// would otherwise be needed. +#[derive(Debug)] +struct QueryDistance<'a, T, U, D> { + query: Calf<'a, T>, + // The type of the data in the original dataset. + _data: PhantomData, + // The type of the `PureDistanceFunction` used for the implementation. + _distance: PhantomData, +} + +impl<'a, T, U, D> QueryDistance<'a, T, U, D> { + fn new(query: Calf<'a, T>) -> Self { + Self { + query, + _data: PhantomData, + _distance: PhantomData, + } + } + + fn bytes(&self) -> usize { + std::mem::size_of::() * self.query.len() + } + + #[inline(never)] + fn error(&self, len: usize) -> ANNResult { + let error = QueryDistanceError { + expected: self.bytes(), + xlen: len, + }; + + Err(ANNError::opaque(error)) + } +} + +impl layers::QueryDistance for QueryDistance<'_, T, U, D> +where + T: Send + Sync + 'static + Debug, + U: Send + Sync + 'static + Debug, + D: for<'a> FTarget2, UnalignedSlice<'a, U>> + + Send + + Sync + + Debug, +{ + #[inline(always)] + fn evaluate(&self, x: &[u8]) -> ANNResult { + if x.len() != self.bytes() { + self.error(x.len()) + } else { + // SAFETY: We've validated that `x` has the correct length. + let x = unsafe { UnalignedSlice::new(x.as_ptr().cast::(), self.query.len()) }; + Ok(D::run(ARCH, (*self.query).into(), x)) + } + } +} + +#[derive(Debug, Error)] +#[error( + "expected slice of length {} - instead got {}", + self.expected, + self.xlen, +)] +struct QueryDistanceError { + expected: usize, + xlen: usize, +} + +crate::opaque!(QueryDistanceError); + +macro_rules! mint { + ($query:ident, $visitor:ident, $T:ty => { $N:literal, $f:ident }) => {{ + mint!($query, $visitor, { $T, $T } => { $N, $f }) + }}; + ($query:ident, $visitor:ident, { $T:ty, $U:ty } => { $N:literal, $f:ident }) => {{ + let inner = QueryDistance::<$T, $U, Specialize<$N, $f>>::new($query); + $visitor.visit_sized::<{ $N * std::mem::size_of::<$U>() }, _>(inner) + }}; + ($query:ident, $visitor:ident, $T:ty => $f:ident) => {{ + mint!($query, $visitor, { $T, $T } => $f) + }}; + ($query:ident, $visitor:ident, { $T:ty, $U:ty } => $f:ident) => {{ + let inner = QueryDistance::<$T, $U, $f>::new($query); + $visitor.visit(inner) + }}; +} + +impl FullPrecision for f32 { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } + + fn __query_distance<'a, V>( + _: Hidden, + full: &'a Full, + query: &'a [f32], + visitor: V, + ) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + full.check_dim(query.len())?; + + let query = Calf::Borrowed(query); + + let output = match full.metric { + Metric::L2 => { + if full.dim() == 100 { + mint!(query, visitor, f32 => { 100, SquaredL2 }) + } else { + mint!(query, visitor, f32 => SquaredL2) + } + } + Metric::InnerProduct => { + mint!(query, visitor, f32 => InnerProduct) + } + Metric::Cosine => mint!(query, visitor, f32 => Cosine), + Metric::CosineNormalized => mint!(query, visitor, f32 => CosineNormalized), + }; + + Ok(output) + } +} + +impl FullPrecision for f16 { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } + + fn __query_distance<'a, V>( + _: Hidden, + full: &'a Full, + query: &'a [f16], + visitor: V, + ) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + full.check_dim(query.len())?; + + let mut as_f32: Box<[f32]> = std::iter::repeat_n(0.0, full.dim()).collect(); + diskann_wide::arch::dispatch2(SliceCast::new(), &mut *as_f32, query); + let query = Calf::Owned(as_f32); + + let output = match full.metric { + Metric::L2 => { + if full.dim() == 100 { + mint!(query, visitor, { f32, f16 } => { 100, SquaredL2 }) + } else { + mint!(query, visitor, { f32, f16 } => SquaredL2) + } + } + Metric::InnerProduct => mint!(query, visitor, { f32, f16 } => InnerProduct), + Metric::Cosine => mint!(query, visitor, { f32, f16 } => Cosine), + Metric::CosineNormalized => mint!(query, visitor, { f32, f16 } => CosineNormalized), + }; + + Ok(output) + } +} + +impl FullPrecision for u8 { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } + + fn __query_distance<'a, V>( + _: Hidden, + full: &'a Full, + query: &'a [u8], + visitor: V, + ) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + full.check_dim(query.len())?; + + let query = Calf::Borrowed(query); + + let output = match full.metric { + Metric::L2 => { + if full.dim() == 128 { + mint!(query, visitor, u8 => { 128, SquaredL2 }) + } else { + mint!(query, visitor, u8 => SquaredL2) + } + } + Metric::InnerProduct => mint!(query, visitor, u8 => InnerProduct), + Metric::Cosine => mint!(query, visitor, u8 => Cosine), + Metric::CosineNormalized => mint!(query, visitor, u8 => Cosine), + }; + + Ok(output) + } +} + +impl FullPrecision for i8 { + fn __new(_: Hidden, dim: usize, metric: Metric) -> Full { + Full::from_distance_provider(dim, metric) + } + + fn __query_distance<'a, V>( + _: Hidden, + full: &'a Full, + query: &'a [i8], + visitor: V, + ) -> ANNResult + where + V: layers::QueryVisitor<'a>, + { + full.check_dim(query.len())?; + + let query = Calf::Borrowed(query); + + let output = match full.metric { + Metric::L2 => mint!(query, visitor, i8 => SquaredL2), + Metric::InnerProduct => mint!(query, visitor, i8 => InnerProduct), + Metric::Cosine => mint!(query, visitor, i8 => Cosine), + Metric::CosineNormalized => mint!(query, visitor, i8 => Cosine), + }; + + Ok(output) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +#[cfg(not(miri))] +mod tests { + use std::fmt::Display; + + use rand::{Rng, SeedableRng, rngs::StdRng}; + + use super::*; + // Bring the inherent-call traits into method scope. The `Distance` / `QueryDistance` + // traits are not imported: their methods are reached through `&dyn _` trait objects, + // which does not require the trait to be in scope. + use crate::layers::{AsDistance as _, QueryVisitor, Search as _, Set as _}; + + /// Generate random elements of a layer's data type from a seeded RNG. + trait Sample: bytemuck::Pod { + fn sample(rng: &mut R) -> Self; + } + + impl Sample for f32 { + fn sample(rng: &mut R) -> Self { + rng.random_range(-1.0f32..1.0f32) + } + } + + impl Sample for f16 { + fn sample(rng: &mut R) -> Self { + f16::from_f32(rng.random_range(-1.0f32..1.0f32)) + } + } + + impl Sample for u8 { + fn sample(rng: &mut R) -> Self { + rng.random() + } + } + + impl Sample for i8 { + fn sample(rng: &mut R) -> Self { + rng.random() + } + } + + fn gen_vec(rng: &mut R, dim: usize) -> Vec { + (0..dim).map(|_| T::sample(rng)).collect() + } + + /// A [`QueryVisitor`] that simply boxes the minted kernel so the test can probe it + /// directly. Exercises both `visit` (dynamic) and `visit_sized` (specialized) paths. + struct Collect; + + impl<'a> QueryVisitor<'a> for Collect { + type Output = Box; + + fn visit(self, distance: Q) -> Self::Output + where + Q: layers::QueryDistance + 'a, + { + Box::new(distance) + } + } + + /// Compare two distances allowing for floating-point reassociation between the + /// specialized / converted kernels and the dynamic reference. + fn approx_eq(got: f32, want: f32) -> bool { + (got - want).abs() <= 1e-3 + 1e-4 * want.abs() + } + + /// Exercise every `Full` API across dimensions `1..=max_dim`. + /// + /// For each dimension we check that `bytes`/`set` agree, that `distance` and + /// `query_distance` are consistent with `DistanceProvider`, and that all of these + /// reject byte slices that are too long or too short. + fn test_impl(max_dim: usize, ctx: &dyn Display) + where + T: FullPrecision + Sample + DistanceProvider, + { + let mut rng = StdRng::seed_from_u64(0x0D15_0ACE ^ max_dim as u64); + let metrics = [ + Metric::L2, + Metric::InnerProduct, + Metric::Cosine, + Metric::CosineNormalized, + ]; + + for dim in 1..=max_dim { + let a = gen_vec::(&mut rng, dim); + let b = gen_vec::(&mut rng, dim); + + // `bytes` and `set` agree: the encoded buffer equals the raw cast bytes. + let layer = Full::::new(dim, Metric::L2); + assert_eq!( + layer.bytes().value(), + dim * std::mem::size_of::(), + "{ctx}: dim {dim}: unexpected byte length", + ); + + let mut a_bytes = vec![0u8; layer.bytes().value()]; + layer.set(&a, &mut a_bytes).unwrap(); + assert_eq!( + a_bytes.as_slice(), + bytemuck::cast_slice::(&a), + "{ctx}: dim {dim}: set mismatch", + ); + + let mut b_bytes = vec![0u8; layer.bytes().value()]; + layer.set(&b, &mut b_bytes).unwrap(); + + for metric in metrics { + let full = Full::::new(dim, metric); + + // Reference value straight from `DistanceProvider`. + let reference = + >::distance_comparer(metric, Some(dim)).call(&a, &b); + + // `distance` is built from the same comparer, so it must match exactly. + let distance = full.as_distance(); + let via_distance = distance.evaluate(&a_bytes, &b_bytes).unwrap(); + assert_eq!( + via_distance, reference, + "{ctx}: dim {dim}, metric {metric:?}: distance != DistanceProvider", + ); + + // `query_distance` computes the same geometry. Specialized and f16-converted + // kernels may reassociate the summation, so compare approximately. + let query = full.query_distance(a.as_slice(), Collect).unwrap(); + let via_query = query.evaluate(&b_bytes).unwrap(); + assert!( + approx_eq(via_query, via_distance), + "{ctx}: dim {dim}, metric {metric:?}: query {via_query} != distance {via_distance}", + ); + + // Every distance API rejects byte slices that are too long or too short. + let short = &a_bytes[..a_bytes.len() - 1]; + let mut long = a_bytes.clone(); + long.push(0); + + assert!(distance.evaluate(short, &b_bytes).is_err()); + assert!(distance.evaluate(&long, &b_bytes).is_err()); + assert!(distance.evaluate(&a_bytes, short).is_err()); + assert!(distance.evaluate(&a_bytes, &long).is_err()); + + assert!(query.evaluate(short).is_err()); + assert!(query.evaluate(&long).is_err()); + } + + // `set` rejects mis-sized element and buffer slices. + let mut buf = vec![0u8; layer.bytes().value()]; + let too_many = gen_vec::(&mut rng, dim + 1); + assert!( + layer.set(&too_many, &mut buf).is_err(), + "{ctx}: dim {dim}: set accepted an over-long element slice", + ); + + assert!( + layer.query_distance(&too_many, Collect).is_err(), + "{ctx}: dim {dim}: incorrect query lengths should be rejected" + ); + + let mut short_buf = vec![0u8; layer.bytes().value().saturating_sub(1)]; + assert!( + layer.set(&a, &mut short_buf).is_err(), + "{ctx}: dim {dim}: set accepted an under-sized buffer", + ); + + let too_few = gen_vec::(&mut rng, dim - 1); + assert!( + layer.query_distance(&too_few, Collect).is_err(), + "{ctx}: dim {dim}: incorrect query lengths should be rejected" + ); + } + } + + // `max_dim` must exceed the largest specialized dimension for each type so the + // const-generic (`visit_sized`) paths are covered alongside the dynamic ones. + #[test] + fn full_f32() { + test_impl::(256, &"f32"); + } + + #[test] + fn full_f16() { + test_impl::(256, &"f16"); + } + + #[test] + fn full_u8() { + test_impl::(160, &"u8"); + } + + #[test] + fn full_i8() { + test_impl::(160, &"i8"); + } +} diff --git a/diskann-inmem/src/layers/mod.rs b/diskann-inmem/src/layers/mod.rs new file mode 100644 index 000000000..13179d9bc --- /dev/null +++ b/diskann-inmem/src/layers/mod.rs @@ -0,0 +1,127 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Distance layers indexing. +//! +//! An important assumption made by this module is that the data within each layer is +//! uniformly sized: each entry occupies the same number of bytes. Furthermore, the data +//! to be stored may not assume any particular alignment. Implementations will strive to +//! achieve a reasonable alignment, but this may not be relied on. +//! +//! # Query Distance Specialization +//! +//! The design of this module allows aggressive optimization of graph search kernels via +//! the [`Search`] and [`QueryVisitor`] pairs of traits. +//! +//! Implementations of [`Search`] can pass a [`QueryDistance`] kernel specialized to +//! a specific geometry (dimensionality or metric type) which upstream [`QueryVisitor`] +//! will fuse into larger kernels. While this allows for high performance graph kernels, +//! some considerations should be taken into account: +//! +//! 1. For correctness purposes, upstream callers cannot do any kind of caching. As such, +//! the dispatch layer used to select the kernel passed to the [`QueryVisitor`] should +//! be relatively efficient. +//! +//! 2. Keep the number of specializations bounded for compile time reasons. + +use diskann::ANNResult; + +use crate::num::Bytes; + +mod full; +pub use full::{Full, FullPrecision}; + +/// Base layer for data representations. +pub trait Layer: Send + Sync + 'static { + /// Return the number of bytes needed by this layer representation. + /// + /// To be well-behaved, this function must be idempotent. + fn bytes(&self) -> Bytes; +} + +/// Store an element of type `T` into a raw byte buffer. +/// +/// Implementations may assume that `bytes.len()` is equal to [`Layer::bytes`]. +pub trait Set: Layer { + /// Write into the stored representation. + fn set(&self, element: T, bytes: &mut [u8]) -> ANNResult<()>; +} + +/// A distance computation on raw byte slices. +/// +/// When paired with [`Layer`] via helpers like [`AsDistance`], implementations may assume +/// that `x` and `y` have length [`Layer::bytes`]. +/// +/// No alignment guarantees are made for `x` and `y`, though in practice they are likely +/// to be aligned to 32 or 64 bytes. +pub trait Distance: Send + Sync + std::fmt::Debug { + fn evaluate(&self, x: &[u8], y: &[u8]) -> ANNResult; +} + +/// Return a [`Distance`] function for a [`Layer`]. +pub trait AsDistance: Send + Sync + std::fmt::Debug { + fn as_distance(&self) -> &dyn Distance; +} + +/// A unary query distance on raw byte slices. +/// +/// When paired with [`Layer`] via helpers like [`Search`], implementations may assume +/// that `x` has length [`Layer::bytes`]. +/// +/// No alignment guarantees are made for `x`, though in practice it is likely to be +/// aligned to 32 or 64 bytes. +pub trait QueryDistance: Send + Sync + std::fmt::Debug { + fn evaluate(&self, x: &[u8]) -> ANNResult; +} + +/// Enable search over vectors defined by a [`Layer`]. +pub trait Search: Send + Sync + 'static { + /// The type of the query. This should be equivalent to the generic parameter in + /// [`Set`], but needs to be replicated here due to limitations in the current trait + /// design. + type Query<'a>; + + /// Create a distance computer specialized for `query` and provide it to `visitor`. + fn query_distance<'a, V>(&'a self, query: Self::Query<'a>, visitor: V) -> ANNResult + where + V: QueryVisitor<'a>; +} + +/// Specialize a kernel around a [`QueryDistance`] implementation. +pub trait QueryVisitor<'a>: Sized { + /// The type of the type-erased output. + type Output; + + /// Specialize [`Self::Output`] for `distance`. + fn visit(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a; + + /// Specialize [`Self::Output`] for `distance` accepting a hint that `distance` has been + /// specialized to work on data elements of exactly `BYTES` bytes long. + /// + /// This can be used to tailor surrounding code (e.g. software prefetches) for exactly + /// the length of the data being processed. + fn visit_sized(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a, + { + self.visit(distance) + } +} + +/// A insert-specific specialization of [`Search`]. +/// +/// Note that the bounds for this trait are unnecessarily complicated, but rely on changes +/// to `diskann` to full resolve. +pub trait Insert: Search + for<'a> Set> + AsDistance { + /// A specialization of [`Search::query_distance`] targeting vector insert specifically. + fn insert_distance<'a, V>(&'a self, query: Self::Query<'a>, visitor: V) -> ANNResult + where + V: QueryVisitor<'a>, + { + self.query_distance(query, visitor) + } +} diff --git a/diskann-inmem/src/lib.rs b/diskann-inmem/src/lib.rs new file mode 100644 index 000000000..172a7478d --- /dev/null +++ b/diskann-inmem/src/lib.rs @@ -0,0 +1,63 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! The inmem index for DiskANN. + +#![deny(rustdoc::broken_intra_doc_links)] + +pub mod num; + +mod buffer; +mod counters; +mod epoch; +mod freelist; +mod neighbors; +mod sharded; +mod tag; + +mod store; + +pub mod layers; +pub mod provider; + +pub use provider::{Context, Provider, Strategy}; + +#[cfg(test)] +mod test; + +#[cfg(feature = "integration-test")] +#[doc(hidden)] +pub mod integration; + +//----------------// +// Internal Tools // +//----------------// + +/// A "public" type that can only be constructed by this crate. +/// +/// This helps with public traits with internal methods that we don't want users to call. +#[doc(hidden)] +#[derive(Debug)] +pub struct Hidden(()); + +impl Hidden { + const fn new() -> Self { + Self(()) + } +} + +macro_rules! opaque { + ($T:ty) => { + impl From<$T> for diskann::ANNError { + #[track_caller] + #[cold] + fn from(err: $T) -> diskann::ANNError { + diskann::ANNError::opaque(err) + } + } + }; +} + +pub(crate) use opaque; diff --git a/diskann-inmem/src/neighbors.rs b/diskann-inmem/src/neighbors.rs new file mode 100644 index 000000000..14600502e --- /dev/null +++ b/diskann-inmem/src/neighbors.rs @@ -0,0 +1,718 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! A Concurrent Graph Structure +//! +//! The [`Neighbors`] data structure is a concurrent graph managed out of a single allocation. +//! The use of a single allocation puts a hard upper-bound on the length each adjacency list, +//! which is enforced by the types in this module. +//! +//! Concurrency is obtained using sharded read/write locks, with [`Neighbors::get`] and +//! [`Neighbors::set`] acquiring read and write locks respectively. +//! +//! To implement atomic read-modify-write operations, [`Neighbors::lock`] can be used to +//! obtain a [`Lock`]ed list. +//! +//! Due to lock sharding, attempting to acquire multiple [`Lock`]s to a single [`Neighbors`] +//! simultaneously can lead to dead-lock. +//! +//! ## Performance Considerations +//! +//! Adjacency lists written through the APIs exposed in this module are not validated for +//! uniqueness nor for being in-bounds. These are the caller's responsibility. + +use std::ptr::NonNull; + +use diskann::{graph::AdjacencyList, utils::IntoUsize}; +use parking_lot::{RwLock, RwLockWriteGuard}; +use thiserror::Error; + +use crate::{ + buffer::{Buffer, BufferError}, + num::{Align, Bytes}, +}; + +type Id = u32; + +/// Locks are shared among groups of adjacency lists. +/// +/// Adjacency lists whose indices map to the same lock group (i.e. `i / LOCK_GRANULARITY`) +/// share a single `RwLock`. This means that holding a [`Lock`] on slot `i` will also block +/// operations on any slot `j` in the same group. +/// +/// **Deadlock hazard**: attempting to acquire two [`Lock`]s simultaneously can deadlock if +/// they fall in the same lock group — or even across groups, depending on acquisition order. +/// Callers must not hold more than one [`Lock`] at a time. +const LOCK_GRANULARITY: usize = 16; + +fn lock_index(i: u32) -> usize { + i.into_usize() / LOCK_GRANULARITY +} + +/// A concurrent graph data structure with a fixed number of adjacency lists and a fixed +/// upper-bound for each adjacency list's length. See the [module level docs](self) for +/// more detail. +/// +/// Adjacency lists are indexed by `[0, Neighbors::entries)`. +#[derive(Debug)] +pub(crate) struct Neighbors { + neighbors: Buffer, + locks: Vec>, +} + +impl Neighbors { + /// Construct a new [`Neighbors`] capable of holding `entries` adjacency lists with a + /// maximum length of `max_length`. + /// + /// # Errors + /// + /// Returns an error if `(max_length + 1) * size_of::()` overflows `usize` + /// (unreachable on 64-bit targets) or the resulting allocation would exceed + /// `isize::MAX` bytes. + pub(crate) fn new(entries: u32, max_length: u32) -> Result { + let bytes = max_length + .into_usize() + .checked_add(1) + .and_then(|len| len.checked_mul(std::mem::size_of::())) + .map(Bytes::new) + .ok_or(NeighborsError::Overflow(max_length))?; + + // We materialize slices of `Id` into the raw byte buffers. + // + // To make this sound, the base allocation must be that of `Id` so the slice + // materialization is properly aligned. + const ALIGN: Align = Align::_128; + const { + assert!( + ALIGN.value() >= Align::of::().value(), + "buffer alignment must be at least that of the ID" + ); + } + + let neighbors = Buffer::new(entries.into_usize(), bytes, ALIGN)?; + + let locks = std::iter::repeat_with(|| RwLock::new(())) + .take(entries.into_usize().div_ceil(LOCK_GRANULARITY)) + .collect(); + + Ok(Self { neighbors, locks }) + } + + /// Return the maximum length for any adjacency list. + pub(crate) fn max_length(&self) -> usize { + // We reserve 4 bytes at the beginning for the length of the adjacency list. + (self.neighbors.stride().value() - std::mem::size_of::()) / std::mem::size_of::() + } + + /// Return the maximum length for any adjacency list as a 32-bit integer. + pub(crate) fn max_length_u32(&self) -> u32 { + // Lossless by the invariants on `Self::new`. + self.max_length() as u32 + } + + /// Return the number of adjacency lists contained by this graph. + pub(crate) fn entries(&self) -> u32 { + // Cast is lossless by construction. + self.neighbors.len() as u32 + } + + /// Copy the contents of adjacency list `i` into `neighbors`. + /// + /// Returns an error if `i` exceeds [`Self::entries`]. + pub(crate) fn get( + &self, + i: u32, + neighbors: &mut AdjacencyList, + ) -> Result<(), OutOfBounds> { + self.check(i)?; + + // SAFETY: We've checked that `i` is in-bounds. + let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }; + + let _guard = lock.read(); + + // SAFETY: By construction `self.buffer` has the same number of entries as + // `self.locks` and we have already checked that `i` is in-bounds there. + let (prefix, rest) = + unsafe { self.neighbors.get_unchecked(i.into_usize()) }.split(Bytes::size_of::()); + + debug_assert_eq!(prefix.len(), Bytes::size_of::()); + debug_assert!(prefix.as_ptr().cast::().is_aligned()); + + // SAFETY: We hold the read-lock, so reading is safe. From our bounds checks, we + // know that this pointer is valid. + let len: usize = unsafe { prefix.as_ptr().cast::().read() } + .min(self.max_length_u32()) + .into_usize(); + + let mut resizer = neighbors.resize(len); + + // SAFETY: We've validated that the two slices are valid. They cannot overlap + // because `neighbors` is provided externally by exclusive reference. + unsafe { + std::ptr::copy_nonoverlapping( + rest.as_mut_ptr(), + resizer.as_mut_ptr().cast::(), + len * std::mem::size_of::(), + ) + }; + resizer.finish(len); + Ok(()) + } + + /// Lock adjacency list `i` for read-modify-write operations. + /// + /// Returns an error if `i` exceeds [`Self::entries`]. + pub(crate) fn lock(&self, i: u32) -> Result, OutOfBounds> { + self.check(i)?; + + // SAFETY: `i` is in-bounds. + Ok(unsafe { self.lock_unchecked(i) }) + } + + /// Lock adjacency-list `i` without bounds-checking. + /// + /// # SAFETY + /// + /// `i` must be in-bounds. + unsafe fn lock_unchecked(&self, i: u32) -> Lock<'_> { + // SAFETY: `i` is in-bounds. + let lock = unsafe { self.locks.get_unchecked(lock_index(i)) }.write(); + + // SAFETY: By construction `self.buffer` has the same number of entries as + // `self.locks` and we have already checked that `i` is in-bounds there. + let slice = unsafe { self.neighbors.get_unchecked(i.into_usize()) }; + + debug_assert!(slice.as_ptr().cast::().is_aligned()); + + Lock { + ptr: slice.as_non_null().cast::(), + capacity: self.max_length().into_usize(), + _lock: lock, + } + } + + /// Overwrite the contents of adjacency list `i` with `neighbors`. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * `i` exceeds [`Self::entries`]. + /// * `neighbors.len()` exceeds [`Self::max_length_u32`]. + /// + /// If an error is returned, the graph is left unmodified. + pub(crate) fn set(&self, i: u32, neighbors: &[u32]) -> Result<(), SetError> { + self.check(i).map_err(SetError::OutOfBounds)?; + + // We can check the length of `neighbors` before acquiring any locks as an early exit. + if neighbors.len() > self.max_length().into_usize() { + return Err(SetError::TooLong(TooLong { + got: neighbors.len(), + max: self.max_length_u32(), + })); + } + + // SAFETY: We've checked `i` is in-bounds. + let lock = unsafe { self.lock_unchecked(i) }; + + // SAFETY: `neighbors.len() <= self.max_length()`. + unsafe { lock.write_unchecked(neighbors) }; + Ok(()) + } + + fn check(&self, i: u32) -> Result<(), OutOfBounds> { + if i >= self.entries() { + Err(OutOfBounds(i)) + } else { + Ok(()) + } + } +} + +/// Errors returned by [`Neighbors::new`]. +#[derive(Debug, Error)] +pub(crate) enum NeighborsError { + /// Computing the per-list byte size `(max_length + 1) * size_of::()` overflowed + /// `usize`. + /// + /// Unreachable on 64-bit targets. + #[error("adjacency list length of {0} is too long")] + Overflow(u32), + + /// Allocation of the underlying buffer failed. + /// + /// This can occur if the total allocation size (`entries * per-list bytes`) + /// would exceed `isize::MAX`, or if the underlying allocator returns an error. + #[error("neighbor buffer allocation failed")] + AllocationFailed(#[from] BufferError), +} + +/// Attempted to access a [`Neighbors`] at an out-of-bounds index. +#[derive(Debug, Clone, Copy, Error)] +#[error("index {} is out-of-bounds", self.0)] +pub(crate) struct OutOfBounds(u32); + +crate::opaque!(OutOfBounds); + +/// A neighbor list was longer than the configured per-list capacity. +/// +/// `got` is the caller-supplied length (any `usize`); `max` is the per-list capacity, +/// which is bounded by `u32` per [`Neighbors::new`]. +#[derive(Debug, Clone, Copy, Error)] +#[error("length {} exceeds the max length {}", self.got, self.max)] +pub(crate) struct TooLong { + got: usize, + max: u32, +} + +crate::opaque!(TooLong); + +/// Errors during [`Neighbors::set`]. +#[derive(Debug, Clone, Copy, Error)] +pub(crate) enum SetError { + /// Attempted to access an out-of-bounds index. + #[error(transparent)] + OutOfBounds(OutOfBounds), + + /// The new adjacency list was too long. + #[error(transparent)] + TooLong(TooLong), +} + +crate::opaque!(SetError); + +/// A locked adjacency list to implement atomic read-modify-write operations. +/// +/// Callers must not hold more than one `Lock` at a time. See [`LOCK_GRANULARITY`] for +/// details on the deadlock hazard. +pub(crate) struct Lock<'a> { + ptr: NonNull, + capacity: usize, + _lock: RwLockWriteGuard<'a, ()>, +} + +impl Lock<'_> { + /// Return the capacity of the neighbor buffer. + pub(crate) fn capacity(&self) -> usize { + self.capacity + } + + /// Return the current length of the neighbor list. + /// + /// This is guaranteed to be less than or equal to [`capacity`](Self::capacity). + pub(crate) fn len(&self) -> usize { + // SAFETY: By construction, `self.raw` has a length of at least 1. + // + // The `min` operation defensively clamps in case the stored length has been + // corrupted; under normal operation it should already be `<= capacity`. + unsafe { self.ptr.read() }.into_usize().min(self.capacity()) + } + + /// Consume `self`, appending `neighbors` to the list. + /// + /// Returns an error if the concatenated list would exceed [`Self::capacity`] without + /// modify the adjacency list. + /// + /// This method does not attempt to deduplicate `neighbors`. + pub(crate) fn append(self, neighbors: &[u32]) -> Result<(), TooLong> { + let len = self.len(); + let newlen = len.saturating_add(neighbors.len()); + + if newlen > self.capacity() { + return Err(TooLong { + got: newlen, + max: self.capacity as u32, + }); + } + + // SAFETY: We've verified that both regions are in-bounds. + // + // The slices have to be disjoint because `self` effectively owns its data while + // it is alive and this method receives by-value. + unsafe { + std::ptr::copy_nonoverlapping( + neighbors.as_ptr(), + self.ptr.add(len + 1).as_ptr(), + neighbors.len(), + ) + } + + // SAFETY: `self.ptr` is guaranteed to be valid for at least 4-bytes, and we own the + // underlying data until `drop`. + unsafe { self.ptr.write(newlen as u32) }; + + Ok(()) + } + + /// Write the contents of `neighbors` into `self` without validating lenghts. + /// + /// # Safety + /// + /// `neighbors.len() <= self.capacity()`. + unsafe fn write_unchecked(self, neighbors: &[u32]) { + let len = neighbors.len(); + debug_assert!(len <= self.capacity()); + + // SAFETY: the caller asserts that the pointer arithmetic is sound. + // + // The slices are disjoint because `self` owns its data and this method receives + // by value. + unsafe { std::ptr::copy_nonoverlapping(neighbors.as_ptr(), self.ptr.as_ptr().add(1), len) } + + // SAFETY: `self.ptr` is guaranteed to be valid for at least 4-bytes, and we own the + // underlying data until `drop`. + unsafe { self.ptr.write(len as u32) }; + } + + #[cfg(test)] + fn as_slice(&self) -> &[u32] { + let len = self.len(); + + // SAFETY: by construction - this access is in-bounds and `Lock` has exclusive + // access too its data, so we're free to hand out a raw slice. + unsafe { std::slice::from_raw_parts(self.ptr.add(1).as_ptr().cast_const(), len) } + } + + #[cfg(test)] + fn write(self, neighbors: &[u32]) -> Result<(), TooLong> { + if neighbors.len() > self.capacity() { + return Err(TooLong { + got: neighbors.len(), + max: self.capacity as u32, + }); + } + + // SAFETY: We've checked that `neighbors.len() <= self.capacity()`. + unsafe { self.write_unchecked(neighbors) }; + Ok(()) + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl std::fmt::Debug for Lock<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Lock") + .field("ptr", &self.ptr) + .field("capacity", &self.capacity) + .field("lock", &()) + .finish() + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test::Sequencer; + + // -- OutOfBounds checks -- + + #[test] + fn out_of_bounds_rejects_indices_beyond_entries() { + let n = Neighbors::new(4, 4).unwrap(); + // entries == 4, so valid indices are 0..=3. + // Regression test: a buggy `check` using `i == entries()` would let + // `entries+1`, `entries+2`, ... slip through to UB. + let mut out = AdjacencyList::with_capacity(4); + for bad in [4u32, 5, 100, u32::MAX] { + assert!(matches!(n.get(bad, &mut out), Err(OutOfBounds(_)))); + assert!(matches!(n.set(bad, &[]), Err(SetError::OutOfBounds(_)))); + assert!(matches!(n.lock(bad), Err(OutOfBounds(_)))); + } + } + + #[test] + fn empty_neighbors_rejects_all_access() { + let n = Neighbors::new(0, 4).unwrap(); + let mut out = AdjacencyList::with_capacity(4); + for i in [0u32, 1, u32::MAX] { + assert!(matches!(n.get(i, &mut out), Err(OutOfBounds(_)))); + assert!(matches!(n.set(i, &[]), Err(SetError::OutOfBounds(_)))); + assert!(matches!(n.lock(i), Err(OutOfBounds(_)))); + } + } + + // TooLong errors + + #[test] + fn set_rejects_oversized_neighbors() { + let n = Neighbors::new(4, 3).unwrap(); + let too_many = &[1, 2, 3, 4]; + assert!(matches!(n.set(0, too_many), Err(SetError::TooLong(_)))); + } + + #[test] + fn lock_write_rejects_oversized_neighbors() { + let n = Neighbors::new(4, 3).unwrap(); + let lock = n.lock(0).unwrap(); + assert!(lock.write(&[1, 2, 3, 4]).is_err()); + } + + #[test] + fn lock_append_rejects_overflow() { + let n = Neighbors::new(4, 3).unwrap(); + n.set(0, &[1, 2]).unwrap(); + let lock = n.lock(0).unwrap(); + assert!(lock.append(&[3, 4]).is_err()); + } + + #[test] + fn lock_implements_debug() { + let n = Neighbors::new(4, 3).unwrap(); + let lock = n.lock(0).unwrap(); + let _ = format!("{:?}", lock); + } + + // -- Lock::append -- + + #[test] + fn append_preserves_existing_and_adds_new() { + let n = Neighbors::new(4, 6).unwrap(); + n.set(0, &[10, 20]).unwrap(); + + let lock = n.lock(0).unwrap(); + assert_eq!(lock.as_slice(), &[10, 20]); + lock.append(&[30, 40, 50]).unwrap(); + + let mut out = AdjacencyList::with_capacity(6); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[10, 20, 30, 40, 50]); + } + + #[test] + fn append_to_empty() { + let n = Neighbors::new(4, 4).unwrap(); + + let lock = n.lock(0).unwrap(); + assert_eq!(lock.as_slice(), &[]); + lock.append(&[1, 2, 3]).unwrap(); + + let mut out = AdjacencyList::with_capacity(4); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[1, 2, 3]); + } + + #[test] + fn append_fills_to_capacity() { + let n = Neighbors::new(1, 3).unwrap(); + n.set(0, &[1]).unwrap(); + + let lock = n.lock(0).unwrap(); + lock.append(&[2, 3]).unwrap(); + + let mut out = AdjacencyList::with_capacity(3); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[1, 2, 3]); + } + + #[test] + fn append_empty_slice_is_noop() { + let n = Neighbors::new(1, 4).unwrap(); + n.set(0, &[10, 20]).unwrap(); + + let lock = n.lock(0).unwrap(); + lock.append(&[]).unwrap(); + + let mut out = AdjacencyList::with_capacity(4); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[10, 20]); + } + + #[test] + fn write_overwrites_longer_list() { + let n = Neighbors::new(1, 5).unwrap(); + n.set(0, &[1, 2, 3, 4, 5]).unwrap(); + + // Overwrite with a shorter list. + let lock = n.lock(0).unwrap(); + assert_eq!(lock.len(), 5); + lock.write(&[99]).unwrap(); + + // The length must reflect the new shorter list, not the old one. + let mut out = AdjacencyList::with_capacity(5); + n.get(0, &mut out).unwrap(); + assert_eq!(&*out, &[99]); + } + + // Clear the adjacency list in `neighbors`. + // + // Receives by `&mut` to ensure exclusivity. + fn clear(neighbors: &mut Neighbors) { + for i in 0..neighbors.entries() { + neighbors.set(i, &[]).unwrap(); + } + + assert_is_cleared(neighbors); + } + + fn assert_is_cleared(neighbors: &mut Neighbors) { + for i in 0..neighbors.entries() { + assert!(neighbors.lock(i).unwrap().is_empty()); + } + } + + #[test] + fn basic_test() { + let mut neighbors = Neighbors::new(10, 4).unwrap(); + assert_eq!(neighbors.entries(), 10); + assert_eq!(neighbors.max_length(), 4); + + let mut list = AdjacencyList::new(); + for i in 0..neighbors.entries() { + list.clear(); + list.extend_from_slice(&[1, 2, 3, 4]); + neighbors.get(i, &mut list).unwrap(); + assert!(list.is_empty()); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.capacity(), neighbors.max_length()); + assert_eq!(lock.len(), 0); + assert!(lock.is_empty()); + assert_eq!(lock.as_slice(), &[]); + } + + // Verify out-of-bounds accesses error. + let oob = neighbors.entries(); + assert!(matches!(neighbors.get(oob, &mut list), Err(OutOfBounds(_)))); + assert!(matches!(neighbors.lock(oob), Err(OutOfBounds(_)))); + assert!(matches!( + neighbors.set(oob, &[1, 2, 3, 4, 5, 6]), + Err(SetError::OutOfBounds(_)) + )); + + let generate = + |round: u32, entry: u32| -> Vec { (0..(round + 1)).map(|r| entry + r).collect() }; + + // Test mutation via `Neighbors::set`. + for round in 0..neighbors.max_length_u32() { + for i in 0..neighbors.entries() { + let v = generate(round, i); + neighbors.set(i, &v).unwrap(); + } + + for i in 0..neighbors.entries() { + let expected = generate(round, i); + neighbors.get(i, &mut list).unwrap(); + assert_eq!(&*list, &*expected); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.as_slice(), &*expected); + } + } + + clear(&mut neighbors); + + // Test mutation via `lock + write`. + for round in 0..neighbors.max_length_u32() { + for i in 0..neighbors.entries() { + let v = generate(round, i); + neighbors.lock(i).unwrap().write(&v).unwrap(); + } + + for i in 0..neighbors.entries() { + let expected = generate(round, i); + neighbors.get(i, &mut list).unwrap(); + assert_eq!(&*list, &*expected); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.as_slice(), &*expected); + } + } + + clear(&mut neighbors); + + // Test mutation via `lock + append`. + for round in 0..neighbors.max_length_u32() { + for i in 0..neighbors.entries() { + neighbors.lock(i).unwrap().append(&[round + i]).unwrap(); + } + + for i in 0..neighbors.entries() { + let expected = generate(round, i); + + neighbors.get(i, &mut list).unwrap(); + assert_eq!(&*list, &*expected); + + let lock = neighbors.lock(i).unwrap(); + assert_eq!(lock.as_slice(), &*expected); + } + } + + clear(&mut neighbors); + } + + //-------------------// + // Concurrency Tests // + //-------------------// + + // Verify that holding a `Lock` correctly blocks reads for the same adjacency list. + #[test] + fn lock_blocks_get() { + for _ in 0..10 { + let neighbors = Neighbors::new(3, 4).unwrap(); + let seq = Sequencer::new(); + + std::thread::scope(|s| { + let handle = s.spawn(|| { + seq.wait_for(0); + let mut list = AdjacencyList::new(); + neighbors.get(0, &mut list).unwrap(); + list + }); + + seq.until_waiting_for(0); + let lock = neighbors.lock(0).unwrap(); + seq.advance_past(0); + + lock.write(&[1, 2, 3, 4]).unwrap(); + let list = handle.join().unwrap(); + assert_eq!(&*list, &[1, 2, 3, 4]); + }); + } + } + + #[test] + fn many_appends() { + let max_length = if cfg!(miri) { 100 } else { 1000 }; + + let neighbors = Neighbors::new(1, max_length).unwrap(); + + let num_threads = 4; + let barrier = std::sync::Barrier::new(num_threads); + + std::thread::scope(|s| { + let neighbors_ref = &neighbors; + let barrier_ref = &barrier; + + for thread_id in 0..num_threads { + s.spawn(move || { + barrier_ref.wait(); + let mut i = thread_id as u32; + let upper = neighbors_ref.max_length() as u32; + while i < upper { + neighbors_ref.lock(0).unwrap().append(&[i]).unwrap(); + i += num_threads as u32; + } + }); + } + }); + + let mut list = AdjacencyList::new(); + let expected: Vec<_> = (0..neighbors.max_length()).map(|i| i as u32).collect(); + neighbors.get(0, &mut list).unwrap(); + list.sort(); + + assert_eq!(&*list, &*expected); + } +} diff --git a/diskann-inmem/src/num.rs b/diskann-inmem/src/num.rs new file mode 100644 index 000000000..98c20d82b --- /dev/null +++ b/diskann-inmem/src/num.rs @@ -0,0 +1,327 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::num::NonZeroUsize; + +/// An unsigned number of bytes. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Bytes(usize); + +impl Bytes { + /// The approximate number of bytes in a CPU cache line. + pub const CACHELINE: Self = Self::new(64); + + /// Zero bytes. + pub const ZERO: Self = Self::new(0); + + /// Construct a new [`Bytes`]. + #[inline] + pub const fn new(bytes: usize) -> Self { + Self(bytes) + } + + /// Return the current value of `self`. + #[inline] + pub const fn value(self) -> usize { + self.0 + } + + /// Add `self` and `other`, returning `None` if the sum would overflow `usize`. + #[inline] + pub(crate) const fn checked_add(self, other: Bytes) -> Option { + match self.value().checked_add(other.value()) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + /// Multiply `self` and `other`, returning `None` if the sum would overflow `usize`. + #[inline] + pub(crate) const fn checked_mul(self, other: usize) -> Option { + match self.value().checked_mul(other) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + /// Perform integer division of `self` by `other`. + #[inline] + pub(crate) const fn div(self, other: NonZeroUsize) -> Bytes { + Bytes::new(self.value() / other.get()) + } + + /// Subtract `other` from `self` without checking for underflow. + #[inline] + pub(crate) const fn unchecked_sub(self, other: Bytes) -> Bytes { + Self::new(self.value() - other.value()) + } + + /// Return the smallest multiple of `other` greater-than or equal to `self`. + /// + /// Returns `None` if the next multiple exceeds `usize::MAX`. + #[inline] + pub(crate) const fn checked_next_multiple_of(self, other: Bytes) -> Option { + match self.value().checked_next_multiple_of(other.value()) { + Some(v) => Some(Bytes::new(v)), + None => None, + } + } + + /// Return the size of `T` in [`Bytes`]. + #[inline] + pub const fn size_of() -> Self { + Self::new(std::mem::size_of::()) + } + + /// Return `true` if `self` is zero. + pub const fn is_zero(self) -> bool { + self.0 == 0 + } +} + +impl std::fmt::Display for Bytes { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} bytes", self.value()) + } +} + +/// An alignment for an allocation. +/// +/// All alignments are guaranteed to be powers of two. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Align(NonZeroUsize); + +impl Align { + /// Construct a new [`Align`] from `value`, returning `None` if `value` is not a power + /// of two. + pub const fn new(value: usize) -> Option { + match NonZeroUsize::new(value) { + Some(value) => { + if value.is_power_of_two() { + Some(Self(value)) + } else { + None + } + } + None => None, + } + } + + /// Return the raw value of `self`. + pub const fn value(self) -> usize { + self.0.get() + } + + /// Construct a new [`Align`] with the raw `value`. + /// + /// # Safety + /// + /// `value` must be a power of two. + pub const unsafe fn new_unchecked(value: usize) -> Self { + debug_assert!(value.is_power_of_two()); + + // SAFETY: powers of two must be non-zero. + Self(unsafe { NonZeroUsize::new_unchecked(value) }) + } + + /// Return the alignment of a type `T`. + pub const fn of() -> Self { + // SAFETY: `std::mem::align_of` is guaranteed to return a power of 2. + unsafe { Self::new_unchecked(std::mem::align_of::()) } + } + + /// Construct a new [`Align`] from a [`std::alloc::Layout`]. + pub const fn from_layout(layout: std::alloc::Layout) -> Self { + // SAFETY: `Layout::align` is guaranteed to be a power of 2. + unsafe { Self::new_unchecked(layout.align()) } + } + + // Constants. + pub const _1: Self = Self::new(1).unwrap(); + pub const _2: Self = Self::new(2).unwrap(); + pub const _4: Self = Self::new(4).unwrap(); + pub const _8: Self = Self::new(8).unwrap(); + pub const _16: Self = Self::new(16).unwrap(); + pub const _32: Self = Self::new(32).unwrap(); + pub const _64: Self = Self::new(64).unwrap(); + pub const _128: Self = Self::new(128).unwrap(); + pub const _256: Self = Self::new(256).unwrap(); + pub const _512: Self = Self::new(512).unwrap(); + pub const _1024: Self = Self::new(1024).unwrap(); + pub const _2048: Self = Self::new(2048).unwrap(); + pub const _4096: Self = Self::new(4096).unwrap(); +} + +impl std::fmt::Display for Align { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_and_value_roundtrip() { + assert_eq!(Bytes::new(42).value(), 42); + assert_eq!(Bytes::new(0).value(), 0); + } + + #[test] + fn cacheline_constant() { + assert_eq!(Bytes::CACHELINE, Bytes::new(64)); + } + + #[test] + fn size_of_returns_correct_size() { + assert_eq!(Bytes::size_of::(), Bytes::new(1)); + assert_eq!(Bytes::size_of::(), Bytes::new(8)); + assert_eq!(Bytes::size_of::<[u8; 128]>(), Bytes::new(128)); + } + + #[test] + fn checked_add_success() { + assert_eq!( + Bytes::new(10).checked_add(Bytes::new(20)), + Some(Bytes::new(30)) + ); + } + + #[test] + fn checked_add_overflow() { + assert_eq!(Bytes::new(usize::MAX).checked_add(Bytes::new(1)), None); + } + + #[test] + fn checked_mul_success() { + assert_eq!(Bytes::new(64).checked_mul(4), Some(Bytes::new(256))); + } + + #[test] + fn checked_mul_overflow() { + assert_eq!(Bytes::new(usize::MAX).checked_mul(2), None); + } + + #[test] + fn checked_mul_by_zero() { + assert_eq!(Bytes::new(100).checked_mul(0), Some(Bytes::new(0))); + } + + #[test] + fn unchecked_sub() { + assert_eq!( + Bytes::new(100).unchecked_sub(Bytes::new(30)), + Bytes::new(70) + ); + } + + #[test] + fn checked_next_multiple_of_already_aligned() { + assert_eq!( + Bytes::new(128).checked_next_multiple_of(Bytes::new(64)), + Some(Bytes::new(128)) + ); + } + + #[test] + fn checked_next_multiple_of_rounds_up() { + assert_eq!( + Bytes::new(100).checked_next_multiple_of(Bytes::new(64)), + Some(Bytes::new(128)) + ); + } + + #[test] + fn checked_next_multiple_of_overflow() { + assert_eq!( + Bytes::new(usize::MAX).checked_next_multiple_of(Bytes::new(2)), + None + ); + } + + #[test] + fn ordering() { + assert!(Bytes::new(10) < Bytes::new(20)); + assert!(Bytes::new(20) > Bytes::new(10)); + assert_eq!(Bytes::new(5), Bytes::new(5)); + } + + #[test] + fn display() { + assert_eq!(format!("{}", Bytes::new(256)), "256 bytes"); + } + + // Align tests + + #[test] + fn align_new_power_of_two() { + assert_eq!(Align::new(1).unwrap().value(), 1); + assert_eq!(Align::new(2).unwrap().value(), 2); + assert_eq!(Align::new(64).unwrap().value(), 64); + assert_eq!(Align::new(4096).unwrap().value(), 4096); + } + + #[test] + fn align_new_rejects_zero() { + assert!(Align::new(0).is_none()); + } + + #[test] + fn align_new_rejects_non_power_of_two() { + assert!(Align::new(3).is_none()); + assert!(Align::new(5).is_none()); + assert!(Align::new(6).is_none()); + assert!(Align::new(100).is_none()); + } + + #[test] + fn align_of_matches_std() { + assert_eq!(Align::of::<()>().value(), 1); + assert_eq!(Align::of::().value(), std::mem::align_of::()); + assert_eq!(Align::of::().value(), std::mem::align_of::()); + assert_eq!(Align::of::().value(), std::mem::align_of::()); + } + + #[test] + fn align_from_layout() { + let layout = std::alloc::Layout::from_size_align(256, 128).unwrap(); + assert_eq!(Align::from_layout(layout).value(), 128); + } + + #[test] + fn align_constants() { + assert_eq!(Align::_1.value(), 1); + assert_eq!(Align::_2.value(), 2); + assert_eq!(Align::_4.value(), 4); + assert_eq!(Align::_8.value(), 8); + assert_eq!(Align::_16.value(), 16); + assert_eq!(Align::_32.value(), 32); + assert_eq!(Align::_64.value(), 64); + assert_eq!(Align::_128.value(), 128); + assert_eq!(Align::_256.value(), 256); + assert_eq!(Align::_512.value(), 512); + assert_eq!(Align::_1024.value(), 1024); + assert_eq!(Align::_2048.value(), 2048); + assert_eq!(Align::_4096.value(), 4096); + } + + #[test] + fn align_ordering() { + assert!(Align::_1 < Align::_64); + assert!(Align::_128 > Align::_64); + assert_eq!(Align::_32, Align::new(32).unwrap()); + } + + #[test] + fn align_display() { + assert_eq!(format!("{}", Align::_64), "64"); + } +} diff --git a/diskann-inmem/src/provider.rs b/diskann-inmem/src/provider.rs new file mode 100644 index 000000000..f0569b701 --- /dev/null +++ b/diskann-inmem/src/provider.rs @@ -0,0 +1,1237 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! An in-memory provider for the DiskANN graph index. +//! +//! This type supports the following: +//! +//! * Arbitrary external IDs for store data (provided they satisfy [`Id`]. +//! * Support for concurrent insertions, deletions, and searches. +//! * Specialized implementations of [`glue::SearchAccessor::expand_beam`] enabling full +//! inlining of distance kernels. +//! +//! Known areas for future work: +//! +//! * Insert and delete protection: The [`DiskANNIndex`](diskann::graph::DiskANNIndex) doesn't +//! support ergonomic insert or delete guards to protect slots during insert or delete +//! operations. This leaves open a situation where an item can be inserted and during +//! the insertion algorithm, it is deleted, and then re-inserted. +//! +//! This can cause some issue within the main indexing algorithms which assume the inserted +//! ID is present but requires upstream changes to properly fix. +//! +//! * Failed insert rollback: again, this needs some upstream changes to full support. +//! +//! * Quantization + reranking: Ths current version of this index targets just a single +//! data-store and is planned to be addressed in the near future. +//! +//! * Lack of save/load support: The index is currently ephemeral, but there are plans to +//! address this gap. + +use std::{hash::Hash, num::NonZeroUsize}; + +use diskann::{ + ANNError, ANNErrorKind, ANNResult, + graph::{ + AdjacencyList, SearchOutputBuffer, + glue::{self, HybridPredicate}, + workingset, + }, + neighbor::Neighbor, + provider, + utils::IntoUsize, +}; +use diskann_utils::views::Matrix; +use thiserror::Error; + +use crate::{ + counters::{Counters, LocalCounters}, + layers::{self, QueryDistance}, + num::Bytes, + sharded::Sharded, + store::{self, Store}, +}; + +/// Aggregate trait for the external ID type of [`Provider`]. +pub trait Id: Send + Sync + Hash + Eq + Clone + 'static {} + +impl Id for T where T: Send + Sync + Hash + Eq + Clone + 'static {} + +/// An in-memory data-provider for DiskANN's graph indexing algorithms. +/// +/// The first type parameter `L` is a [`layers::Layer`] for describing the kind of data +/// stored within the provider. The second parameter `M` is the associated data for items +/// inserted into the provider. +#[derive(Debug)] +pub struct Provider +where + M: Id, +{ + // The raw binary store + store: Store, + // Data representation. + layer: L, + // ID translation. + mapping: Sharded, + // Construction `Config`. + config: Config, + + // `Counters` is only non-trivial under the `integration-test` feature flag. Otherwise, + // all counter related operations are no-ops. + counters: Counters, +} + +impl Provider +where + M: Id, +{ + /// Construct a new [`Provider`]. + /// + /// The list of `start_points` must be must be compatible with `layer`. + pub fn new(layer: L, config: Config, start_points: I) -> Result + where + I: IntoIterator, + L: layers::Set, + { + let start_points: Vec<_> = start_points.into_iter().collect(); + let bytes = layers::Layer::bytes(&layer); + let mut data = Matrix::new(0u8, start_points.len(), bytes.value()); + + for (row, point) in std::iter::zip(data.row_iter_mut(), start_points.into_iter()) { + layers::Set::set(&layer, point, row)?; + } + + let store = Store::new( + config.capacity(), + bytes, + config.max_degree(), + data.as_view(), + ) + .map_err(|err| ProviderError::CreatingStore(Box::new(err)))?; + + let mapping = Sharded::new(config.capacity()); + + Ok(Self { + store, + layer, + mapping, + config, + counters: Counters::new(), + }) + } + + /// A local set of counters that update the provider-wide counters in bulk. + fn local_counters(&self) -> LocalCounters<'_> { + self.counters.local() + } + + /// Return the maximum number of neighbors that can be stored in the provider's graph. + pub fn max_degree(&self) -> usize { + self.store.max_degree() + } + + /// Return a snapshot of the current event counters. + #[cfg(feature = "integration-test")] + pub fn counters(&self) -> crate::integration::counters::CounterSnapshot { + self.counters.snapshot() + } +} + +#[derive(Debug, Error)] +pub enum ProviderError { + #[error("error when trying to set start points")] + SettingStartPoints(#[from] ANNError), + #[error("could not create data store")] + CreatingStore(#[source] Box), +} + +/// Configuration for [`Provider`]. +#[derive(Debug)] +pub struct Config { + capacity: usize, + max_degree: usize, + prefetch_lookahead: Option, +} + +impl Config { + const DEFAULT_PREFETCH_LOOKAHEAD: NonZeroUsize = NonZeroUsize::new(8).unwrap(); + + /// Construct a new [`Config`]. + /// + /// * `capacity`: The number of dynamic entries in the resulting provider. + /// * `max_degree`: The maximum degree of any adjacency list in the graph. + pub fn new(capacity: usize, max_degree: usize) -> Self { + Self { + capacity, + max_degree, + prefetch_lookahead: Some(Self::DEFAULT_PREFETCH_LOOKAHEAD), + } + } + + /// Return the number of dynamic entries in the resulting provider. + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Return the maximum degree of any adjacency list. + pub fn max_degree(&self) -> usize { + self.max_degree + } + + /// Configure the prefetch lookahead. + /// + /// This is used during beam expansion to prefetch data into CPU caches. + pub fn prefetch_lookahead(&mut self, prefetch_lookahead: Option) { + self.prefetch_lookahead = prefetch_lookahead; + } +} + +/////////////////// +// Data Provider // +/////////////////// + +/// A zero-sied [`diskann::provider::ExecutionContext`] for [`Provider`]. +#[derive(Debug, Clone, Default)] +pub struct Context; + +impl diskann::provider::ExecutionContext for Context {} + +impl diskann::provider::DataProvider for Provider +where + T: Send + Sync + 'static, + M: Id, +{ + type Context = Context; + type InternalId = u32; + type ExternalId = M; + type Error = ANNError; + type Guard = diskann::provider::NoopGuard; + + fn to_internal_id( + &self, + _context: &Self::Context, + gid: &M, + ) -> Result { + match self.mapping.to_internal(gid) { + Some(id) => Ok(id), + None => Err(ANNError::message(ANNErrorKind::Opaque, "no mapping")), + } + } + + /// Translate an internal id to its corresponding external id. + fn to_external_id( + &self, + _context: &Self::Context, + id: Self::InternalId, + ) -> Result { + match self.mapping.to_external(id) { + Some(gid) => Ok(gid), + None => Err(ANNError::message(ANNErrorKind::Opaque, "no mapping")), + } + } +} + +// TODO: The element-status checks here are profoundly approximate because we try to avoid +// any kind of EBR registration. +// +// `diskann` has plans to move deletion checks behind an accessor trait, which will help +// with this situation. +impl diskann::provider::Delete for Provider +where + L: Send + Sync + 'static, + M: Id, +{ + async fn delete(&self, _context: &Context, gid: &M) -> ANNResult<()> { + // This guarantees that we have a valid mapping, but defers the actual deletion until + // we know it's also safe to retire the internal slot. + // + // This ensures both either succeed or are aborted. + let entry = match self.mapping.occupied_entry(gid.clone()) { + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "id already deleted", + )); + } + Some(e) => e, + }; + + match self.store.retire(entry.internal().into_usize()) { + Ok(()) => { + // Successfully retired the internal slot. We can safely release the ID mapping. + entry.delete(); + Ok(()) + } + Err(err) => Err(ANNError::opaque(err)), + } + } + + async fn release(&self, _context: &Context, _id: Self::InternalId) -> ANNResult<()> { + Ok(()) + } + + async fn status_by_internal_id( + &self, + _context: &Context, + id: u32, + ) -> ANNResult { + // Not that this check is approximate. A full check requires materialization of + // a `reader`. + match self.store.can_read_approximate(id.into_usize()) { + Some(true) => Ok(diskann::provider::ElementStatus::Valid), + Some(false) => Ok(diskann::provider::ElementStatus::Deleted), + None => Err(ANNError::message( + ANNErrorKind::Opaque, + "accessed invalid internal ID", + )), + } + } + + async fn status_by_external_id( + &self, + _context: &Context, + gid: &M, + ) -> ANNResult { + if self.mapping.contains_external(gid) { + Ok(diskann::provider::ElementStatus::Valid) + } else { + Ok(diskann::provider::ElementStatus::Deleted) + } + } +} + +fn ready(f: F) -> std::future::Ready +where + F: FnOnce() -> R, +{ + std::future::ready(f()) +} + +impl diskann::provider::SetElement for Provider +where + L: layers::Set, + M: Id, +{ + type SetError = ANNError; + + fn set_element( + &self, + _context: &Self::Context, + id: &M, + element: T, + ) -> impl std::future::Future> + Send { + let work = move || { + let mut slot = self.store.acquire().ok_or_else(|| { + ANNError::message(ANNErrorKind::Opaque, "could not allocate a new slot") + })?; + + // TODO: Proper cleanup via `Guard` or some other mechanism on the event of + // insert failure after `set_element` returns. + >::set(&self.layer, element, slot.as_mut_slice())?; + self.mapping.insert(id.clone(), slot.slot())?; + + // Now that insert has succeeded - publish the slot. This method cannot fail, so + // we do not need to worry about potentially unwinding the ID mapping. + let id = slot.publish(); + + // This is a rather expensive update. + // + // However, counters are only active with the `integration-test` feature, which + // is not expected to be enabled for general use. + self.local_counters().set_vector(1); + + Ok(diskann::provider::NoopGuard::new(id)) + }; + + ready(work) + } +} + +//////////// +// Search // +//////////// + +/// A [`glue::SearchAccessor`] for [`Provider`]. +/// +/// This type intentionally avoids generic parameters and instead compiles optimized +/// `expand_beam` kernels that get reused. The idea is to generate an efficient graph search +/// kernel once and reuse it to balance compile times and performance. +#[derive(Debug)] +pub struct SearchAccessor<'a> { + reader: store::Reader<'a>, + ids: AdjacencyList, + expand_beam: Box, + buffer: Vec<(u32, f32)>, + + // The parent provider for the accessor. + provider: &'a (dyn std::any::Any + Send + Sync), + start_points: std::ops::Range, + counters: LocalCounters<'a>, +} + +impl diskann::provider::HasId for SearchAccessor<'_> { + type Id = u32; +} + +impl glue::SearchAccessor for SearchAccessor<'_> { + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + std::future::ready(Ok(self.start_points.clone().collect())) + } + + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + let work = move || { + for p in self.start_points.clone() { + match self.reader.read(p.into_usize()) { + Some(point) => { + // Counters are no-ops without `integration-test`. + self.counters.get_vector(1); + self.counters.query_distance(1); + + f(p, self.expand_beam.evaluate(point)?); + } + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "could not retrieve start point", + )); + } + } + } + Ok(()) + }; + + ready(work) + } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: HybridPredicate + Send + Sync, + F: FnMut(Self::Id, f32) + Send, + { + let work = move || -> ANNResult<()> { + for i in ids { + self.reader.neighbors().get(i, &mut self.ids)?; + self.counters.get_neighbors(1); + + // Filter out unvisited IDs and ensure that all the IDs we are about + self.ids + .retain(|i| pred.eval_mut(i) && self.reader.is_in_bounds(i.into_usize())); + + // This should always hold, but let's double check. + assert!(self.buffer.len() >= self.ids.len()); + + // SAFETY: We've verified that each entry in `self.ids` is in-bounds and the + // `self.buffer` is long enough to hold all the IDs. + let processed = unsafe { + self.expand_beam + .expand_beam(&self.ids, &self.reader, &mut self.buffer) + }?; + + self.counters.get_vector(processed as u64); + self.counters.query_distance(processed as u64); + + self.buffer + .iter() + .take(processed) + .for_each(|(id, dist)| on_neighbors(*id, *dist)); + } + + Ok(()) + }; + + ready(work) + } +} + +trait ExpandBeam: Send + Sync + std::fmt::Debug { + /// Evaluate a raw distance function. + fn evaluate(&self, x: &[u8]) -> ANNResult; + + /// Compute the distance between the query and each neighbor in `list`. + /// + /// # Safety + /// + /// * All items in `list` must in-bounds with respect to `reader`. + /// * `buffer.len() >= list.len()`. + unsafe fn expand_beam( + &self, + list: &[u32], + reader: &store::Reader<'_>, + buffer: &mut [(u32, f32)], + ) -> ANNResult; +} + +#[derive(Debug)] +struct ExpandBeamImpl { + inner: T, + prefetch_lookahead: usize, +} + +impl ExpandBeamImpl { + fn new(inner: T, prefetch_lookahead: usize) -> Self { + Self { + inner, + prefetch_lookahead, + } + } +} + +impl ExpandBeam for ExpandBeamImpl +where + T: layers::QueryDistance, +{ + fn evaluate(&self, x: &[u8]) -> ANNResult { + self.inner.evaluate(x) + } + + unsafe fn expand_beam( + &self, + list: &[u32], + reader: &store::Reader<'_>, + buffer: &mut [(u32, f32)], + ) -> ANNResult { + // SAFETY: Inherited from caller. + unsafe { + expand_beam_inner::( + &self.inner, + list, + self.prefetch_lookahead, + reader, + buffer, + ) + } + } +} + +#[derive(Debug)] +struct ExpandBeamVisitor { + bytes: Bytes, + prefetch_lookahead: usize, +} + +impl<'a> layers::QueryVisitor<'a> for ExpandBeamVisitor { + type Output = Box; + + fn visit_sized(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a, + { + // This is critical to ensure we emit the correct number of prefetches. + assert!(Bytes::new(BYTES + store::TAG_SIZE.value()) <= self.bytes); + Box::new(ExpandBeamImpl::<_, BYTES>::new( + distance, + self.prefetch_lookahead, + )) + } + + fn visit(self, distance: T) -> Self::Output + where + T: QueryDistance + 'a, + { + Box::new(ExpandBeamImpl::<_, 0>::new( + distance, + self.prefetch_lookahead, + )) + } +} + +/// Prefetch `len` bytes beginning at `ptr`. +/// +/// The last cache line prefetched first, followed by the rest in ascending order. +/// +/// # Safety +/// +/// The memory range `[ptr, ptr.add(len))` must be valid. +#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] +#[inline(always)] +unsafe fn prefetch(ptr: *const u8, len: usize) { + use std::arch::x86_64::*; + + // Fetch the last cache line (the one with the tag) first. + let stride = Bytes::CACHELINE.value(); + let ptr = ptr.cast::(); + let lines = len.div_ceil(stride); + if lines == 0 { + return; + } + + // SAFETY: Inherited from caller. + unsafe { _mm_prefetch(ptr.add(stride * (lines - 1)), _MM_HINT_T0) }; + for i in 0..(lines - 1) { + // SAFETY: Inherited from caller. + unsafe { + _mm_prefetch(ptr.add(stride * i), _MM_HINT_T0); + } + } +} + +/// Prefetch `len` bytes beginning at `ptr`. +/// +/// The last cache line prefetched first, followed by the rest in ascending order. +/// +/// # Safety +/// +/// The memory range `[ptr, ptr.add(len))` must be valid. +#[cfg(not(any(target_arch = "x86_64", target_feature = "avx2")))] +unsafe fn prefetch(_ptr: *const u8, _len: usize) {} + +/// # Safety +/// +/// * All items in `list` must in-bounds with respect to `reader`. +/// * The number of bytes associated with `N` cache lines must "make sense". +/// * `buffer.len() >= list.len()`. +#[inline] +unsafe fn expand_beam_inner( + distance: &T, + list: &[u32], + lookahead: usize, + reader: &store::Reader<'_>, + buffer: &mut [(u32, f32)], +) -> ANNResult +where + T: layers::QueryDistance, +{ + debug_assert!( + BYTES + store::TAG_SIZE.value() <= reader.bytes().value(), + "we really rely on this: {}, bytes = {}", + BYTES + store::TAG_SIZE.value(), + reader.bytes() + ); + + debug_assert!(buffer.len() >= list.len()); + + let bytes = if BYTES == 0 { + reader.bytes().value() + } else { + BYTES + store::TAG_SIZE.value() + }; + + let len = list.len(); + let lookahead = lookahead.min(len); + + for j in 0..lookahead { + // SAFETY: The in-bounds constraint is assured by the caller, both for `j` as well + // as the validity of the prefetch bounds. + unsafe { + prefetch( + reader + .read_raw_unchecked(list.get_unchecked(j).into_usize()) + .as_ptr() + .cast(), + bytes, + ) + } + } + + // Disable prefetching if the lookahead is 0. + let mut j = if lookahead == 0 { len } else { lookahead }; + let mut processed = 0; + for &i in list.iter() { + if j != len { + // SAFETY: The in-bounds constraint is assured by the caller, both for `j` as + // well as the validity of the prefetch bounds. + unsafe { + prefetch( + reader + .read_raw_unchecked(list.get_unchecked(j).into_usize()) + .as_ptr() + .cast(), + bytes, + ) + } + j += 1; + } + + // SAFETY: Caller asserts that `i` is in-bounds. + if let Some(data) = unsafe { reader.read_in_bounds(i.into_usize()) } { + // SAFETY: Inherited from caller. + *unsafe { buffer.get_unchecked_mut(processed) } = (i, distance.evaluate(data)?); + processed += 1; + } + } + + Ok(processed) +} + +//////////// +// Insert // +//////////// + +/// The [`glue::PruneAccessor`] implementation for [`Provider`]. +/// +/// This type implements zero-copy access to the data within its parent provider during prunes. +#[derive(Debug)] +pub struct PruneAccessor<'a> { + reader: store::Reader<'a>, + distance: &'a dyn layers::Distance, + counters: LocalCounters<'a>, +} + +/// The distance computer for [`PruneAccessor`]. +#[derive(Debug)] +pub struct Distance<'a> { + distance: &'a dyn layers::Distance, + counters: LocalCounters<'a>, +} + +impl<'a> Distance<'a> { + fn new(distance: &'a dyn layers::Distance, counters: LocalCounters<'a>) -> Self { + Self { distance, counters } + } +} + +#[expect( + clippy::unwrap_used, + reason = "prune does not allow fallible distance functions yet" +)] +impl diskann_vector::DistanceFunction<&[u8], &[u8], f32> for Distance<'_> { + #[inline] + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.counters.distance_ref(1); + self.distance.evaluate(x, y).unwrap() + } +} + +impl diskann::provider::HasId for PruneAccessor<'_> { + type Id = u32; +} + +impl glue::PruneAccessor for PruneAccessor<'_> { + type Neighbors<'a> + = provider::Neighbors<'a, Self> + where + Self: 'a; + + type ElementRef<'a> = &'a [u8]; + + type View<'a> + = &'a Self + where + Self: 'a; + + type Distance<'a> + = Distance<'a> + where + Self: 'a; + + fn neighbors(&mut self) -> Self::Neighbors<'_> { + provider::Neighbors(self) + } + + async fn fill<'a, Itr>( + &'a mut self, + _itr: Itr, + ) -> ANNResult<(Self::View<'a>, Self::Distance<'a>)> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + { + Ok((self, Distance::new(self.distance, self.counters.fork()))) + } +} + +impl provider::NeighborAccessor for PruneAccessor<'_> { + fn get_neighbors( + &mut self, + id: Self::Id, + neighbors: &mut AdjacencyList, + ) -> impl std::future::Future> + Send { + let work = move || { + self.counters.get_neighbors(1); + Ok(self.reader.neighbors().get(id, neighbors)?) + }; + ready(work) + } +} + +impl provider::NeighborAccessorMut for PruneAccessor<'_> { + fn set_neighbors( + &mut self, + id: Self::Id, + neighbors: &[Self::Id], + ) -> impl std::future::Future> + Send { + let work = move || { + self.counters.set_neighbors(1); + Ok(self.reader.neighbors().set(id, neighbors)?) + }; + ready(work) + } + + fn append_vector( + &mut self, + id: Self::Id, + neighbors: &[Self::Id], + ) -> impl std::future::Future> + Send { + let work = move || -> ANNResult<()> { + self.counters.append_vector(1); + let lock = self.reader.neighbors().lock(id)?; + + // Due to race conditions between calls to `get_neighbors` and `append_vector` + // in `diskann` - it's possible that the state of the adjacency list has changed + // and we're now trying to add too many neighbors. + // + // We take care of that here by simply truncating. + // + // TODO: Introduce proper atomicity in the core algorithm. + if lock.len() + neighbors.len() > lock.capacity() { + let slack = lock.capacity() - lock.len(); + lock.append(&neighbors[..slack])?; + } else { + lock.append(neighbors)?; + } + + Ok(()) + }; + + ready(work) + } +} + +impl workingset::View for &PruneAccessor<'_> { + type ElementRef<'a> = &'a [u8]; + type Element<'a> + = &'a [u8] + where + Self: 'a; + fn get(&self, id: u32) -> Option<&[u8]> { + match self.reader.read(id.into_usize()) { + Some(data) => { + self.counters.get_vector_ref(1); + Some(data) + } + None => None, + } + } +} + +//////////////// +// Strategies // +//////////////// + +#[derive(Debug, Clone, Copy)] +pub struct Strategy; + +impl<'a, L, M> glue::SearchStrategy<'a, Provider, L::Query<'a>> for Strategy +where + L: layers::Search, + M: Id, +{ + type SearchAccessor = SearchAccessor<'a>; + type SearchAccessorError = ANNError; + + fn search_accessor( + &'a self, + provider: &'a Provider, + _context: &'a Context, + query: L::Query<'a>, + ) -> ANNResult> { + let reader = provider.store.reader()?; + let expand_beam = ::query_distance( + &provider.layer, + query, + ExpandBeamVisitor { + bytes: provider.store.bytes(), + prefetch_lookahead: provider.config.prefetch_lookahead.map_or(0, |x| x.get()), + }, + )?; + + let accessor = SearchAccessor { + reader, + ids: AdjacencyList::new(), + expand_beam, + buffer: vec![(0, 0.0); provider.max_degree()], + provider, + start_points: provider.store.frozen(), + counters: provider.local_counters(), + }; + Ok(accessor) + } +} + +// This is a utility for helping inspect the generated code for `ExpandBeam`. +// +pub fn test_function<'a>( + x: &'a Provider>, + strategy: &'a Strategy, + context: &'a Context, + query: &'a [f32], +) -> ANNResult> { + glue::SearchStrategy::search_accessor(strategy, x, context, query) +} + +/// Perform ID translation during post-processing. +#[derive(Debug, Clone, Copy)] +pub struct Translate(std::marker::PhantomData<(L, M)>); + +impl Default for Translate { + fn default() -> Self { + Self(std::marker::PhantomData) + } +} + +impl<'a, L, M> glue::SearchPostProcess, L::Query<'a>, M> for Translate +where + L: layers::Search, + M: Id, +{ + type Error = ANNError; + + fn post_process( + &self, + accessor: &mut SearchAccessor<'_>, + _query: L::Query<'a>, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let work = move || { + // By construction - the downcast should succeed. Otherwise, this is a program bug. + let provider = match accessor.provider.downcast_ref::>() { + Some(provider) => provider, + None => return Err(ANNError::message(ANNErrorKind::Opaque, "bad any cast")), + }; + + let mut count = 0; + for c in candidates { + if let Some(ext) = provider.mapping.to_external(c.id) { + if output.push(ext, c.distance).is_available() { + count += 1; + } else { + break; + } + } + } + Ok(count) + }; + + ready(work) + } +} + +impl<'a, L, M> glue::DefaultPostProcessor<'a, Provider, L::Query<'a>, M> for Strategy +where + L: layers::Search, + M: Id, +{ + diskann::default_post_processor!(Translate); +} + +impl glue::PruneStrategy> for Strategy +where + L: layers::Layer + layers::AsDistance, + M: Id, +{ + type PruneAccessor<'a> = PruneAccessor<'a>; + type PruneAccessorError = ANNError; + + fn prune_accessor<'a>( + &self, + provider: &'a Provider, + _context: &'a Context, + _capacity: usize, + ) -> ANNResult> { + Ok(PruneAccessor { + reader: provider.store.reader()?, + distance: ::as_distance(&provider.layer), + counters: provider.local_counters(), + }) + } +} + +impl<'a, L, M> glue::InsertStrategy<'a, Provider, L::Query<'a>> for Strategy +where + L: layers::Insert, + M: Id, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +impl glue::InplaceDeleteStrategy, M>> for Strategy +where + M: Id, + T: layers::FullPrecision, +{ + type DeleteElement<'a> = &'a [T]; + type DeleteElementGuard = Box<[T]>; + type DeleteElementError = ANNError; + + type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = SearchAccessor<'a>; + type SearchPostProcessor = glue::CopyIds; + type SearchStrategy = Self; + + fn prune_strategy(&self) -> Self { + *self + } + + fn search_strategy(&self) -> Self { + *self + } + + fn search_post_processor(&self) -> Self::SearchPostProcessor { + glue::CopyIds + } + + fn get_delete_element<'a>( + &'a self, + provider: &'a Provider, M>, + _context: &'a Context, + id: u32, + ) -> impl Future> + Send + { + let work = move || { + let reader = provider.store.reader()?; + let data = match reader.read(id.into_usize()) { + Some(data) => data, + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "item could not be read", + )); + } + }; + + let mut buf: Box<[_]> = + std::iter::repeat_n(T::zeroed(), provider.layer.dim()).collect(); + + bytemuck::must_cast_slice_mut::(&mut buf).copy_from_slice(data); + Ok(buf) + }; + ready(work) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use diskann::{ + graph::{DiskANNIndex, InplaceDeleteMethod, search::Knn, test::synthetic::Grid}, + neighbor::Neighbor, + provider::{DataProvider, Delete}, + }; + use diskann_vector::distance::Metric; + + use crate::layers::Full; + + /// The true tests live in the integration tests for this repo. + /// + /// The smoke test here uses a 2D grid of points to verify that our provider + /// implementations are more-or-less correct. + /// + /// Note that since `Provider` separates internal and external IDs, we multiply the + /// coordinates of each element in the grid by 10 and add 1 to verify that the ID + /// translation is behaving properly. + /// + /// For clarity, the expected structure of the grid is as follows: + /// + /// + /// 41 91 141 191 241 + /// 31 81 131 181 231 + /// 21 71 121 171 221 + /// 11 61 111 161 211 + /// 1 51 101 151 201 + /// + #[tokio::test] + async fn smoke() { + let grid = Grid::Two; + let size = 5; + let data = grid.data(size); + let start = grid.start_point(size); + let degree = 6; + + let full = Full::::new(grid.dim().into(), Metric::L2); + + let config = Config::new(grid.num_points(size), degree); + + let provider = + Provider::<_, u64>::new(full, config, std::iter::once(start.as_slice())).unwrap(); + assert_eq!(provider.max_degree(), degree); + + let config = diskann::graph::config::Builder::new( + 2 * (grid.dim() as usize), + diskann::graph::config::MaxDegree::new(provider.max_degree()), + 10, + (Metric::L2).into(), + ) + .build() + .unwrap(); + + let index = DiskANNIndex::new(config, provider, None); + + for (i, data) in data.row_iter().enumerate() { + index + .insert(&Strategy, &Context, &((10 * i + 1) as u64), data) + .await + .unwrap(); + } + + // Verify that each ID round trips. + for i in 0..data.nrows() { + let i = (10 * i + 1) as u64; + let internal = index.provider().to_internal_id(&Context, &i).unwrap(); + assert_ne!(internal as u64, i); + assert_eq!( + index.provider().to_external_id(&Context, internal).unwrap(), + i + ); + + assert!( + !index + .provider() + .status_by_external_id(&Context, &i) + .await + .unwrap() + .is_deleted() + ); + assert!( + !index + .provider() + .status_by_internal_id(&Context, internal) + .await + .unwrap() + .is_deleted() + ); + } + + // Assert that out-of-bounds translations returns errors. + assert!(index.provider().to_internal_id(&Context, &0).is_err()); + assert!(index.provider().to_external_id(&Context, 26).is_err()); + + // Searches should return something reasonable. + let knn = Knn::new(10, 10, None).unwrap(); + let mut neighbors = Vec::>::new(); + index + .search(knn, &Strategy, &Context, &[0.0, 0.0], &mut neighbors) + .await + .unwrap(); + + assert_eq!(neighbors[0].as_tuple(), (1, 0.0)); + assert_eq!(neighbors[1].as_tuple(), (11, 1.0)); // this can be swapped with 2 + assert_eq!(neighbors[2].as_tuple(), (51, 1.0)); + assert_eq!(neighbors[3].as_tuple(), (61, 2.0)); + + // If we run inplace delete on point 61, it longer be present. + index + .inplace_delete( + Strategy, + &Context, + &61, + 3, + InplaceDeleteMethod::VisitedAndTopK { + k_value: 10, + l_value: 10, + }, + ) + .await + .unwrap(); + + assert!( + index + .provider() + .status_by_external_id(&Context, &61) + .await + .unwrap() + .is_deleted() + ); + + // We can't delete the same thing twice. + assert!( + index + .inplace_delete( + Strategy, + &Context, + &61, + 3, + InplaceDeleteMethod::VisitedAndTopK { + k_value: 10, + l_value: 10 + }, + ) + .await + .is_err() + ); + + // Rerun search - the point 61 should now be gone. + let mut neighbors = Vec::>::new(); + index + .search(knn, &Strategy, &Context, &[0.0, 0.0], &mut neighbors) + .await + .unwrap(); + + assert_eq!(neighbors[0].as_tuple(), (1, 0.0)); + assert_eq!(neighbors[1].as_tuple(), (51, 1.0)); // this can be swapped with 2 + assert_eq!(neighbors[2].as_tuple(), (11, 1.0)); + assert_eq!(neighbors[3].as_tuple(), (101, 4.0)); // we can also accept "21" + + // We can't insert an existing ID. + assert!( + index + .insert(&Strategy, &Context, &1, &[10.0, 10.0]) + .await + .is_err() + ); + + // If we insert a new ID but the query vector is too long - make sure we leave the + // provider untouched. + assert!( + index + .insert(&Strategy, &Context, &2, &[10.0, 10.0, 10.0]) + .await + .is_err() + ); + + // Check that we can reinsert the same point with a different ID and have it be + // returned from search. + index + .insert(&Strategy, &Context, &62, &[1.0, 1.0]) + .await + .unwrap(); + + // We can't insert an ID - but this time it's because we don't have any more internal + // slots. + assert!( + index + .insert(&Strategy, &Context, &62, &[0.0, 0.0]) + .await + .is_err() + ); + + // Rerun search - the point 62 should be present. + let mut neighbors = Vec::>::new(); + index + .search(knn, &Strategy, &Context, &[0.0, 0.0], &mut neighbors) + .await + .unwrap(); + + assert_eq!(neighbors[0].as_tuple(), (1, 0.0)); + assert_eq!(neighbors[1].as_tuple(), (11, 1.0)); // this can be swapped with 2 + assert_eq!(neighbors[2].as_tuple(), (51, 1.0)); + assert_eq!(neighbors[3].as_tuple(), (62, 2.0)); + } +} diff --git a/diskann-inmem/src/sharded.rs b/diskann-inmem/src/sharded.rs new file mode 100644 index 000000000..50a59b9e4 --- /dev/null +++ b/diskann-inmem/src/sharded.rs @@ -0,0 +1,364 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::hash::Hash; + +use dashmap::{ + DashMap, + mapref::entry::{self, OccupiedEntry}, +}; +use diskann::utils::IntoUsize; +use parking_lot::{RwLock, RwLockWriteGuard}; +use thiserror::Error; + +const SHARD_SIZE: usize = 1024; + +/// Bidirectional mapping between an external id `I` and a dense internal `u32` id. +#[derive(Debug)] +pub(crate) struct Sharded +where + I: Hash + Eq, +{ + forward: DashMap, + backward: Vec]>>>, + capacity: usize, +} + +impl Sharded +where + I: Hash + Eq, +{ + pub(crate) fn new(capacity: usize) -> Self { + let backward = std::iter::repeat_with(|| { + let shard = std::iter::repeat_with(|| None).take(SHARD_SIZE).collect(); + RwLock::new(shard) + }) + .take(capacity.div_ceil(SHARD_SIZE)) + .collect(); + + Self { + forward: DashMap::new(), + backward, + capacity, + } + } + + /// Establish a mapping between `external` and `internal`. + /// + /// # Errors + /// + /// Returns [`InsertError::OutOfBounds`] if `internal` is outside the table's capacity. + /// Returns [`InsertError::ExternalExists`] if `external` is already mapped. + /// Returns [`InsertError::InternalExists`] if `internal` is already mapped. + pub(crate) fn insert(&self, external: I, internal: u32) -> Result<(), InsertError> + where + I: Eq + Hash + Clone, + { + if internal.into_usize() >= self.capacity { + return Err(InsertError::OutOfBounds); + } + + let Shard { outer, inner } = self.shard(internal); + + // Take the forward entry first and hold it vacant until the reverse slot is + // confirmed empty. This makes the pair-write atomic with respect to other + // `insert` callers: another thread racing on the same `external` will block + // on the dashmap shard, and another thread racing on the same `internal` will + // block on the backward shard's write lock. + let forward = match self.forward.entry(external.clone()) { + entry::Entry::Occupied(_) => return Err(InsertError::ExternalExists), + entry::Entry::Vacant(vacant) => vacant, + }; + + let mut shard = self.backward[outer].write(); + if shard[inner].is_some() { + // Forward entry drops as vacant — no insertion happened. + return Err(InsertError::InternalExists); + } + shard[inner] = Some(external); + forward.insert(internal); + Ok(()) + } + + pub(crate) fn contains_external(&self, external: &Q) -> bool + where + I: std::borrow::Borrow, + Q: Eq + Hash + ?Sized, + { + self.forward.contains_key(external) + } + + /// Look up the internal id for an external id. + pub(crate) fn to_internal(&self, external: &Q) -> Option + where + I: std::borrow::Borrow, + Q: Eq + Hash + ?Sized, + { + self.forward.get(external).map(|v| *v) + } + + /// Look up the external id for an internal id. + pub(crate) fn to_external(&self, internal: u32) -> Option + where + I: Clone, + { + if internal.into_usize() >= self.capacity { + return None; + } + + let Shard { outer, inner } = self.shard(internal); + self.backward[outer].read()[inner].clone() + } + + /// Validate that a mapping exists for `external` and return an [`Entry`] if successful. + /// + /// The [`Entry`] provides a means of error-free deferred deletion to enable coordinated + /// deletion of slots among multiple stores. + pub(crate) fn occupied_entry(&self, external: I) -> Option> + where + I: Eq + Hash, + { + match self.forward.entry(external) { + entry::Entry::Vacant(_) => None, + entry::Entry::Occupied(forward) => { + let internal = *forward.get(); + let Shard { outer, inner } = self.shard(internal); + let backward = self.backward[outer].write(); + assert!( + backward[inner].is_some(), + "id {} removed improperly", + internal + ); + + Some(Entry { + forward, + backward, + entry: inner, + }) + } + } + } + + fn shard(&self, i: u32) -> Shard { + let i = i.into_usize(); + Shard { + outer: i / SHARD_SIZE, + inner: i % SHARD_SIZE, + } + } + + #[cfg(test)] + fn capacity(&self) -> usize { + self.capacity + } +} + +struct Shard { + outer: usize, + inner: usize, +} + +#[derive(Debug, Error)] +pub(crate) enum InsertError { + #[error("internal id is out of bounds")] + OutOfBounds, + #[error("the external id is already mapped")] + ExternalExists, + #[error("the internal id is already mapped")] + InternalExists, +} + +crate::opaque!(InsertError); + +/// A handle to a valid entry in a [`Sharded`]. +/// +/// This can be used to guarantee the presence of an entry prior to deletion to support +/// atomic deletes. +pub(crate) struct Entry<'a, I> +where + I: Eq + Hash, +{ + forward: OccupiedEntry<'a, I, u32>, + backward: RwLockWriteGuard<'a, Box<[Option]>>, + entry: usize, +} + +impl<'a, I> Entry<'a, I> +where + I: Eq + Hash, +{ + pub(crate) fn internal(&self) -> u32 { + *self.forward.get() + } + + pub(crate) fn delete(mut self) { + self.forward.remove(); + self.backward[self.entry] = None; + } + + #[cfg(test)] + fn external(&self) -> &I { + self.forward.key() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_reports_capacity() { + for capacity in [ + 0, + 1, + SHARD_SIZE - 1, + SHARD_SIZE, + SHARD_SIZE + 1, + 3 * SHARD_SIZE, + ] { + let map = Sharded::::new(capacity); + assert_eq!(map.capacity(), capacity); + } + } + + #[test] + fn insert_round_trips() { + let map = Sharded::::new(16); + assert!(map.insert(100, 3).is_ok()); + + assert_eq!(map.to_internal(&100), Some(3)); + assert_eq!(map.to_external(3), Some(100)); + assert!(map.contains_external(&100)); + + // Unmapped ids return nothing. + assert_eq!(map.to_internal(&101), None); + assert_eq!(map.to_external(4), None); + assert!(!map.contains_external(&101)); + } + + #[test] + fn insert_rejects_out_of_bounds_internal() { + let map = Sharded::::new(16); + assert!(matches!(map.insert(0, 16), Err(InsertError::OutOfBounds))); + assert!(matches!( + map.insert(0, u32::MAX), + Err(InsertError::OutOfBounds) + )); + + // The largest in-bounds id is accepted. + assert!(map.insert(0, 15).is_ok()); + } + + #[test] + fn insert_rejects_duplicate_external_and_preserves_state() { + let map = Sharded::::new(16); + map.insert(7, 5).unwrap(); + + assert!(matches!(map.insert(7, 6), Err(InsertError::ExternalExists))); + + // The failed insert must not have established any partial mapping. + assert_eq!(map.to_internal(&7), Some(5)); + assert_eq!(map.to_external(6), None); + assert!(!map.contains_external(&6)); + } + + #[test] + fn insert_rejects_duplicate_internal_and_preserves_state() { + let map = Sharded::::new(16); + map.insert(7, 5).unwrap(); + + assert!(matches!(map.insert(8, 5), Err(InsertError::InternalExists))); + + // The failed insert must not have established any partial mapping. + assert_eq!(map.to_external(5), Some(7)); + assert_eq!(map.to_internal(&8), None); + assert!(!map.contains_external(&8)); + } + + #[test] + fn to_external_handles_bounds_and_empty_slots() { + let map = Sharded::::new(16); + // In-bounds but unmapped slot. + assert_eq!(map.to_external(5), None); + // Out-of-bounds slot. + assert_eq!(map.to_external(16), None); + } + + #[test] + fn mappings_span_shard_boundaries() { + let capacity = 3 * SHARD_SIZE; + let map = Sharded::::new(capacity); + + // Ids straddling every internal shard boundary. + let ids: [u32; 6] = [ + 0, + (SHARD_SIZE - 1) as u32, + SHARD_SIZE as u32, + (2 * SHARD_SIZE - 1) as u32, + (2 * SHARD_SIZE) as u32, + (capacity - 1) as u32, + ]; + + for (external, &internal) in ids.iter().enumerate() { + map.insert(external as u32, internal).unwrap(); + } + + for (external, &internal) in ids.iter().enumerate() { + assert_eq!(map.to_internal(&(external as u32)), Some(internal)); + assert_eq!(map.to_external(internal), Some(external as u32)); + } + } + + #[test] + fn lookup_supports_borrowed_query() { + let map = Sharded::::new(16); + map.insert("alpha".to_string(), 1).unwrap(); + + // Borrowed `&str` lookups against `String` keys. + assert!(map.contains_external("alpha")); + assert_eq!(map.to_internal("alpha"), Some(1)); + assert!(!map.contains_external("beta")); + assert_eq!(map.to_internal("beta"), None); + } + + #[test] + fn occupied_entry_exposes_mapping() { + let map = Sharded::::new(16); + map.insert(42, 9).unwrap(); + + let entry = map.occupied_entry(42).expect("entry should exist"); + assert_eq!(entry.internal(), 9); + assert_eq!(*entry.external(), 42); + } + + #[test] + fn occupied_entry_absent_for_unmapped() { + let map = Sharded::::new(16); + assert!(map.occupied_entry(42).is_none()); + } + + #[test] + fn entry_delete_clears_both_directions() { + let map = Sharded::::new(16); + map.insert(42, 9).unwrap(); + + // Just creating and dropping an `occupied_entry` does not clear it. + { + let _ = map.occupied_entry(42).unwrap(); + assert!(map.contains_external(&42)); + assert_eq!(map.to_internal(&42), Some(9)); + assert_eq!(map.to_external(9), Some(42)); + } + + map.occupied_entry(42).expect("entry should exist").delete(); + + assert!(!map.contains_external(&42)); + assert_eq!(map.to_internal(&42), None); + assert_eq!(map.to_external(9), None); + + // The freed external and internal ids can be reused. + assert!(map.insert(42, 9).is_ok()); + } +} diff --git a/diskann-inmem/src/store.rs b/diskann-inmem/src/store.rs new file mode 100644 index 000000000..5aa04890c --- /dev/null +++ b/diskann-inmem/src/store.rs @@ -0,0 +1,963 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! A concurrent in-memory data store for uniformly sized data. +//! +//! This supports concurrent data access, deletes, and inserts through a safe interface. +//! Data is stored internally in slots indexed from `[0..N)` with `K` points reserved at the +//! end at positions `[N..N+K)`. +//! +//! ## Reading +//! +//! Read access requires a [`Reader`] produced by [`Store::reader`]. [`Reader::read`] +//! provides read-only access to data at slot `i` if the data is valid for reads. +//! +//! ## Writing +//! +//! [`Store::acquire`] is used to find and claim an unused internal [`Slot`]. A [`Slot`] +//! provides write access to its coresponding data which is published when the [`Slot`] is +//! dropped. +//! +//! The index of the slot chosen may be obtained via [`Slot::slot`].k +//! +//! ## Deleting +//! +//! Data is deleted via [`Store::retire`]. This immediately marks the corresponding slot as +//! unavailable for future readers. However, the retired slot will not be reused until the +//! [`Store`] can guarantee that no [`Reader`]s that could be using the data are active. +//! +//! Slots are automatically reclaimed as part of slot acquisition in the "writing" phase. +//! +//! ## Neighbor Access +//! +//! The [`Store`] also contains a [`Neighbors`] instance to store adjacency lists. Since +//! neighbors are generally accessed less frequently than data with a higher volume of write +//! traffic, fine-grained locks are used for this data structure. +//! +//! # Details +//! +//! This uses an implementation of the epoch-based reclamation (EBR) provided by [`Registry`]. +//! Concurrency tags are mirrored inline with the stored data (just after the data payload) +//! to keep memory access localized. As such, high-performance implementations will want to +//! fetch the last cache line of data first to ensure the tag is resident in cache for faster +//! data checks. +//! +//! The EBR scheme allows readers to safely access data while only generating read traffic to +//! the CPU caches. The cost is that there is a delay between when slots are retired and when +//! they can be reused, with a long lived [`Reader`] blocking this reclamation. As such, +//! users of this data structure should ensure that [`Reader`]s are reasonably short lived. +//! +//! Internally, the data belongs to a single allocation. + +use std::{ + iter::repeat_n, + num::{NonZeroU32, NonZeroUsize}, + sync::atomic::Ordering, +}; + +use diskann::utils::IntoUsize; +use diskann_utils::views::MatrixView; +use thiserror::Error; + +use crate::{ + buffer::{Buffer, BufferError, RawSlice}, + epoch::{self, Registry}, + freelist::{self, Freelist}, + neighbors::{Neighbors, NeighborsError}, + num::{Align, Bytes}, + tag::{AtomicTag, Tag}, +}; + +/// A concurrent data and graph store. +#[derive(Debug)] +pub(crate) struct Store { + // The invasive store where concurrency tags are stored inline with the data. + // + // These tags are mirrored from `tags` - with the latter being used for secondary scans + // offering slightly better locality. + // + // The inline tags are stored after the data. + buffer: Buffer, + + // The unpadded size of each row in `buffer`. This includes both the data **and** the + // 1-byte tag. Tags are located at byte `unpadded - 1`. + unpadded: Bytes, + + // The number of unfrozen points. This is guaranteed to be less than `buffer`. + unfrozen: usize, + + // The authoritative source of truth for the state of each slot. + tags: Vec, + freelist: Freelist, + + // EBR registry. + registry: Registry, + + // Graph. + neighbors: Neighbors, +} + +/// The number of bytes occupied by the in-line concurrency tag. +pub(crate) const TAG_SIZE: Bytes = Bytes::size_of::(); + +const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); + +// TODO: This is a guess and probably needs tuning. +const RETRY_LIMIT: usize = 20; + +impl Store { + /// Create a new [`Store`] capable of holding [`entries`] non-frozen slots each of + /// length `bytes`. + pub(crate) fn new( + entries: usize, + bytes: Bytes, + max_neighbors: usize, + init: MatrixView<'_, u8>, + ) -> Result { + if init.ncols() != bytes.value() { + return Err(StoreError::mismatched_frozen_point_dim(init.ncols(), bytes)); + } + + if init.nrows() == 0 { + return Err(StoreError::need_frozen_point()); + } + + #[expect( + clippy::expect_used, + reason = "we expect `init` to have at least one row, so this should never happen" + )] + let unpadded = bytes + .checked_add(TAG_SIZE) + .expect("unreachable because `init` cannot exceed `isize::MAX` bytes"); + + // Pad to half a cache line. When data occupies just part of a cache line, this + // results in the same total number of cache lines being fetched while potentially + // enabling more compact memory. + #[expect( + clippy::expect_used, + reason = "we expect `init` to have at least one row, so this should never happen" + )] + let padded_bytes = unpadded + .checked_next_multiple_of(Bytes::CACHELINE.div(TWO)) + .expect("unreachabel because `init` cannot exceed `isize::MAX` bytes"); + + let too_many_entries = || StoreError::too_many_entries(entries, init.nrows()); + + // We have a hard upper-bound of `u32::MAX` total slots. + // + // Thiis enforces that bound. + let entries: u32 = entries.try_into().map_err(|_| too_many_entries())?; + + let frozen: u32 = init.nrows().try_into().map_err(|_| too_many_entries())?; + + let total: u32 = entries.checked_add(frozen).ok_or_else(too_many_entries)?; + + let max_neighbors: u32 = max_neighbors + .try_into() + .map_err(|_| StoreError::too_many_neighbors(max_neighbors))?; + + const FREELIST_SIZE: NonZeroU32 = NonZeroU32::new(1024).unwrap(); + + let me = Self { + buffer: Buffer::new(total.into_usize(), padded_bytes, Align::_128)?, + unpadded, + unfrozen: entries.into_usize(), + tags: repeat_n(Tag::AVAILABLE, total.into_usize()) + .map(AtomicTag::new) + .collect(), + + // NOTE: The `Freelist` is initialized to `entries` and not `total` because + // we do not want it to release frozen IDs. + freelist: Freelist::new(entries, FREELIST_SIZE), + registry: Registry::new(), + neighbors: Neighbors::new(total, max_neighbors)?, + }; + + // Populate frozen points. + for (i, data) in init.row_iter().enumerate() { + // We have checked that the total number of entries fits in `u32`, so this + // arithmetic cannot overflow. + #[expect(clippy::expect_used, reason = "this should always succeed")] + let mut slot = me + .slot(entries + (i as u32)) + .expect("store was just created - claiming the slot must succeed"); + + slot.as_mut_slice().copy_from_slice(data); + slot.freeze(); + } + + Ok(me) + } + + /// Return the range of slots containing frozen items in `self`. + pub(crate) fn frozen(&self) -> std::ops::Range { + (self.unfrozen as u32)..(self.buffer.len() as u32) + } + + /// Return the number of bytes occupied by each entry. + pub(crate) fn bytes(&self) -> Bytes { + self.unpadded + } + + /// Return the maximum degree that can be stored in the graph. + pub(crate) fn max_degree(&self) -> usize { + self.neighbors.max_length() + } + + /// Attempt to reclaim retired slots. + /// + /// If successful, returns the number of slots reclaimed. + pub(crate) fn try_drain(&self) -> Option { + #[expect(clippy::panic, reason = "we cannot proceed if we observe this")] + fn release(tag: &AtomicTag, kind: &'static str) { + // Relaxed ordering is sufficient as all readers/writers are synchronized on + // the central generation. + if let Err(got) = tag.compare_exchange( + Tag::RETIRING, + Tag::AVAILABLE, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + panic!( + "CONCURRENCY VIOLATION: {} - expected {} - got {}", + kind, + Tag::AVAILABLE, + got, + ); + } + } + + let drain = self.registry.try_advance()?; + let items = drain.len(); + for i in drain { + assert!( + i.into_usize() < self.buffer.len(), + "received an invalid ID ({}) while reclaiming slots - max allowed is {}", + i, + self.buffer.len(), + ); + + // We release the mirror before the main tag. The other direction would + // prematurely advertise availability. + // + // SAFETY: We've verified that `i` is in-bounds. + let (mirror, _) = unsafe { self.data_unchecked(i.into_usize()) }; + release(mirror, "mirror"); + release(&self.tags[i.into_usize()], "tag"); + self.freelist.push(i); + } + Some(items) + } + + /// Return a [`Reader`] into the store. + /// + /// # Errors + /// + /// Returns [`epoch::Unavailable`] if there are too many active readers. + pub(crate) fn reader(&self) -> Result, epoch::Unavailable> { + Ok(Reader { + buffer: &self.buffer, + unpadded: self.unpadded, + neighbors: &self.neighbors, + _guard: self.registry.guard()?, + }) + } + + /// Attempt to acquire a new [`Slot`] for writing. + /// + /// This method first consults the freelist and falls back to scanning the tags list + /// if no ID is available from the fast path. + pub(crate) fn acquire(&self) -> Option> { + for _ in 0..RETRY_LIMIT { + match self.freelist.pop() { + freelist::Id::Found(id) => { + if let Some(slot) = self.slot(id) { + return Some(slot); + } + } + freelist::Id::Scan => match self.scan_acquire() { + Some(slot) => return Some(slot), + None => { + self.try_drain(); + } + }, + } + } + None + } + + /// Attempt to retire slot `i`. If successful, this slot will be placed in an internal + /// retirement queue for reclamation once we can prove no readers are active that could + /// have observed this transition. + /// + /// Returns `Ok(())` if the slot was successfully retired. + /// + /// # Errors + /// + /// Returns an error in any of the following conditions: + /// + /// * The slot index `i` is out-of-bounds. + /// * The slot is not in a state that can be retired (e.g., it is already retired or + /// is owned by a different thread). + /// * An [`epoch::Guard`] could not be obtained due to registration slot exhaustion. + /// * An attempt to acquire the slot after these checks races with another thread and + /// the race was lost. + pub(crate) fn retire(&self, i: usize) -> Result<(), RetireError> { + let tag = self.tags.get(i).ok_or(RetireError::OutOfBounds)?; + let current = tag.load(Ordering::Relaxed); + + // We can only perform a deletion if the generation is not in a reserved state. + if current.is_reserved() { + return Err(RetireError::SlotIsReserved { tag: current }); + } + + let guard = self + .registry + .guard() + .map_err(RetireError::GuardUnavailable)?; + + let retiring = Tag::RETIRING; + + // Even if we make this change, we can't access any data until we wait for the + // epoch to be bumped. As such, relaxed semantics are fine. + match tag.compare_exchange(current, retiring, Ordering::Relaxed, Ordering::Relaxed) { + Ok(_) => { + // Set the metadata in the mirror as well. + // + // SAFETY: We've checked that `i` is in-bounds. + let (mirror, _) = unsafe { self.data_unchecked(i) }; + mirror.store(retiring, Ordering::Relaxed); + guard.retire(i as u32); + Ok(()) + } + Err(_) => Err(RetireError::CouldNotClaimSlot), + } + } + + /// A somewhat crude algorithm for cooperatively performing slot scanning. + /// + /// This uses [`Freelist::scan`] to acquire a disjoint chunk of the ID space for scanning, + /// spreading out the search across multiple threads. + /// + /// If we successfully acquire a slot, we continue for the rest of the bucket returned + /// by [`Freelist::scan`] and add any available slots to the freelist (allowing other + /// threads to find them). + /// + /// Periodically, the freelist is checked to see if another thread has found an available + /// slot for us. + fn scan_acquire(&self) -> Option> { + // This is potentially quite slow - but stop if we've scanned the entire range + // without finding anything. + let mut remaining = self.unfrozen; + let mut chunks_since_freelist_check = 0; + let mut acquired: Option> = None; + + while remaining != 0 { + let chunk = self.freelist.scan(); + remaining = remaining.saturating_sub(chunk.len()); + + for slot in chunk { + #[expect( + clippy::expect_used, + reason = "this is a serious bug with the freelist" + )] + let tag = self + .tags + .get(slot.into_usize()) + .expect("freelist scan should not give out invalid IDs"); + + // If this slot is available and we haven't claimed a slot yet, try to + // claim it. Otherwise, continue with the scan to partially repopulate the + // freelist for other threads. + if tag.load(Ordering::Relaxed) == Tag::AVAILABLE { + if acquired.is_none() { + // SAFETY: We're guaranteed that `tag` belongs to `slot`. + acquired = unsafe { self.try_acquire(tag, slot) }; + } else { + self.freelist.push(slot); + } + } + } + + if acquired.is_some() { + return acquired; + } + + chunks_since_freelist_check += 1; + if chunks_since_freelist_check == 4 { + if let Some(id) = self.freelist.pop_recycled() + && let Some(slot) = self.slot(id) + { + return Some(slot); + } + chunks_since_freelist_check = 0; + } + } + None + } + + fn slot(&self, i: u32) -> Option> { + let tag = &self.tags.get(i.into_usize())?; + + // SAFETY: We've guaranteed that `tag` belongs to `slot`. + unsafe { self.try_acquire(tag, i) } + } + + /// Try to acquire `slot` with the associated `tag`. + /// + /// # Safety + /// + /// Caller asserts that `tag` was obtained from `self.tags[slot]`. This is meant as + /// a perfomance optimization where `tag` is first queried for potential availability. + unsafe fn try_acquire<'a>(&'a self, tag: &'a AtomicTag, slot: u32) -> Option> { + match tag.compare_exchange( + Tag::AVAILABLE, + Tag::OWNED, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => { + // SAFETY: Inherited from caller - `slot` is in-bounds. + let (mirror, data) = unsafe { self.data_unchecked(slot.into_usize()) }; + Some(Slot { + tag, + mirror, + data, + slot, + }) + } + Err(_) => None, + } + } + + /// Return the data at position `i` without bound-checking. + /// + /// # Safety + /// + /// The index `i` must be less then `self.buffer.len()`. + unsafe fn data_unchecked(&self, i: usize) -> (&AtomicTag, RawSlice<'_>) { + // SAFETY: inherited from caller. + let (data, mirror) = unsafe { self.buffer.get_unchecked(i) } + .truncate(self.unpadded) + .split(self.unpadded.unchecked_sub(TAG_SIZE)); + ( + // SAFETY: We're careful in this module to ensure the inline tags are only + // ever accessed atomically. + unsafe { AtomicTag::from_ptr(mirror.as_mut_ptr().cast()) }, + data, + ) + } + + /// Return whether or not it is probably okay to read from the slot `i`. + /// + /// This check is approximate and non-synchronizing. To fully check, [`Reader::can_read`] + /// must be used. + /// + /// Returns `None` is index `i` is out-of-bounds. + pub(crate) fn can_read_approximate(&self, i: usize) -> Option { + self.tags + .get(i) + .map(|tag| tag.load(Ordering::Relaxed).can_read()) + } + + #[cfg(test)] + fn writable(&self) -> std::ops::Range { + 0..self.unfrozen as u32 + } +} + +/// Errors occurring during [`Store::new`]. +#[derive(Debug, Error)] +#[error(transparent)] +pub(crate) struct StoreError(StoreErrorInner); + +impl StoreError { + fn mismatched_frozen_point_dim(dim: usize, bytes: Bytes) -> Self { + Self(StoreErrorInner::MismatchedFrozenPointDim { dim, bytes }) + } + + fn need_frozen_point() -> Self { + Self(StoreErrorInner::NeedFrozenPoint) + } + + fn too_many_entries(entries: usize, frozen: usize) -> Self { + Self(StoreErrorInner::TooManyEntries { entries, frozen }) + } + + fn too_many_neighbors(neighbors: usize) -> Self { + Self(StoreErrorInner::TooManyNeighbors { neighbors }) + } +} + +impl From for StoreError { + fn from(err: BufferError) -> Self { + Self(err.into()) + } +} + +impl From for StoreError { + fn from(err: NeighborsError) -> Self { + Self(err.into()) + } +} + +#[derive(Debug, Error)] +enum StoreErrorInner { + #[error( + "frozen point dim ({}) must have the same dimensionality as requested bytes ({})", + dim, + bytes + )] + MismatchedFrozenPointDim { dim: usize, bytes: Bytes }, + #[error("at least one frozen point must be provided")] + NeedFrozenPoint, + #[error( + "total points ({} + {} frozen) must not exceed `u32::MAX`", + entries, + frozen + )] + TooManyEntries { entries: usize, frozen: usize }, + #[error("number of neighbors ({}) may not exceed `u32::MAX`", neighbors)] + TooManyNeighbors { neighbors: usize }, + #[error(transparent)] + BufferError(#[from] BufferError), + #[error(transparent)] + NeighborsError(#[from] NeighborsError), +} + +/// Error conditions for [`Store::retire`]. +#[derive(Debug, Error)] +pub(crate) enum RetireError { + /// Slot index was out-of-bounds. + #[error("index out of bounds")] + OutOfBounds, + /// The slot cannot be retired because it is in a reserved state. + #[error("slot is reserved: {}", tag)] + SlotIsReserved { tag: Tag }, + /// An [`epoch::Guard`] could not be acquired. + #[error(transparent)] + GuardUnavailable(epoch::Unavailable), + /// Another thread won the retirement race. + #[error("could not claim slot")] + CouldNotClaimSlot, +} + +/// An epoch protected reader into a [`Store`]. +/// +/// Created via [`Store::reader`]. +#[derive(Debug)] +pub(crate) struct Reader<'a> { + buffer: &'a Buffer, + unpadded: Bytes, + neighbors: &'a Neighbors, + // It's important that we hold onto this, even if we don't use it. + _guard: epoch::Guard<'a>, +} + +impl<'a> Reader<'a> { + /// Attempt to read the value at index `i`. This can fail for any of the + /// following reasons: + /// + /// 1. Index `i` is out-of-bounds. + /// 2. The read cannot be guaranteed to be race-free. + #[inline] + pub(crate) fn read(&self, i: usize) -> Option<&[u8]> { + if self.is_in_bounds(i) { + // SAFETY: `i` is in-bounds. + unsafe { self.read_in_bounds(i) } + } else { + None + } + } + + /// Return `true` if the index `i` is in-bounds. + #[inline] + #[must_use = "this function has no side-effects"] + pub(crate) fn is_in_bounds(&self, i: usize) -> bool { + i < self.buffer.len() + } + + /// Return `true` if it is safe to read the data at position `i`. + /// + /// This guarantee only holds while `self` is alive. Construction of a new [`Reader`] + /// requires a separate check. + #[cfg(test)] + pub(crate) fn can_read(&self, i: usize) -> Option { + if !self.is_in_bounds(i) { + return None; + } + + // SAFETY: We've checked that `i` is in-bounds. + // + // Further, we guarantee that `self.unpadded >= TAG_SIZE`, so the pointer arithmetic + // is in-bounds. + let tag_ptr = unsafe { + self.buffer + .get_unchecked(i) + .as_mut_ptr() + .add(self.unpadded.unchecked_sub(TAG_SIZE).value()) + }; + + // SAFETY: We only access tag pointers atomically. + let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.cast()) } + .load(Ordering::Acquire) + .can_read(); + + Some(can_read) + } + + /// Read the data as position `i` if it is guaranteed to be race-free without bounds + /// checking. + /// + /// # Safety + /// + /// The index `i` must satisfy [`Self::is_in_bounds`]. + #[inline] + pub(crate) unsafe fn read_in_bounds(&self, i: usize) -> Option<&[u8]> { + debug_assert!(self.is_in_bounds(i)); + + // SAFETY: + // + // * The caller asserts `i` is in-bounds. + // * We maintain an internal invariant that `self.buffer.stride() <= self.unpadded`. + // * Further, we maintain that `self.unpadded >= TAG_SIZE`. + let (data, tag_ptr) = unsafe { + self.buffer + .get_unchecked(i) + .truncate_unchecked(self.unpadded) + .split_unchecked(self.unpadded.unchecked_sub(TAG_SIZE)) + }; + + // NOTE: Must be `Acquire` to correctly synchronize with writes. + // + // SAFETY: We are careful in this module to ensure that inline tags are only accessed + // atomically. + let can_read = unsafe { AtomicTag::from_ptr(tag_ptr.as_mut_ptr().cast()) } + .load(Ordering::Acquire) + .can_read(); + + if can_read { + // SAFETY: We've passed the `can_read` check - `_guard` will ensure the read + // slice is valid and race-free. + Some(unsafe { data.as_slice() }) + } else { + None + } + } + + /// Return the raw data slice for index `i` without any race guarantees. + /// + /// # Safety + /// + /// The index `i` must be satisfy [`Self::is_in_bounds`]. + #[inline] + pub(crate) unsafe fn read_raw_unchecked(&self, i: usize) -> RawSlice<'_> { + // SAFETY: Inherited from caller: `i` is inbounds. + unsafe { self.buffer.get_unchecked(i) }.truncate(self.unpadded) + } + + /// Return the number of bytes for each entry. + pub(crate) fn bytes(&self) -> Bytes { + self.unpadded + } + + /// Return [`Neighbors`]. + pub(crate) fn neighbors(&self) -> &Neighbors { + self.neighbors + } +} + +/// A writable buffer into the data managed by a [`Store`], obtained from [`Store::acquire`]. +#[derive(Debug)] +pub(crate) struct Slot<'a> { + tag: &'a AtomicTag, + mirror: &'a AtomicTag, + data: RawSlice<'a>, + slot: u32, +} + +impl<'a> Slot<'a> { + /// View the managed data as a mutable slice. + pub(crate) fn as_mut_slice(&mut self) -> &mut [u8] { + // SAFETY: The slot guarantees exclusive access to its corresponding data. + unsafe { self.data.as_mut_slice() } + } + + /// Return the slot associated with this write. + pub(crate) fn slot(&self) -> u32 { + self.slot + } + + fn freeze(self) { + let me = std::mem::ManuallyDrop::new(self); + me.mirror.store(Tag::FROZEN, Ordering::Release); + me.tag.store(Tag::FROZEN, Ordering::Release); + } + + /// Consume the slot and publish the written data for all readers. + /// + /// Return the internal slot ID. + pub(crate) fn publish(self) -> u32 { + let id = self.slot(); + let me = std::mem::ManuallyDrop::new(self); + me.mirror.store(Tag::PUBLISHED, Ordering::Release); + me.tag.store(Tag::PUBLISHED, Ordering::Release); + id + } +} + +impl Drop for Slot<'_> { + fn drop(&mut self) { + self.mirror.store(Tag::AVAILABLE, Ordering::Release); + self.tag.store(Tag::AVAILABLE, Ordering::Release); + } +} + +/////////// +// Tests // +/////////// + +/// These tests are basic functionality tests for the store. +/// +/// Longer running conurrency tests are in the integration test suite. +#[cfg(test)] +mod tests { + use super::*; + + use diskann_utils::views::Matrix; + + // Build a store with `entries` writable slots of `entry_bytes` each, backed by `frozen` + // zeroed frozen points. The frozen points occupy the highest slot indices. + fn store(entries: usize, entry_bytes: usize, frozen: usize) -> Result { + let mut data = Matrix::new(0u8, frozen, entry_bytes); + let mut base = 0u8; + for row in data.row_iter_mut() { + row.fill(base); + base = base.wrapping_add(1); + } + + Store::new(entries, Bytes::new(entry_bytes), 0, data.as_view()) + } + + //------------------------// + // Constructor validation // + //------------------------// + + #[test] + fn new_rejects_mismatched_frozen_dim() { + // Frozen point has 8 columns but the store is asked for 16-byte entries. + let data = Matrix::new(0u8, 1, 8); + let err = Store::new(4, Bytes::new(16), 0, data.as_view()).unwrap_err(); + assert!(matches!( + err.0, + StoreErrorInner::MismatchedFrozenPointDim { dim: 8, .. } + )); + } + + #[test] + fn new_requires_a_frozen_point() { + let err = store(4, 8, 0).unwrap_err(); + assert!(matches!(err.0, StoreErrorInner::NeedFrozenPoint)); + } + + #[test] + fn new_rejects_total_slot_overflow() { + // `entries` alone fits in u32, but `entries + frozen` overflows it. + let data = Matrix::new(0u8, 1, 8); + let err = Store::new(u32::MAX as usize, Bytes::new(8), 0, data.as_view()).unwrap_err(); + assert!(matches!(err.0, StoreErrorInner::TooManyEntries { .. })); + } + + #[test] + fn new_rejects_too_many_neighbors() { + let data = Matrix::new(0u8, 1, 8); + let err = + Store::new(4, Bytes::new(8), u32::MAX.into_usize() + 1, data.as_view()).unwrap_err(); + assert!(matches!(err.0, StoreErrorInner::TooManyNeighbors { .. })); + } + + //--------// + // Layout // + //--------// + + #[test] + fn frozen_range_follows_writable_slots() { + let s = store(4, 8, 2).unwrap(); + + // Writable slots are [0, 4); frozen points occupy [4, 6). + assert_eq!(s.frozen(), 4..6); + + let reader = s.reader().unwrap(); + for i in 0..4 { + assert!(!s.can_read_approximate(i).unwrap()); + assert!(!reader.can_read(i).unwrap()); + assert!(reader.read(i).is_none()); + } + + assert!(s.can_read_approximate(4).unwrap()); + assert!(reader.can_read(4).unwrap()); + assert_eq!(reader.read(4).unwrap(), &[0, 0, 0, 0, 0, 0, 0, 0]); + + assert!(s.can_read_approximate(5).unwrap()); + assert!(reader.can_read(5).unwrap()); + assert_eq!(reader.read(5).unwrap(), &[1, 1, 1, 1, 1, 1, 1, 1]); + + assert!(s.can_read_approximate(6).is_none()); + assert!(reader.can_read(6).is_none()); + assert!(reader.read(6).is_none()); + } + + /////////////// + // Lifecycle // + /////////////// + + #[test] + fn acquire_write_publish_read_roundtrip() { + let s = store(4, 8, 1).unwrap(); + + let reader = s.reader().expect("reader guard available"); + + let idx = { + let mut slot = s.acquire().expect("a fresh store has free slots"); + let idx = slot.slot() as usize; + slot.as_mut_slice() + .copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + + // Before the slot is dropped - we should not be able to read it. + assert!(reader.read(idx).is_none()); + assert!(!s.can_read_approximate(idx).unwrap()); + slot.publish(); + idx + }; + + assert_eq!(reader.read(idx), Some([1, 2, 3, 4, 5, 6, 7, 8].as_slice())); + assert!(s.can_read_approximate(idx).unwrap()); + } + + #[test] + fn unpublished_slots_are_immediately_available() { + let s = store(4, 8, 1).unwrap(); + + let reader = s.reader().expect("reader guard available"); + + let idx = { + let mut slot = s.acquire().expect("a fresh store has free slots"); + let idx = slot.slot() as usize; + slot.as_mut_slice() + .copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + + // Before the slot is dropped - we should not be able to read it. + assert!(reader.read(idx).is_none()); + assert!(!s.can_read_approximate(idx).unwrap()); + + // NOTE: We do not explicitly publish the slot. + idx + }; + + assert!(reader.read(idx).is_none()); + assert!(!s.can_read_approximate(idx).unwrap()); + } + + #[test] + fn acquire_exhausts_then_reports_none() { + let s = store(2, 8, 1).unwrap(); + // Hold the guards so the slots stay owned. + let _a = s.acquire().expect("first writable slot"); + let _b = s.acquire().expect("second writable slot"); + assert!( + s.acquire().is_none(), + "all writable slots are owned, so acquire must fail" + ); + } + + //--------// + // Retire // + //--------// + + #[test] + fn retire_out_of_bounds() { + let s = store(4, 8, 1).unwrap(); + assert!(matches!(s.retire(999), Err(RetireError::OutOfBounds))); + } + + #[test] + fn retire_rejects_reserved_slots() { + let s = store(4, 8, 1).unwrap(); + // An untouched writable slot is AVAILABLE, which is a reserved state. + assert!(matches!( + s.retire(0), + Err(RetireError::SlotIsReserved { .. }) + )); + // A frozen slot is likewise reserved. + let frozen = s.frozen().start as usize; + assert!(matches!( + s.retire(frozen), + Err(RetireError::SlotIsReserved { .. }) + )); + // An owned slot is not retirable. + let slot = s.acquire().unwrap(); + assert!(matches!( + s.retire(slot.slot() as usize), + Err(RetireError::SlotIsReserved { .. }) + )); + } + + #[test] + fn retire_published_slot_then_unreadable() { + let s = store(4, 8, 1).unwrap(); + + let idx = { + let slot = s.acquire().unwrap(); + slot.publish() as usize + }; + + assert!(s.retire(idx).is_ok()); + + // A reader opened after retirement must not observe the retired slot. + let reader = s.reader().unwrap(); + assert_eq!(reader.read(idx), None); + assert_eq!(reader.can_read(idx), Some(false)); + + // The slot can also not be retired again. + assert!(matches!( + s.retire(idx), + Err(RetireError::SlotIsReserved { .. }) + )); + } + + //---------// + // Recycle // + //---------// + + #[test] + fn test_recycling() { + let entries = if cfg!(miri) { 16 } else { 2048 }; + + let s = store(entries, 4, 2).unwrap(); + + // Claim all slots. + let mut count = 0; + while let Some(slot) = s.acquire() { + slot.publish(); + count += 1; + } + + assert_eq!(count, s.writable().len()); + + // Now that all slots are claimed - retire all slots. + for i in s.writable() { + s.retire(i.into_usize()).unwrap(); + } + + // Verify that we can claim all slots again. + let mut count = 0; + while let Some(slot) = s.acquire() { + slot.publish(); + count += 1; + } + + assert_eq!(count, s.writable().len()); + } +} diff --git a/diskann-inmem/src/tag.rs b/diskann-inmem/src/tag.rs new file mode 100644 index 000000000..e96e234f7 --- /dev/null +++ b/diskann-inmem/src/tag.rs @@ -0,0 +1,328 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! State tags for slots participating in the EBR protocol. +//! +//! This module defines [`Tag`] and [`AtomicTag`], a small state machine used to label +//! individual slots in concurrent data structures. Tags pair with the epoch-based +//! reclamation machinery in [`super::epoch`]: epochs decide *when* it is safe to reclaim a +//! slot, while tags decide *whether* a given slot is currently readable, owned, or in +//! transition. +//! +//! Note that the type system does not enforce the tag protocol — only the documented +//! transitions on [`Tag`] are sound, and it is the caller's responsibility to follow them. + +use std::sync::atomic::{AtomicU8, Ordering}; + +/// A tag for controlling concurrent access to data. +/// +/// Tag updates and reads should use [`AtomicTag`]. +/// +/// A reader holding a [`Guard`](super::epoch::Guard) performs an [`Ordering::Acquire`] load +/// on an [`AtomicTag`]; if [`Tag::can_read`] returns `true`, the reader may access the +/// data this tag protects. +/// +/// # Named Tags +/// +/// * [`Tag::PUBLISHED`]: The associated slot has been published and may be freely accessed +/// by readers. +/// +/// * [`Tag::FROZEN`]: This data is protected and is not expected to be mutated. Readers +/// may still freely access this data. `FROZEN` has no defined transitions in this +/// protocol; once a slot is frozen it remains so for the lifetime of the structure. +/// +/// * [`Tag::AVAILABLE`]: The associated slot is not currently storing valid data +/// and is available to use. +/// +/// Ownership is acquired via a CAS from `AVAILABLE` to `OWNED`. +/// +/// * [`Tag::OWNED`]: The associated data is owned by some thread. Only the thread +/// owning this slot may update it. +/// +/// Note that ownership may be transferred between threads as long as this ownership +/// transfer is unambiguous and properly synchronized. +/// +/// In this state, the owning thread may write to the associated data. +/// +/// * [`Tag::RETIRING`]: Indicates that this slot is currently being [retired](super::epoch). +/// Readers may not access associated data after reading this tag, but readers who accessed +/// the tag before retirement may still exist. +/// +/// Only transition away from this value when the corresponding slot is returned from a +/// [`Drain`](super::epoch::Drain). +/// +/// # Allowed Transitions +/// +/// The following protocol must be used when working with [`AtomicTag`]ged data and a +/// [`Registry`](super::Registry). +/// +/// * [`Tag::AVAILABLE`] -> [`Tag::OWNED`]: Use a CAS to ensure unique ownership. Once in +/// the owned state, unsynchronized writes can be made to associated data. +/// +/// * [`Tag::OWNED`] -> [`Tag::PUBLISHED`]: Must be done as an [`Ordering::Release`] store +/// and only by the thread that acquired ownership. +/// +/// * [`Tag::PUBLISHED`] -> [`Tag::RETIRING`]: Must be done while under a +/// [`Guard`](super::epoch::Guard) and may be done with relaxed atomics. Writes to +/// associated data may not be made. Place into [`Guard::retire`](super::epoch::Guard::retire) +/// for final reclamation. +/// +/// * [`Tag::RETIRING`] -> [`Tag::AVAILABLE`]: May only be done if the corresponding slot is +/// retrieved from a [`Drain`](super::epoch::Drain). Writes may occur to associated data +/// and if so, this transition must be made with [`Ordering::Release`]. +/// +/// # Reading +/// +/// Checks to [`Tag::can_read`] can be made following [`Ordering::Acquire`] loads. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub(crate) struct Tag(u8); + +impl Tag { + //-------------// + // High Values // + //-------------// + + /// The slot is permanently readable and never mutated again. See [`Tag`]. + pub(crate) const FROZEN: Self = Self::new(u8::MAX); + + /// The slot has been published and is freely readable. See [`Tag`]. + pub(crate) const PUBLISHED: Self = Self::new(u8::MAX - 1); + + //------------// + // Low Values // + //------------// + + /// The slot holds no valid data and may be claimed via CAS to [`Tag::OWNED`]. + /// See [`Tag`]. + pub(crate) const AVAILABLE: Self = Self::new(0); + + /// The slot is exclusively owned by a single thread that may write its data. + /// See [`Tag`]. + pub(crate) const OWNED: Self = Self::new(1); + + /// The slot is in the process of being retired and is no longer readable to new + /// readers. See [`Tag`]. + pub(crate) const RETIRING: Self = Self::new(2); + + /// NOTE: We rely on reserved values being contiguous so `is_reserved` can be + /// implemented relatively efficiently. + const RESERVED: Self = Self::RETIRING; + + /// Return `true` if `self` is one of the protocol's reserved tag values. + /// + /// Reserved tags are part of the protocol's fixed vocabulary and are never delivered + /// as retirement payloads. + #[must_use = "this function has no side-effects"] + pub(crate) fn is_reserved(self) -> bool { + (self <= Self::RESERVED) || (self == Self::FROZEN) + } + + /// Return `true` if `self` is in a state where it is legal to access tagged data. + #[must_use = "this function has no side-effects"] + pub(crate) fn can_read(self) -> bool { + // Tags are split into `high` (readable) and `low` (non-readable) values so this + // check reduces to a single comparison. + self >= Self::PUBLISHED + } + + /// Construct a new [`Tag`] with `value`. + #[inline] + const fn new(value: u8) -> Self { + Self(value) + } + + /// Return the value of `self`. + #[inline] + const fn value(self) -> u8 { + self.0 + } +} + +impl std::fmt::Display for Tag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let me = *self; + if me == Self::AVAILABLE { + f.write_str("Tag(AVAILABLE)") + } else if me == Self::OWNED { + f.write_str("Tag(OWNED)") + } else if me == Self::RETIRING { + f.write_str("Tag(RETIRING)") + } else if me == Self::FROZEN { + f.write_str("Tag(FROZEN)") + } else if me == Self::PUBLISHED { + f.write_str("Tag(PUBLISHED)") + } else { + write!(f, "Tag({})", me.value()) + } + } +} + +/// An atomic [`Tag`]. +/// +/// Memory orderings are the caller's responsibility and must be chosen consistent with the +/// protocol described on [`Tag`]. +#[derive(Debug)] +#[repr(transparent)] +pub(crate) struct AtomicTag(AtomicU8); + +impl AtomicTag { + /// Construct a new [`AtomicTag`] initialized to `tag`. + pub(crate) const fn new(tag: Tag) -> Self { + Self(AtomicU8::new(tag.value())) + } + + /// Creates a new reference to a `AtomicTag` from a raw pointer. + /// + /// # Safety + /// + /// * `ptr` must be aligned to `align_of::()`. + /// * `ptr` must be valid for both reads and writes for the whole lifetime `'a`. + /// * The caller chooses `'a`; the underlying allocation must outlive `'a`. + /// * This must adhere to the memory model for atomic accesses. In particular, it must + /// not admit conflicting atomic and non-atomic accesses, or atomic accesses of + /// different sizes without synchronization. + /// + /// See: + pub(crate) unsafe fn from_ptr<'a>(ptr: *mut AtomicTag) -> &'a Self { + // SAFETY: inherited from caller. + unsafe { &*ptr } + } + + /// Perform an atomic compare-exchange with the provided orderings. + /// + /// Note that this does not enforce the [`Tag`] transition protocol; the caller must + /// ensure `current` and `new` correspond to a legal transition. + /// + /// See: [`AtomicU8::compare_exchange`]. + pub(crate) fn compare_exchange( + &self, + current: Tag, + new: Tag, + success: Ordering, + failure: Ordering, + ) -> Result { + self.0 + .compare_exchange(current.value(), new.value(), success, failure) + .map(Tag::new) + .map_err(Tag::new) + } + + /// Perform an atomic load with the provided ordering. + /// + /// See: [`AtomicU8::load`]. + pub(crate) fn load(&self, ordering: Ordering) -> Tag { + Tag::new(self.0.load(ordering)) + } + + /// Perform an atomic store with the provided ordering. + /// + /// See: [`AtomicU8::store`]. + pub(crate) fn store(&self, val: Tag, ordering: Ordering) { + self.0.store(val.value(), ordering) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::{sync::Barrier, thread}; + + use crate::{ + buffer::Buffer, + num::{Align, Bytes}, + }; + + fn spin_decrement(m: &AtomicTag, count: usize) { + for _ in 0..count { + let mut current = m.load(Ordering::Relaxed); + while let Err(c) = m.compare_exchange( + current, + Tag::new(current.value().wrapping_sub(1)), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + current = c; + } + } + } + + #[test] + fn test_atomic() { + let threads = 4; + let barrier = &Barrier::new(threads); + + // This dance basically verifies that we can view the tag though a proper-aligned + // raw pointer. + let buffer = + Buffer::new(1, Bytes::size_of::(), Align::of::()).unwrap(); + let ptr = buffer.get(0).unwrap().as_mut_ptr().cast::(); + + { + // SAFETY: We only access these atomically. + let tag = unsafe { AtomicTag::from_ptr(ptr) }; + tag.store(Tag::FROZEN, Ordering::Relaxed); + } + + let count = 1000; + thread::scope(|s| { + for _ in 0..threads { + s.spawn(|| { + // Re-derive `p` to avoid issues with `Send`. + let p = buffer.get(0).unwrap().as_mut_ptr().cast::(); + + // SAFETY: We only access this atomically. + let tag = unsafe { AtomicTag::from_ptr(p) }; + barrier.wait(); + spin_decrement(tag, count); + }); + } + }); + + { + // SAFETY: We only access this atomically. + let g = unsafe { AtomicTag::from_ptr(ptr) }.load(Ordering::Relaxed); + assert_eq!(g, Tag::new(u8::MAX.wrapping_sub((count * threads) as u8))); + } + } + + #[test] + fn test_is_reserved() { + assert!(Tag::FROZEN.is_reserved()); + assert!(!Tag::PUBLISHED.is_reserved()); + + assert!(Tag::AVAILABLE.is_reserved()); + assert!(Tag::OWNED.is_reserved()); + assert!(Tag::RETIRING.is_reserved()); + } + + #[test] + fn test_can_read() { + assert!(Tag::FROZEN.can_read()); + assert!(Tag::PUBLISHED.can_read()); + + assert!(!Tag::AVAILABLE.can_read()); + assert!(!Tag::OWNED.can_read()); + assert!(!Tag::RETIRING.can_read()); + } + + #[test] + fn test_display() { + assert_eq!(Tag::AVAILABLE.to_string(), "Tag(AVAILABLE)"); + assert_eq!(Tag::OWNED.to_string(), "Tag(OWNED)"); + assert_eq!(Tag::RETIRING.to_string(), "Tag(RETIRING)"); + assert_eq!(Tag::FROZEN.to_string(), "Tag(FROZEN)"); + assert_eq!(Tag::PUBLISHED.to_string(), "Tag(PUBLISHED)"); + + // Guard against future changes. + assert_eq!(Tag::new(Tag::RETIRING.value() + 1).to_string(), "Tag(3)"); + assert_eq!(Tag::new(Tag::PUBLISHED.value() - 1).to_string(), "Tag(253)"); + } +} diff --git a/diskann-inmem/src/test/epoch.rs b/diskann-inmem/src/test/epoch.rs new file mode 100644 index 000000000..f97f4dfe3 --- /dev/null +++ b/diskann-inmem/src/test/epoch.rs @@ -0,0 +1,302 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Directed stress tests for `Registry`. + +use std::{ + cell::UnsafeCell, + mem::MaybeUninit, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use rand::{Rng, distr::StandardUniform}; + +use crate::{ + epoch::Registry, + tag::{AtomicTag, Tag}, +}; + +type Data = [u32; 4]; + +struct Slot { + tag: AtomicTag, + payload: UnsafeCell>>, +} + +impl Slot { + fn new() -> Self { + Self { + tag: AtomicTag::new(Tag::AVAILABLE), + payload: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + fn try_claim(&self, payload: Data, f: F) + where + F: FnOnce(), + { + if self + .tag + .compare_exchange( + Tag::AVAILABLE, + Tag::OWNED, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + // SAFETY: By transitioning from AVAILABLE to OWNED, we've acquired ownership + // of this slot and are thus free to write to the `UnsafeCell`. + unsafe { &mut *self.payload.get() }.write(Box::new(payload)); + f(); + self.tag.store(Tag::PUBLISHED, Ordering::Release); + } + } + + fn try_read(&self) -> Option<&Data> { + if self.tag.load(Ordering::Acquire).can_read() { + // SAFETY: We've checked that we can read this cell. + let payload = unsafe { &*self.payload.get() }; + + // SAFETY: Items that can be read **must** be initialized. + Some(unsafe { payload.assume_init_ref() }) + } else { + None + } + } + + #[must_use] + fn retire(&self) -> bool { + let tag = self.tag.load(Ordering::Relaxed); + if tag != Tag::PUBLISHED { + return false; + } + + self.tag + .compare_exchange( + Tag::PUBLISHED, + Tag::RETIRING, + Ordering::Relaxed, + Ordering::Relaxed, + ) + .is_ok() + } + + unsafe fn make_available(&self) { + assert_eq!(self.tag.load(Ordering::Relaxed), Tag::RETIRING); + + // SAFETY: Items tagged as `RETIRING` must be initialized. + unsafe { (&mut *self.payload.get()).assume_init_drop() }; + + if self + .tag + .compare_exchange( + Tag::RETIRING, + Tag::AVAILABLE, + Ordering::Release, + Ordering::Relaxed, + ) + .is_err() + { + panic!("concurrency violation"); + } + } +} + +impl Drop for Slot { + fn drop(&mut self) { + if self.tag.load(Ordering::Relaxed) != Tag::AVAILABLE { + let payload = self.payload.get_mut(); + + // SAFETY: We have exclusive access and by convention, if the tag is not + // available, then the corresponding payload is initialized. + unsafe { payload.assume_init_drop() }; + } + } +} + +// SAFETY: We control concurrency, so can safely share this. +unsafe impl Sync for Slot {} + +fn make_payload(epoch: u64, index: usize) -> Data { + [ + index as u32, + epoch as u32, + (epoch >> 32) as u32, + (index as u32) ^ (epoch as u32) ^ ((epoch >> 32) as u32), + ] +} + +fn verify_payload(data: &Data) -> (usize, u64) { + let checksum = data[0] ^ data[1] ^ data[2]; + assert_eq!( + data[3], checksum, + "torn or corrupted read: payload {data:?}, expected checksum {checksum}" + ); + let index = data[0] as usize; + let epoch = data[1] as u64 | ((data[2] as u64) << 32); + (index, epoch) +} + +struct Record { + epoch: u64, + index: usize, + data: Data, +} + +fn read_job( + registry: &Registry, + slots: &[Slot], + stop_at: u64, + retire_rate: f64, + active: &AtomicUsize, +) -> Vec { + assert!(retire_rate > 0.0); + assert!(retire_rate < 1.0); + + let mut records = Vec::new(); + let mut rng = rand::rng(); + + loop { + let mut reads = Vec::<&Data>::new(); + let guard = registry.guard().unwrap(); + if guard.epoch >= stop_at { + break; + } + + for (i, slot) in slots.iter().enumerate() { + if let Some(read) = slot.try_read() { + reads.push(read); + + let sample: f64 = rng.sample(StandardUniform); + if sample < retire_rate && slot.retire() { + guard.retire(i as u32); + active.fetch_sub(1, Ordering::Release); + + std::thread::yield_now(); + records.push(Record { + epoch: guard.epoch, + index: i, + data: *read, + }); + } + } + } + } + + records +} + +fn retire_job(registry: &Registry, slots: &[Slot], stop_at: u64, active: &AtomicUsize) { + loop { + let epoch = registry.epoch(); + if epoch >= stop_at { + return; + } + + if active.load(Ordering::Acquire) != 0 { + std::thread::yield_now(); + continue; + } + + if let Some(drain) = registry.try_advance() { + for i in drain { + // SAFETY: retrieving from the drain gives us exclusive access. + unsafe { slots[i as usize].make_available() }; + } + } + } +} + +fn write_job(registry: &Registry, slots: &[Slot], stop_at: u64, active: &AtomicUsize) { + loop { + let epoch = registry.epoch(); + if epoch >= stop_at { + return; + } + + for (i, slot) in slots.iter().enumerate() { + slot.try_claim(make_payload(epoch, i), || { + active.fetch_add(1, Ordering::Relaxed); + }); + } + + std::thread::yield_now(); + } +} + +#[test] +fn registry_stress_test() { + let registry = Registry::new(); + let slots: Vec<_> = std::iter::repeat_with(Slot::new).take(10).collect(); + let active = AtomicUsize::new(0); + + let stop_at = if cfg!(miri) { 11 } else { 50_000 }; + let retire_rate = if cfg!(miri) { 0.95 } else { 0.1 }; + + // We use two threads for each job to be extra adversarial. + let barrier = std::sync::Barrier::new(6); + let result = std::thread::scope(|s| { + // Spin up readers. + let r0 = s.spawn(|| { + barrier.wait(); + read_job(®istry, &slots, stop_at, retire_rate, &active) + }); + + let r1 = s.spawn(|| { + barrier.wait(); + read_job(®istry, &slots, stop_at, retire_rate, &active) + }); + + // Spin up writers + s.spawn(|| { + barrier.wait(); + write_job(®istry, &slots, stop_at, &active); + }); + + s.spawn(|| { + barrier.wait(); + write_job(®istry, &slots, stop_at, &active); + }); + + // Spin up retirers + s.spawn(|| { + barrier.wait(); + retire_job(®istry, &slots, stop_at, &active); + }); + s.spawn(|| { + barrier.wait(); + retire_job(®istry, &slots, stop_at, &active); + }); + + let mut r0 = r0.join().unwrap(); + let r1 = r1.join().unwrap(); + r0.extend(r1); + r0 + }); + + for record in &result { + let (index, write_epoch) = verify_payload(&record.data); + + // The index encoded in the payload must match the slot we read from. + assert_eq!( + index, record.index, + "slot identity mismatch: payload says slot {index}, record says slot {}", + record.index + ); + + // The slot was written at `write_gen` and read at `record.generation. + // Since generations increase (newer = larger), write_gen <= record.generation + // means the write happened at or before the reader's epoch. + // + // Note that a reader can observe one epoch change during its tenure, so we *can* + // observe writes from one higher epoch. + assert!( + write_epoch <= (record.epoch + 1), + "read data from the future: write_gen={write_epoch}, read_gen={}", + record.epoch + ); + } +} diff --git a/diskann-inmem/src/test/mod.rs b/diskann-inmem/src/test/mod.rs new file mode 100644 index 000000000..e91e30f81 --- /dev/null +++ b/diskann-inmem/src/test/mod.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod sequencer; +pub(crate) use sequencer::Sequencer; + +// Longer Running Tests +mod epoch; diff --git a/diskann-inmem/src/test/sequencer.rs b/diskann-inmem/src/test/sequencer.rs new file mode 100644 index 000000000..81c9cffd9 --- /dev/null +++ b/diskann-inmem/src/test/sequencer.rs @@ -0,0 +1,86 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::Arc; + +use parking_lot::{Condvar, Mutex}; + +#[derive(Clone)] +pub(crate) struct Sequencer(Arc); + +struct SequencerInner { + state: Mutex, + condvar: Condvar, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum State { + Empty, + Parked(usize), + Released(usize), +} + +impl Sequencer { + pub(crate) fn new() -> Self { + Self(Arc::new(SequencerInner { + state: Mutex::new(State::Empty), + condvar: Condvar::new(), + })) + } + + pub(crate) fn wait_for(&self, stage: usize) { + let mut state = self.0.state.lock(); + if stage == 0 { + assert_eq!(*state, State::Empty) + } else { + assert_eq!(*state, State::Released(stage - 1)) + } + + *state = State::Parked(stage); + self.0.condvar.notify_all(); + self.0 + .condvar + .wait_while(&mut state, move |s| *s != State::Released(stage)); + } + + pub(crate) fn advance_past(&self, stage: usize) { + let mut state = self.0.state.lock(); + self.0 + .condvar + .wait_while(&mut state, move |s| Self::check_release(*s, stage)); + *state = State::Released(stage); + self.0.condvar.notify_all(); + } + + pub(crate) fn until_waiting_for(&self, stage: usize) { + let mut state = self.0.state.lock(); + if *state != State::Parked(stage) { + self.0 + .condvar + .wait_while(&mut state, move |s| Self::check_release(*s, stage)) + } + } + + fn check_release(current: State, stage: usize) -> bool { + match current { + State::Empty => { + assert_eq!(stage, 0); + true + } + State::Released(s) => { + if s + 1 != stage { + panic!("observed {:?} while releasing stage {}", current, stage); + } + true + } + State::Parked(s) => { + if s != stage { + panic!("observed {:?} while releasing stage {}", current, stage) + } + false + } + } + } +} diff --git a/diskann-vector/src/distance/implementations.rs b/diskann-vector/src/distance/implementations.rs index 25a281bff..30bd80431 100644 --- a/diskann-vector/src/distance/implementations.rs +++ b/diskann-vector/src/distance/implementations.rs @@ -41,7 +41,7 @@ macro_rules! architecture_hook { { #[inline(always)] fn run(arch: A, left: L, right: R) -> T { - arch.run2(Self::default(), left, right) + arch.run2_inline(Self::default(), left, right) } } }; diff --git a/rfcs/01206-inmem2.md b/rfcs/01206-inmem2.md new file mode 100644 index 000000000..55bd0f07e --- /dev/null +++ b/rfcs/01206-inmem2.md @@ -0,0 +1,270 @@ +# Concurrent In-Memory Index + +| | | +|---|---| +| **Authors** | Mark Hildebrand | +| **Contributors** | | +| **Created** | 2026-07-01 | + +## Summary + +Let's make our in-memory index robust under concurrent operations. + +## Motivation + +### Background + +There are [two methods](https://github.com/microsoft/DiskANN/blob/b603ec009ea5c3cdbbd5358ca88b4a2a30c8d52b/diskann-providers/src/model/graph/provider/async_/common.rs#L148-L195) at the heart of our current in-memory index that deeply bother me: `get_slice` and `get_slice_mut`. +These methods allow concurrent, unsynchronized access between mutable and immutable data. +While there is some measure of safety on the [write path](https://github.com/microsoft/DiskANN/blob/b603ec009ea5c3cdbbd5358ca88b4a2a30c8d52b/diskann-providers/src/model/graph/provider/async_/fast_memory_vector_provider.rs#L44-L47), this is insufficient to prevent a concurrent reader and a writer and there is no systematic safety protocol in place to prevent this situation: safety comments are basically "yolo" (I can make fun of them because I've written such comments myself). + +This is problematic for `diskann` for a number of reasons: + +* It prevents us from stress testing our algorithm under concurrent inserts/search/deletes etc. + Unfortunately, this is the situation under which most of our database integrations actually operate. + +* Coupled with the lack of ID translation mechanism for the inmem provider, we have no protection against races inserting multiple internal IDs simultaneously or a coherent story between concurrent inserts and deletes to the same ID. + +* It makes me sad. + +So why is it this way? +Performance! +Yoloing a pointer read is cheap; concurrency is expensive. +An obvious safety mechansim would be to slap a mutex or [fine-grained RCU](https://github.com/microsoft/DiskANN/blob/b603ec009ea5c3cdbbd5358ca88b4a2a30c8d52b/diskann-providers/src/model/graph/provider/async_/memory_vector_provider.rs#L23) around each slot. +Unfortunately, approaches like this that require cache-line **writes** before (and after) each read operation are prohibitively slow and completely unnecessary for [static builds](https://github.com/microsoft/DiskANN/blob/main/diskann-disk/src/build/builder/inmem_builder.rs). + +### Problem Statement + +Let's try to fix the in-memory provider to make it: + +* Well defined under concurrent operations. +* Enable co-development of the core algorithm to patch our remaining concurrency issues. +* Easier to use than our current providers. +* Faster to compile. +* Support external/internal ID translation. +* Have much better tests. +* Without completely compromising performance. + +## Proposal + +This RFC accompanies [1206](https://github.com/microsoft/DiskANN/pull/1206), which is a MVP of the proposed design. + +The design assumes that traffic to the data store is primarily **read** heavy (to facilitated searches). +As such, we're willing to sacrifice some write performance (or more specifically, the time between when a slot is deleted and when it is reused) to make reads fast. +Importantly, readers can determine whether it is safe (or not) to access data with a single 1-byte atomic read. +**Following** a successful check, readers may access the associated data without fear of races. + +### Concurrency + +**Definition**: A "slot" is any data (potentially spread across multiple containers) that is uniquely associated with an internal ID. +This part of the discussion focuses on slots and concurrently controlling access to the data contained within. + +At a high level, consider associating every slot with a 1-byte atomic tag with states + +* "available": This slot is free to claim but cannot be read. + The data managed by this slot should be considered invalid. + +* "owned": This slot is exclusively owned by a thread. + No other threads may access data in the slot and the owner is free to write. + +* "published": The slot is publicly available. + Readers that observe a "published" tag can read the associated data. + +#### Attempt 1 + +Consider a simple (but wholly invalid) concurrency protocol where a writer acquires a slot by transitioning (via [CAS](https://doc.rust-lang.org/std/sync/atomic/type.AtomicUsize.html#method.compare_exchange)) the tag from "available" to "owned", writes data, then transitions the slot to "published". +A delete could happen by transitioning from "published" to "owned", do its thing, and then either re-publish or make it "available". +This does not work as a solution to our goal. + +The problem is that even if intrepid writer transitions the tag to "owned", it has no guaranteed that readers who observed the "published" state prior to this transition are done using the data. +As a diagram +``` +Time | Reader | Writer +-------+----------------------------+----------- + 1 | Reads "published" | + | Decides it's safe to read | +-------+----------------------------+---------------------------------------- + 2 | | Transitions from "published" to "owned" +-------+----------------------------+---------------------------------------- +RACE 3 | Starts Reading | Starts Writing +``` +We get undefined behavior at step 3! + +#### Attempt 2 + +We could augment attempt 1 by having the reader thread check the tag after it finishes its operation. +If it observes a non-"published" state, it can abort its operation. +This cannot be done safely with our current tag scheme due to the ABA problem: a reader could see "published", start its operation, and then get preempted by the operating system (OS). +In the meantime, a writer can transition from "published" to "owned", do its thing, and then set it back to "published". +When the original reader gets back from its vacation, it still sees "published" and thinks everything is okay even though it potentially operated on invalid data. + +Schemes like this can be used by constructs like [sequence locks](https://en.wikipedia.org/wiki/Seqlock). +However, these are + +* Famously impossible to represent (without UB) in the semantics of high-level programming languages. + +* Basically the only operation you can do "safely" under a sequence lock is a memcpy. + You can't compute - and we'd really like to be able to compute a distance on the data in-place. + +* Require larger counters to avoid longer-ranged versions of the previously described ABA problem where the counter fully wraps around. + +#### Solution + +What we need a little more communication. +Conceptually, the writer transitions a tag from "published" to "owned", but then **waits** until it can guarantee that no readers are alive any more that could have observed that change. +Only at this point can the writer do its thing without undefined behavior (don't worry, in the implementation here, writers are *not* blocked waiting for readers). + +This is where [epoch-based reclamation](https://docs.rs/crossbeam-epoch/latest/crossbeam_epoch/) enters the picture. +For each provider, we maintain an monotonically increasing "epoch". +The idea is that for every search/insert/delete operation, a `Reader` is first created which registers itself as using the current epoch (call it `E`). +This `Reader` behaves as previously described, reading tags and if it observes "published": reading the associated data. +When the operation finishes, the `Reader` deregisters itself. + +**Importantly**, a `Reader` can also retire slots by transitioning their tag (via CAS) from "published" to a new "retired" state and inserting the tag index into its epoch-specific queue. +This "retired" state (1) prevents future readers from accessing the data, and (2) prevents other threads from trying to claim the slot. +The epoch-specific queue holds onto the slot ID until all readers who could have observed that transition have been deregistered. + +Cleaning up retired items from epoch-queue happens during epoch advancement. +Periodically (e.g., one in every `N` inserts or searches or via a background process), we try to advance the epoch. +An epoch can **only** be advanced from `E` to `E+1` if all registered readers belong to epoch `E`. +Any reader at epoch `E-1` will prevent the transition. +When we successfully advance the epoch, we get the epoch-queue associate with epoch `E-2`. +This queue contains the slot indexes for those retired at epoch `E-2`. +Because all current readers belong to `E` or `E+1`, we are guaranteed that all current readers agree that the state of the tag is "retired" and will not be trying to read any data in the slots contained within the queue. +As such, the thread processing epoch `E-2` is free to write to the data in these slots and transition them to other states without fear of a race, thus solving our problem! + +With these scheme, we can keep recycling the same four epoch queues because the scheme guarantees that only two are ever written to at a time. +Literature often claims that just three queues are needed. +We need an extra one because each `Reader` pushes items into the epoch queue associated with the `Reader`'s creation epoch. +A `Reader` in epoch `E` can retire a slot into the `E` queue, but a `Reader` in epoch `E+1` **can** observe this retirement. +Then the epoch advances to `E+2`. +If we pulled the offending slot out of the `E` queue, the `Reader` in `E+1` could still be reading it and we're back to undefined behavior. +Introducing a fourth queue fixes this issue. + +#### Implementation + +The core components of this protocol are split across three files in #1206: + +* `tag.rs`: The implementation of atomic slot tags. + The PR contains a few more states for slots to enable slots to be in special (e.g., "frozen") states, but follows the main "available", "published", "owned", "retired" scheme outlined above. + +* `epoch.rs`: The logic for `Reader` registration, deregistration, epoch advancement, and epoch queues. + Unsurprisingly, convincing a bunch of concurrent threads to get along with minimal locking is subtle. + +* `store.rs`: A package data store with slots that completes the implementation by providing + a safe `Reader` based abstraction on the store. + +Within `store.rs`, there are actually two tags per slot. +An authoritative tag in a `Vec`, and a mirrored tag that lives inline with the data being stored in the slot. +The idea here is that during search, we can emit prefetches for the data we're going to process and either get the mirror for free, or take advantage of locality to avoid something like a page fault. +This mirror tag can then be used for the safety check instead. +For quantization algorithms like spherical that don't always generate a nice power-of-two number of bytes, there is likely unused space in our cache line padding so this 1-byte tag can be stored for free. + +### Reconciling Performance + +Even though we've brought the concurrency overhead down to just 1-byte per slot with a light weight check (about 6 instructions), this is still strictly more data than the current index. +This is particularly painful for datasets like "sift", where any additional read pulls in an additional cache line, moving from 2 cachelines to 3. +There is a little bit of work that can be done. +PR [1067](https://github.com/microsoft/DiskANN/pull/1067) moved the search contract behind a single `expand_beam` function. +PR 1026 uses a variation of [bring your own type-erasure](https://github.com/microsoft/DiskANN/pull/1068) to enable distance layers to + +* Inline their final distance functions directly into the `expand_beam` implementation rather than relying on [function pointers](https://github.com/microsoft/DiskANN/blob/main/diskann-vector/src/distance/distance_provider.rs). +* Further, length-specialized implementation can communicate their element byte size via const-generics, allowing the final `expand_beam` implementation to emit the exact number of prefetch instructions. + +As an example, the `expand_beam` inner loop for 100-dimensional L2 vectors compiles to the following assembly: +``` +.LBB4_27: | Check if all neighbors have been processed + mov r10, r15 | + add rdx, 4 | + mov rax, rsi | + cmp r14, rdx | + je .LBB4_30 | +.LBB4_23: + mov r14, r10 + mov r8, qword ptr [rbx + 16] | Prefetch if there are still items to prefetch + mov r10, qword ptr [rbx + 24] | + mov rsi, rbp | + cmp rax, rbp | + je .LBB4_25 | + mov esi, dword ptr [r15 + 4*rax] | + imul rsi, r10 | + prefetcht0 byte ptr [r8 + rsi + 384] | + prefetcht0 byte ptr [r8 + rsi] | + prefetcht0 byte ptr [r8 + rsi + 64] | + prefetcht0 byte ptr [r8 + rsi + 128] | + prefetcht0 byte ptr [r8 + rsi + 192] | + prefetcht0 byte ptr [r8 + rsi + 256] | + prefetcht0 byte ptr [r8 + rsi + 320] | + inc rax + mov rsi, rax +.LBB4_25: + mov eax, dword ptr [r15 + rdx] | Safety tag check + imul r10, rax | + add r8, r10 | + movzx r10d, byte ptr [r11 + r8] | + cmp r10b, -2 | + jb .LBB4_27 | + vmovups ymm2, ymmword ptr [rdi] | Inlined Distance Computation + vmovups ymm3, ymmword ptr [rdi + 32] | + vmovups ymm4, ymmword ptr [rdi + 64] | + vsubps ymm2, ymm2, ymmword ptr [r8] | + vmovups ymm5, ymmword ptr [rdi + 96] | + vfmadd213ps ymm2, ymm2, ymm0 | + ... repeats a lot | + vfmadd213ps ymm3, ymm3, ymm4 | + vaddps ymm3, ymm5, ymm3 | + vmaskmovps ymm4, ymm1, ymmword ptr [rcx] | + vaddps ymm2, ymm2, ymm3 | + vmaskmovps ymm3, ymm1, ymmword ptr [r8 + 384] | + vsubps ymm3, ymm4, ymm3 | + vfmadd213ps ymm3, ymm3, ymm2 | + vextractf128 xmm2, ymm3, 1 | + vaddps xmm2, xmm3, xmm2 | + vshufpd xmm3, xmm2, xmm2, 1 | + vaddps xmm2, xmm2, xmm3 | + vmovshdup xmm3, xmm2 | + vaddss xmm2, xmm2, xmm3 | + mov r8, qword ptr [rsp + 8] + mov dword ptr [r9 + 8*r8], eax | Write Back + vmovss dword ptr [r9 + 8*r8 + 4], xmm2 | + inc r8 | + mov qword ptr [rsp + 8], r8 | + jmp .LBB4_27 +``` +To prevent compile time explosions, these aggressively optimized inner loops are only generated once and then packaged in a trait object. +This avoids re-monomorphization as different closures and iterators that can be passed to the top level `SearchAccessor::expand_beam` method. + +The hope with this specialization hook is that we can tune and optimize `expand_beam` more aggressively than our current providers to offset the extra byte read (and search accessor creation times due to epoch registration). + +### Testing + +Our current in-memory index uses a [very large](https://github.com/microsoft/DiskANN/blob/main/diskann-providers/src/index/diskann_async.rs) test file with ad-hoc tests. +PR 1206 uses the [A/B test functionality](https://github.com/microsoft/DiskANN/pull/900) in `diskann-benchmark-runner` to + +* Execute [longer running](https://github.com/microsoft/DiskANN/pull/1199) tests. +* Gather richer metrics and recall stats for these tests. +* Compare against a checked-in JSON baseline and notify of any changes. + +To allow for future adaptability, the baseline can be regenerated with the `DISKANN_TEST=overwrite` environment variable setting. +Since the baseline is raw JSON, changes will show up in the git diff for reviewers to inspect. + +The goal here is to enable more robust testing of the in-memory index and by-extension the core DiskANN algorithm. + +## Trade-offs + +All of this is fairly complex stuff to solve an insidious safety loophole. +And unfortunately, the concurrency infrastructure is not strictly needed for static in-memory builds. +I've been thinking about this a lot, and have never really been able to come up with a scheme that provides the read-only property with such little overhead. +I am more than happy to entertain alternative ideas. + +## Benchmark Results + +Incoming, but inmem2 is generally on-par with inmem1 (except for sift, where is has about a 10% performance regression). +For streaming workloads, the hard deletes required by inmem2 may actually lead to higher recall. + +## Future Work + +- [ ] Add quantization (this will require figuring out how an extra blob can be protected by `store` - this is a solvable problem). +- [ ] Implement saving and loading. +- [ ] Optimize `expand_beam` a little more. +- [ ] Migrate existing users over to inmem2.