From 42f99100ee0efdbb986ac14cd3f7c5cc30f556c3 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Fri, 12 Jun 2026 17:26:04 -0700 Subject: [PATCH 01/12] Added the first version --- diskann-benchmark/example/flat-index.json | 22 + diskann-benchmark/src/backend/flat/mod.rs | 12 + diskann-benchmark/src/backend/flat/search.rs | 504 +++++++++++++++++++ diskann-benchmark/src/backend/mod.rs | 2 + diskann-benchmark/src/inputs/flat.rs | 129 +++++ diskann-benchmark/src/inputs/mod.rs | 1 + diskann-benchmark/src/main.rs | 12 +- 7 files changed, 681 insertions(+), 1 deletion(-) create mode 100644 diskann-benchmark/example/flat-index.json create mode 100644 diskann-benchmark/src/backend/flat/mod.rs create mode 100644 diskann-benchmark/src/backend/flat/search.rs create mode 100644 diskann-benchmark/src/inputs/flat.rs diff --git a/diskann-benchmark/example/flat-index.json b/diskann-benchmark/example/flat-index.json new file mode 100644 index 000000000..18d9170cd --- /dev/null +++ b/diskann-benchmark/example/flat-index.json @@ -0,0 +1,22 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "flat-search", + "content": { + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_type": "float32", + "distance": "squared_l2", + "search": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "k": 10, + "num_threads": [1], + "reps": 1 + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/flat/mod.rs b/diskann-benchmark/src/backend/flat/mod.rs new file mode 100644 index 000000000..d7fe34f15 --- /dev/null +++ b/diskann-benchmark/src/backend/flat/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_benchmark_runner::Registry; + +mod search; + +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + search::register_benchmarks(registry) +} diff --git a/diskann-benchmark/src/backend/flat/search.rs b/diskann-benchmark/src/backend/flat/search.rs new file mode 100644 index 000000000..994ba4267 --- /dev/null +++ b/diskann-benchmark/src/backend/flat/search.rs @@ -0,0 +1,504 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Backend for flat-index (brute-force kNN) benchmarks. +//! +//! This exercises [`diskann::flat::FlatIndex::knn_search`] over an in-memory +//! provider, measuring recall and latency. + +use std::{io::Write, num::NonZeroUsize, sync::Arc}; + +use diskann::{ + flat::{DistancesUnordered, FlatIndex, SearchStrategy}, + graph::SearchOutputBuffer, + provider::{DataProvider, DefaultContext, HasId, NoopGuard}, + utils::VectorRepr, + ANNResult, +}; +use diskann_benchmark_core::{self as benchmark_core, recall::GroundTruthMode, search}; +use diskann_benchmark_runner::{ + benchmark::{FailureScore, MatchScore}, + output::Output, + utils::{datatype::AsDataType, percentiles, MicroSeconds}, + Benchmark, Checkpoint, Registry, +}; +use diskann_utils::{future::SendFuture, views::Matrix}; +use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; +use half::f16; +use serde::Serialize; + +use crate::{ + inputs::flat::FlatSearch, + utils::{self, datafiles, recall::RecallMetrics}, +}; + +//////////////////////////// +// Benchmark Registration // +//////////////////////////// + +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + registry.register("flat-index-f32", FlatBenchmark::::new())?; + registry.register("flat-index-f16", FlatBenchmark::::new())?; + registry.register("flat-index-u8", FlatBenchmark::::new())?; + registry.register("flat-index-i8", FlatBenchmark::::new())?; + Ok(()) +} + +///////////////// +// FlatSearch // +///////////////// + +/// A minimal in-memory provider for flat search benchmarks. +/// +/// Wraps a loaded [`Matrix`] and implements [`DataProvider`] with identity +/// ID mapping. +struct InMemProvider { + data: Arc>, +} + +impl DataProvider for InMemProvider { + type Context = DefaultContext; + type InternalId = u32; + type ExternalId = u32; + type Error = diskann::ANNError; + type Guard = NoopGuard; + + fn to_internal_id(&self, _ctx: &DefaultContext, gid: &u32) -> Result { + Ok(*gid) + } + + fn to_external_id(&self, _ctx: &DefaultContext, id: u32) -> Result { + Ok(id) + } +} + +struct FlatBenchmark { + _phantom: std::marker::PhantomData, +} + +impl FlatBenchmark { + fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } +} + +impl Benchmark for FlatBenchmark +where + T: VectorRepr + AsDataType, +{ + type Input = FlatSearch; + type Output = FlatResult; + + fn try_match(&self, input: &FlatSearch) -> Result { + utils::match_data_type::(input.data_type) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&FlatSearch>, + ) -> std::fmt::Result { + match input { + Some(i) => { + let desc = T::describe(i.data_type); + if !desc.is_match() { + writeln!(f, "Data Type: {}", desc)?; + } + Ok(()) + } + None => writeln!(f, "Data Type: {}", T::DATA_TYPE), + } + } + + fn run( + &self, + input: &FlatSearch, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + writeln!(output, "{}", input)?; + + let metric: Metric = input.distance.into(); + + // Load dataset + writeln!(output, "Loading dataset...")?; + let data: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data))?; + let nrows = data.nrows(); + let ncols = data.ncols(); + writeln!(output, " Loaded {} vectors of dimension {}", nrows, ncols)?; + + // Build the provider and wrap in FlatIndex + let data = Arc::new(data); + let provider = InMemProvider { data: data.clone() }; + let index = FlatIndex::new(provider); + let index = Arc::new(index); + + // Load queries and groundtruth + let queries: Matrix = + datafiles::load_dataset(datafiles::BinFile(&input.search.queries))?; + let groundtruth: Matrix = + datafiles::load_dataset(datafiles::BinFile(&input.search.groundtruth))?; + + writeln!( + output, + " Queries: {}, Groundtruth: {}x{}", + queries.nrows(), + groundtruth.nrows(), + groundtruth.ncols(), + )?; + + // Run searches for each thread count + let k = input.search.k; + let reps = input.search.reps; + + let mut results = Vec::new(); + + for &threads in &input.search.num_threads { + let searcher = Arc::new(FlatSearcher { + index: index.clone(), + queries: Arc::new(queries.clone()), + metric, + k, + }); + + let setup = search::Setup { + threads, + tasks: threads, + reps, + }; + + let run = search::Run::new(k, setup); + let aggregated = search::search_all( + searcher, + std::iter::once(run), + FlatAggregator::new(&groundtruth, k.get()), + )?; + + for item in aggregated { + results.push(item); + } + } + + let result = FlatResult { results }; + writeln!(output, "\n\n{}", result)?; + Ok(result) + } +} + +/////////////////////// +// Flat SearchStrategy // +/////////////////////// + +/// A [`SearchStrategy`] implementation for [`InMemProvider`] that drives +/// a full sequential scan over all vectors. +struct FlatScanStrategy { + metric: Metric, + num_vectors: usize, + _phantom: std::marker::PhantomData, +} + +impl FlatScanStrategy { + fn new(metric: Metric, num_vectors: usize) -> Self { + Self { + metric, + num_vectors, + _phantom: std::marker::PhantomData, + } + } +} + +/// The visitor that iterates over all vectors in the provider. +struct FlatVisitor<'a, T> { + data: &'a Matrix, + num_vectors: usize, +} + +impl HasId for FlatVisitor<'_, T> { + type Id = u32; +} + +impl DistancesUnordered for FlatVisitor<'_, T> { + type ElementRef<'a> = &'a [T]; + type Id = u32; + type Error = diskann::error::Infallible; + + fn distances_unordered( + &mut self, + computer: &T::QueryDistance, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32), + { + async move { + for i in 0..self.num_vectors { + let vector = self.data.row(i); + let dist = computer.evaluate_similarity(vector); + f(i as u32, dist); + } + Ok(()) + } + } +} + +impl SearchStrategy, &[T]> for FlatScanStrategy { + type ElementRef<'a> = &'a [T]; + type Id = u32; + type QueryComputer = T::QueryDistance; + type QueryComputerError = diskann::error::Infallible; + type Visitor<'a> + = FlatVisitor<'a, T> + where + Self: 'a, + InMemProvider: 'a; + type Error = diskann::error::Infallible; + + fn create_visitor<'a>( + &'a self, + provider: &'a InMemProvider, + _context: &'a DefaultContext, + ) -> Result, Self::Error> { + Ok(FlatVisitor { + data: &provider.data, + num_vectors: self.num_vectors, + }) + } + + fn build_query_computer( + &self, + query: &[T], + ) -> Result { + Ok(T::query_distance(query, self.metric)) + } +} + +////////////////////////////////////////// +// benchmark_core::search::Search impl // +////////////////////////////////////////// + +/// Wraps a [`FlatIndex`] and queries to implement the [`Search`] trait from benchmark_core. +struct FlatSearcher { + index: Arc>>, + queries: Arc>, + metric: Metric, + k: NonZeroUsize, +} + +/// Additional metrics collected during flat search. +#[derive(Debug, Clone, Copy)] +struct FlatMetrics { + /// The number of distance comparisons performed. + pub comparisons: u32, +} + +impl search::Search for FlatSearcher +where + T: VectorRepr, +{ + type Id = u32; + type Parameters = NonZeroUsize; // k value + type Output = FlatMetrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { + search::IdCount::Fixed(*parameters) + } + + async fn search( + &self, + _parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> ANNResult + where + O: SearchOutputBuffer + Send, + { + let strategy = FlatScanStrategy::::new(self.metric, self.index.provider().data.nrows()); + let context = DefaultContext; + let query = self.queries.row(index); + + let stats = self + .index + .knn_search(self.k, &strategy, &context, query, buffer) + .await?; + + Ok(FlatMetrics { + comparisons: stats.cmps, + }) + } +} + +////////////////// +// Aggregation // +////////////////// + +/// Aggregates results from multiple flat search runs, computing recall metrics. +struct FlatAggregator<'a> { + groundtruth: &'a Matrix, + recall_k: usize, +} + +impl<'a> FlatAggregator<'a> { + fn new(groundtruth: &'a Matrix, recall_k: usize) -> Self { + Self { + groundtruth, + recall_k, + } + } +} + +/// Results of a single flat search run. +#[derive(Debug, Clone, Serialize)] +struct FlatSearchResults { + num_tasks: usize, + k: usize, + qps: Vec, + search_latencies: Vec, + mean_latencies: Vec, + p90_latencies: Vec, + p99_latencies: Vec, + recall: RecallMetrics, + mean_cmps: f32, +} + +impl search::Aggregate for FlatAggregator<'_> { + type Output = FlatSearchResults; + + fn aggregate( + &mut self, + run: search::Run, + mut results: Vec>, + ) -> anyhow::Result { + // Compute recall using the first repetition's results. + let recall = match results.first() { + Some(first) => benchmark_core::recall::knn( + self.groundtruth, + None, + first.ids().as_rows(), + self.recall_k, + run.parameters().get(), + GroundTruthMode::Fixed, + )?, + None => anyhow::bail!("Results must be non-empty"), + }; + + let mut mean_latencies = Vec::with_capacity(results.len()); + let mut p90_latencies = Vec::with_capacity(results.len()); + let mut p99_latencies = Vec::with_capacity(results.len()); + + results.iter_mut().for_each(|r| { + match percentiles::compute_percentiles(r.latencies_mut()) { + Ok(values) => { + let percentiles::Percentiles { mean, p90, p99, .. } = values; + mean_latencies.push(mean); + p90_latencies.push(p90); + p99_latencies.push(p99); + } + Err(_) => { + mean_latencies.push(0.0); + p90_latencies.push(MicroSeconds::new(0)); + p99_latencies.push(MicroSeconds::new(0)); + } + } + }); + + let qps: Vec = results + .iter() + .map(|r| recall.num_queries as f64 / r.end_to_end_latency().as_seconds()) + .collect(); + + let mean_cmps = { + let (sum, count) = results + .iter() + .flat_map(|r| r.output().iter().map(|o| o.comparisons as f64)) + .fold((0.0f64, 0usize), |(s, c), v| (s + v, c + 1)); + if count == 0 { + 0.0 + } else { + sum / count as f64 + } + } as f32; + + Ok(FlatSearchResults { + num_tasks: run.setup().tasks.into(), + k: run.parameters().get(), + qps, + search_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), + mean_latencies, + p90_latencies, + p99_latencies, + recall: (&recall).into(), + mean_cmps, + }) + } +} + +////////////// +// Results // +////////////// + +#[derive(Debug, Serialize)] +struct FlatResult { + results: Vec, +} + +impl std::fmt::Display for FlatResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.results.is_empty() { + return Ok(()); + } + + let headers: &[&str] = &[ + "K", + "Avg cmps", + "QPS - mean(max)", + "Avg Latency", + "p99 Latency", + "Recall", + "Threads", + ]; + + let mut table = + diskann_benchmark_runner::utils::fmt::Table::new(headers, self.results.len()); + for (i, r) in self.results.iter().enumerate() { + let mut row = table.row(i); + row.insert(r.k, 0); + row.insert(r.mean_cmps, 1); + row.insert( + format!( + "{:.1} ({:.1})", + utils::MaybeDisplay(percentiles::mean(&r.qps), "missing"), + utils::MaybeDisplay(percentiles::max_f64(&r.qps), "missing"), + ), + 2, + ); + row.insert( + format!( + "{:.1}us ({:.1}us)", + utils::MaybeDisplay(percentiles::mean(&r.mean_latencies), "missing"), + utils::MaybeDisplay(percentiles::max_f64(&r.mean_latencies), "missing"), + ), + 3, + ); + row.insert( + format!( + "{:.1}us ({:.1})", + utils::MaybeDisplay(percentiles::mean(&r.p99_latencies), "missing"), + utils::MaybeDisplay(r.p99_latencies.iter().max(), "missing"), + ), + 4, + ); + row.insert(format!("{:3}", r.recall.average), 5); + row.insert(r.num_tasks, 6); + } + + write!(f, "{}", table) + } +} diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index d04bae158..44904886f 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -8,12 +8,14 @@ use diskann_benchmark_runner::Registry; mod disk_index; mod exhaustive; mod filters; +mod flat; mod index; mod multi_vector; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { exhaustive::register_benchmarks(registry)?; disk_index::register_benchmarks(registry)?; + flat::register_benchmarks(registry)?; index::register_benchmarks(registry)?; filters::register_benchmarks(registry)?; multi_vector::register_benchmarks(registry)?; diff --git a/diskann-benchmark/src/inputs/flat.rs b/diskann-benchmark/src/inputs/flat.rs new file mode 100644 index 000000000..b3a176acf --- /dev/null +++ b/diskann-benchmark/src/inputs/flat.rs @@ -0,0 +1,129 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::num::NonZeroUsize; + +use anyhow::Context; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; +use serde::{Deserialize, Serialize}; + +use crate::{ + inputs::{as_input, write_field, Example, PRINT_WIDTH}, + utils::SimilarityMeasure, +}; + +////////////// +// Registry // +////////////// + +as_input!(FlatSearch); + +/////////// +// Input // +/////////// + +/// Input specification for a flat-index (brute-force kNN) benchmark. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct FlatSearch { + /// Path to the dataset vectors (`.bin` format). + pub(crate) data: InputFile, + + /// The on-disk data type of the dataset. + pub(crate) data_type: DataType, + + /// The distance metric to use. + pub(crate) distance: SimilarityMeasure, + + /// Search configuration. + pub(crate) search: SearchPhase, +} + +impl FlatSearch { + pub(crate) const fn tag() -> &'static str { + "flat-search" + } + + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; + Ok(()) + } +} + +impl std::fmt::Display for FlatSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write_field!(f, "Data", self.data.display())?; + write_field!(f, "Data Type", self.data_type)?; + write_field!(f, "Distance", self.distance)?; + write_field!(f, "Queries", self.search.queries.display())?; + write_field!(f, "Groundtruth", self.search.groundtruth.display())?; + write_field!(f, "K", self.search.k)?; + write_field!(f, "Threads", self.search.num_threads.len())?; + write_field!(f, "Reps", self.search.reps)?; + Ok(()) + } +} + +impl Example for FlatSearch { + fn example() -> Self { + Self { + data: InputFile::new("path/to/data.bin"), + data_type: DataType::Float32, + distance: SimilarityMeasure::SquaredL2, + search: SearchPhase::example(), + } + } +} + +/////////////////// +// Search Phase // +/////////////////// + +/// Parameters controlling the search phase of a flat benchmark. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct SearchPhase { + /// Path to the query vectors (`.bin` format). + pub(crate) queries: InputFile, + + /// Path to the groundtruth file (`.bin` format). + pub(crate) groundtruth: InputFile, + + /// The number of nearest neighbors to retrieve per query. + pub(crate) k: NonZeroUsize, + + /// Number of threads to use for parallel query execution. + pub(crate) num_threads: Vec, + + /// Number of repetitions per configuration for stable timing. + pub(crate) reps: NonZeroUsize, +} + +impl SearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.queries + .resolve(checker) + .context("resolving queries file")?; + self.groundtruth + .resolve(checker) + .context("resolving groundtruth file")?; + Ok(()) + } +} + +impl Example for SearchPhase { + fn example() -> Self { + Self { + queries: InputFile::new("path/to/queries.bin"), + groundtruth: InputFile::new("path/to/groundtruth.bin"), + k: NonZeroUsize::new(10).unwrap(), + num_threads: vec![ + NonZeroUsize::new(1).unwrap(), + NonZeroUsize::new(4).unwrap(), + NonZeroUsize::new(8).unwrap(), + ], + reps: NonZeroUsize::new(5).unwrap(), + } + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index ed49f145e..e21375a51 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -6,6 +6,7 @@ pub(crate) mod disk; pub(crate) mod exhaustive; pub(crate) mod filters; +pub(crate) mod flat; pub(crate) mod graph_index; pub(crate) mod multi_vector; pub(crate) mod save_and_load; diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index aa21c4ae5..6096c2259 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -262,7 +262,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); - let input_path = tempdir.path().join("graph-index.json"); + let input_path = tempdir.path().join("input.json"); save_to_file(&input_path, &raw); let output_path = tempdir.path().join("output.json"); @@ -300,6 +300,16 @@ mod tests { run_integration_test(raw); } + ///////////////////////// + // Flat Search // + ///////////////////////// + + #[test] + fn flat_search_integration() { + let raw = value_from_file(&example_directory().join("flat-index.json")); + run_integration_test(raw); + } + //////////////////////////// // Dynamic Index // //////////////////////////// From 826154507fde1f9568478705476c02d050b72164 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Fri, 12 Jun 2026 18:36:13 -0700 Subject: [PATCH 02/12] Added benchmark file --- .../wikipedia-100K-flat-index.json | 22 +++++++++++++++++++ diskann-benchmark/src/inputs/flat.rs | 11 +++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 diskann-benchmark/perf_test_inputs/wikipedia-100K-flat-index.json diff --git a/diskann-benchmark/perf_test_inputs/wikipedia-100K-flat-index.json b/diskann-benchmark/perf_test_inputs/wikipedia-100K-flat-index.json new file mode 100644 index 000000000..cabb62b89 --- /dev/null +++ b/diskann-benchmark/perf_test_inputs/wikipedia-100K-flat-index.json @@ -0,0 +1,22 @@ +{ + "search_directories": [ + "target/tmp" + ], + "jobs": [ + { + "type": "flat-search", + "content": { + "data": "wikipedia_cohere/wikipedia_base.bin.crop_nb_100000", + "data_type": "float32", + "distance": "inner_product", + "search": { + "queries": "wikipedia_cohere/wikipedia_query.bin", + "groundtruth": "wikipedia_cohere/wikipedia-100K", + "k": 100, + "num_threads": [4, 8], + "reps": 1 + } + } + } + ] +} diff --git a/diskann-benchmark/src/inputs/flat.rs b/diskann-benchmark/src/inputs/flat.rs index b3a176acf..7a775f764 100644 --- a/diskann-benchmark/src/inputs/flat.rs +++ b/diskann-benchmark/src/inputs/flat.rs @@ -60,7 +60,16 @@ impl std::fmt::Display for FlatSearch { write_field!(f, "Queries", self.search.queries.display())?; write_field!(f, "Groundtruth", self.search.groundtruth.display())?; write_field!(f, "K", self.search.k)?; - write_field!(f, "Threads", self.search.num_threads.len())?; + write_field!( + f, + "Threads", + self.search + .num_threads + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") + )?; write_field!(f, "Reps", self.search.reps)?; Ok(()) } From ec5db9a7109be7fb5e20bd963f6e76601397e411 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Fri, 12 Jun 2026 19:11:11 -0700 Subject: [PATCH 03/12] Use NAME --- diskann-benchmark/src/backend/flat/search.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/diskann-benchmark/src/backend/flat/search.rs b/diskann-benchmark/src/backend/flat/search.rs index 994ba4267..00d92636c 100644 --- a/diskann-benchmark/src/backend/flat/search.rs +++ b/diskann-benchmark/src/backend/flat/search.rs @@ -38,11 +38,13 @@ use crate::{ // Benchmark Registration // //////////////////////////// +const NAME: &str = "flat-index"; + pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { - registry.register("flat-index-f32", FlatBenchmark::::new())?; - registry.register("flat-index-f16", FlatBenchmark::::new())?; - registry.register("flat-index-u8", FlatBenchmark::::new())?; - registry.register("flat-index-i8", FlatBenchmark::::new())?; + registry.register(NAME, FlatBenchmark::::new())?; + registry.register(NAME, FlatBenchmark::::new())?; + registry.register(NAME, FlatBenchmark::::new())?; + registry.register(NAME, FlatBenchmark::::new())?; Ok(()) } From 9f80e12547b1b3e3c618c69e4b58808e5ad24f6b Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Mon, 15 Jun 2026 15:01:00 -0700 Subject: [PATCH 04/12] Add FlatSearchParameters --- diskann-benchmark/src/backend/flat/search.rs | 24 ++++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/diskann-benchmark/src/backend/flat/search.rs b/diskann-benchmark/src/backend/flat/search.rs index 00d92636c..e9f10329c 100644 --- a/diskann-benchmark/src/backend/flat/search.rs +++ b/diskann-benchmark/src/backend/flat/search.rs @@ -164,7 +164,6 @@ where index: index.clone(), queries: Arc::new(queries.clone()), metric, - k, }); let setup = search::Setup { @@ -173,7 +172,7 @@ where reps, }; - let run = search::Run::new(k, setup); + let run = search::Run::new(FlatSearchParameters { k }, setup); let aggregated = search::search_all( searcher, std::iter::once(run), @@ -287,6 +286,11 @@ struct FlatSearcher { index: Arc>>, queries: Arc>, metric: Metric, +} + +/// Search parameters for flat-index benchmarks. +#[derive(Debug, Clone, Copy)] +struct FlatSearchParameters { k: NonZeroUsize, } @@ -302,7 +306,7 @@ where T: VectorRepr, { type Id = u32; - type Parameters = NonZeroUsize; // k value + type Parameters = FlatSearchParameters; type Output = FlatMetrics; fn num_queries(&self) -> usize { @@ -310,12 +314,12 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(*parameters) + search::IdCount::Fixed(parameters.k) } async fn search( &self, - _parameters: &Self::Parameters, + parameters: &Self::Parameters, buffer: &mut O, index: usize, ) -> ANNResult @@ -328,7 +332,7 @@ where let stats = self .index - .knn_search(self.k, &strategy, &context, query, buffer) + .knn_search(parameters.k, &strategy, &context, query, buffer) .await?; Ok(FlatMetrics { @@ -370,12 +374,12 @@ struct FlatSearchResults { mean_cmps: f32, } -impl search::Aggregate for FlatAggregator<'_> { +impl search::Aggregate for FlatAggregator<'_> { type Output = FlatSearchResults; fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result { // Compute recall using the first repetition's results. @@ -385,7 +389,7 @@ impl search::Aggregate for FlatAggregator<'_> { None, first.ids().as_rows(), self.recall_k, - run.parameters().get(), + run.parameters().k.get(), GroundTruthMode::Fixed, )?, None => anyhow::bail!("Results must be non-empty"), @@ -430,7 +434,7 @@ impl search::Aggregate for FlatAggregator<'_> { Ok(FlatSearchResults { num_tasks: run.setup().tasks.into(), - k: run.parameters().get(), + k: run.parameters().k.get(), qps, search_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), mean_latencies, From 36d7f0ee0349f164a63a766a4cbbad10ab8e41c6 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Mon, 15 Jun 2026 15:13:18 -0700 Subject: [PATCH 05/12] Removed num_vectors --- diskann-benchmark/src/backend/flat/search.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/diskann-benchmark/src/backend/flat/search.rs b/diskann-benchmark/src/backend/flat/search.rs index e9f10329c..0342137c6 100644 --- a/diskann-benchmark/src/backend/flat/search.rs +++ b/diskann-benchmark/src/backend/flat/search.rs @@ -198,15 +198,13 @@ where /// a full sequential scan over all vectors. struct FlatScanStrategy { metric: Metric, - num_vectors: usize, _phantom: std::marker::PhantomData, } impl FlatScanStrategy { - fn new(metric: Metric, num_vectors: usize) -> Self { + fn new(metric: Metric) -> Self { Self { metric, - num_vectors, _phantom: std::marker::PhantomData, } } @@ -215,7 +213,6 @@ impl FlatScanStrategy { /// The visitor that iterates over all vectors in the provider. struct FlatVisitor<'a, T> { data: &'a Matrix, - num_vectors: usize, } impl HasId for FlatVisitor<'_, T> { @@ -236,8 +233,7 @@ impl DistancesUnordered for FlatVisitor<'_, T> F: Send + FnMut(Self::Id, f32), { async move { - for i in 0..self.num_vectors { - let vector = self.data.row(i); + for (i, vector) in self.data.row_iter().enumerate() { let dist = computer.evaluate_similarity(vector); f(i as u32, dist); } @@ -265,7 +261,6 @@ impl SearchStrategy, &[T]> for FlatScanStrategy< ) -> Result, Self::Error> { Ok(FlatVisitor { data: &provider.data, - num_vectors: self.num_vectors, }) } @@ -326,7 +321,7 @@ where where O: SearchOutputBuffer + Send, { - let strategy = FlatScanStrategy::::new(self.metric, self.index.provider().data.nrows()); + let strategy = FlatScanStrategy::::new(self.metric); let context = DefaultContext; let query = self.queries.row(index); From b6e742a0516eeab8b40492f4ccb382a698ce70c0 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Wed, 17 Jun 2026 16:54:25 -0700 Subject: [PATCH 06/12] Switch to FastMemoryVectorProviderAsync --- diskann-benchmark/src/flat/search.rs | 38 ++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index 99ac1b670..fbfe093c4 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -24,6 +24,8 @@ use diskann_benchmark_runner::{ utils::{datatype::AsDataType, percentiles, MicroSeconds}, Benchmark, Checkpoint, Registry, }; +use diskann_providers::model::graph::provider::async_::FastMemoryVectorProviderAsync; +use diskann_providers::storage::FileStorageProvider; use diskann_utils::{future::SendFuture, views::Matrix}; use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; use half::f16; @@ -54,10 +56,10 @@ pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> /// A minimal in-memory provider for flat search benchmarks. /// -/// Wraps a loaded [`Matrix`] and implements [`DataProvider`] with identity -/// ID mapping. -struct InMemProvider { - data: Arc>, +/// Wraps a [`FastMemoryVectorProviderAsync`] and implements [`DataProvider`] +/// with identity ID mapping. +struct InMemProvider { + data: FastMemoryVectorProviderAsync, } impl DataProvider for InMemProvider { @@ -128,14 +130,21 @@ where // Load dataset writeln!(output, "Loading dataset...")?; - let data: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data))?; - let nrows = data.nrows(); - let ncols = data.ncols(); + let provider = { + let fmvp = FastMemoryVectorProviderAsync::::load_from_bin( + &FileStorageProvider, + input.data.to_str().unwrap(), + metric, + None, + None, + )?; + InMemProvider { data: fmvp } + }; + let nrows = provider.data.total(); + let ncols = provider.data.dim(); writeln!(output, " Loaded {} vectors of dimension {}", nrows, ncols)?; - // Build the provider and wrap in FlatIndex - let data = Arc::new(data); - let provider = InMemProvider { data: data.clone() }; + // Build the FlatIndex let index = FlatIndex::new(provider); let index = Arc::new(index); @@ -211,8 +220,8 @@ impl FlatScanStrategy { } /// The visitor that iterates over all vectors in the provider. -struct FlatVisitor<'a, T> { - data: &'a Matrix, +struct FlatVisitor<'a, T: VectorRepr> { + data: &'a FastMemoryVectorProviderAsync, } impl HasId for FlatVisitor<'_, T> { @@ -232,7 +241,10 @@ impl DistancesUnordered for FlatVisitor<'_, T> F: Send + FnMut(Self::Id, f32), { async move { - for (i, vector) in self.data.row_iter().enumerate() { + let total = self.data.total(); + for i in 0..total { + // SAFETY: single-writer load completed before search; no concurrent mutation. + let vector = unsafe { self.data.get_vector_sync(i) }; let dist = computer.evaluate_similarity(vector); f(i as u32, dist); } From 721f83a973671ab4425962daab75b4fceabed0c7 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Thu, 18 Jun 2026 20:13:41 -0700 Subject: [PATCH 07/12] Minor fixes --- diskann-benchmark/src/flat/search.rs | 40 ++++++++++++++++++---------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index fbfe093c4..70a2b6100 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -133,7 +133,7 @@ where let provider = { let fmvp = FastMemoryVectorProviderAsync::::load_from_bin( &FileStorageProvider, - input.data.to_str().unwrap(), + &input.data.to_string_lossy(), metric, None, None, @@ -146,7 +146,6 @@ where // Build the FlatIndex let index = FlatIndex::new(provider); - let index = Arc::new(index); // Load queries and groundtruth let queries: Matrix = @@ -154,6 +153,13 @@ where let groundtruth: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.search.groundtruth))?; + anyhow::ensure!( + ncols == queries.ncols(), + "dataset dimension ({}) does not match query dimension ({})", + ncols, + queries.ncols(), + ); + writeln!( output, " Queries: {}, Groundtruth: {}x{}", @@ -168,13 +174,13 @@ where let mut results = Vec::new(); - for &threads in &input.search.num_threads { - let searcher = Arc::new(FlatSearcher { - index: index.clone(), - queries: Arc::new(queries.clone()), - metric, - }); + let searcher = Arc::new(FlatSearcher { + index, + queries, + strategy: FlatScanStrategy::new(metric), + }); + for &threads in &input.search.num_threads { let setup = search::Setup { threads, tasks: threads, @@ -183,7 +189,7 @@ where let run = search::Run::new(FlatSearchParameters { k }, setup); let aggregated = search::search_all( - searcher, + searcher.clone(), std::iter::once(run), FlatAggregator::new(&groundtruth, k.get()), )?; @@ -288,9 +294,9 @@ impl SearchStrategy, &[T]> for FlatScanStrategy< /// Wraps a [`FlatIndex`] and queries to implement the [`Search`] trait from benchmark_core. struct FlatSearcher { - index: Arc>>, - queries: Arc>, - metric: Metric, + index: FlatIndex>, + queries: Matrix, + strategy: FlatScanStrategy, } /// Search parameters for flat-index benchmarks. @@ -331,13 +337,19 @@ where where O: SearchOutputBuffer + Send, { - let strategy = FlatScanStrategy::::new(self.metric); let context = DefaultContext; let query = self.queries.row(index); let stats = self .index - .knn_search(parameters.k, &strategy, CopyIds, &context, query, buffer) + .knn_search( + parameters.k, + &self.strategy, + CopyIds, + &context, + query, + buffer, + ) .await?; Ok(FlatMetrics { From 4b7fd9913016dd0d62e7f4bae66ba43115db18e6 Mon Sep 17 00:00:00 2001 From: Alex Razumov Date: Thu, 18 Jun 2026 21:57:34 -0700 Subject: [PATCH 08/12] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark/src/flat/search.rs | 41 ++++++++++++++++------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index 70a2b6100..98eccfc7c 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -142,6 +142,12 @@ where }; let nrows = provider.data.total(); let ncols = provider.data.dim(); + anyhow::ensure!( + nrows <= u32::MAX as usize, + "flat-index benchmark requires <= {} vectors (got {}) to fit in u32 ids", + u32::MAX, + nrows, + ); writeln!(output, " Loaded {} vectors of dimension {}", nrows, ncols)?; // Build the FlatIndex @@ -150,9 +156,10 @@ where // Load queries and groundtruth let queries: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.search.queries))?; - let groundtruth: Matrix = - datafiles::load_dataset(datafiles::BinFile(&input.search.groundtruth))?; - + let groundtruth = datafiles::load_groundtruth( + datafiles::BinFile(&input.search.groundtruth), + Some(input.search.k.get()), + )?; anyhow::ensure!( ncols == queries.ncols(), "dataset dimension ({}) does not match query dimension ({})", @@ -171,6 +178,12 @@ where // Run searches for each thread count let k = input.search.k; let reps = input.search.reps; + anyhow::ensure!( + k.get() <= nrows, + "k ({}) must be <= number of dataset vectors ({})", + k, + nrows, + ); let mut results = Vec::new(); @@ -416,21 +429,13 @@ impl search::Aggregate for FlatAggregato let mut p90_latencies = Vec::with_capacity(results.len()); let mut p99_latencies = Vec::with_capacity(results.len()); - results.iter_mut().for_each(|r| { - match percentiles::compute_percentiles(r.latencies_mut()) { - Ok(values) => { - let percentiles::Percentiles { mean, p90, p99, .. } = values; - mean_latencies.push(mean); - p90_latencies.push(p90); - p99_latencies.push(p99); - } - Err(_) => { - mean_latencies.push(0.0); - p90_latencies.push(MicroSeconds::new(0)); - p99_latencies.push(MicroSeconds::new(0)); - } - } - }); + for r in results.iter_mut() { + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(r.latencies_mut())?; + mean_latencies.push(mean); + p90_latencies.push(p90); + p99_latencies.push(p99); + } let qps: Vec = results .iter() From f55446bbbb60b6d06698f3d24af533fad235f96b Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Fri, 19 Jun 2026 13:05:52 -0700 Subject: [PATCH 09/12] Increased test coverage --- diskann-benchmark/src/flat/search.rs | 84 ++++++++++++++++++++++++++++ diskann-benchmark/src/inputs/flat.rs | 22 ++++++++ 2 files changed, 106 insertions(+) diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index 98eccfc7c..940612b1a 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -530,3 +530,87 @@ impl std::fmt::Display for FlatResult { write!(f, "{}", table) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::inputs::Example; + use diskann_benchmark_runner::utils::MicroSeconds; + + fn make_dummy_results(num_results: usize) -> FlatResult { + let results = (0..num_results) + .map(|i| FlatSearchResults { + num_tasks: i + 1, + k: 10, + qps: vec![100.0], + search_latencies: vec![MicroSeconds::new(1000)], + mean_latencies: vec![10.0], + p90_latencies: vec![MicroSeconds::new(900)], + p99_latencies: vec![MicroSeconds::new(990)], + recall: RecallMetrics { + recall_k: 10, + recall_n: 10, + num_queries: 100, + average: 0.95, + }, + mean_cmps: 256.0, + }) + .collect(); + FlatResult { results } + } + + #[test] + fn display_empty_flat_result() { + let result = FlatResult { + results: Vec::new(), + }; + let text = format!("{}", result); + assert!(text.is_empty()); + } + + #[test] + fn display_flat_result_with_data() { + let result = make_dummy_results(1); + let text = format!("{}", result); + assert!(text.contains("K")); + assert!(text.contains("Recall")); + } + + #[test] + fn description_with_matching_type() { + let benchmark = FlatBenchmark::::new(); + let input = crate::inputs::flat::FlatSearch::example(); + let text = format!("{}", DescriptionHelper(&benchmark, Some(&input))); + // When the type matches, description writes nothing (is_match() == true) + assert!(!text.contains("Data Type:")); + } + + #[test] + fn description_without_input() { + let benchmark = FlatBenchmark::::new(); + let text = format!("{}", DescriptionHelper::(&benchmark, None)); + assert!(text.contains("Data Type: float32")); + } + + #[test] + fn description_with_mismatched_type() { + use diskann_benchmark_runner::utils::datatype::DataType; + let benchmark = FlatBenchmark::::new(); + let mut input = crate::inputs::flat::FlatSearch::example(); + input.data_type = DataType::UInt8; + let text = format!("{}", DescriptionHelper(&benchmark, Some(&input))); + assert!(text.contains("Data Type: expected \"float32\" but found \"uint8\"")); + } + + /// Helper to call `description()` via `Display`. + struct DescriptionHelper<'a, T: VectorRepr + AsDataType>( + &'a FlatBenchmark, + Option<&'a crate::inputs::flat::FlatSearch>, + ); + + impl std::fmt::Display for DescriptionHelper<'_, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.description(f, self.1) + } + } +} diff --git a/diskann-benchmark/src/inputs/flat.rs b/diskann-benchmark/src/inputs/flat.rs index 7a775f764..d25e77cf7 100644 --- a/diskann-benchmark/src/inputs/flat.rs +++ b/diskann-benchmark/src/inputs/flat.rs @@ -136,3 +136,25 @@ impl Example for SearchPhase { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::inputs::Example; + + #[test] + fn example_flat_search_round_trips() { + let example = FlatSearch::example(); + let json = serde_json::to_value(&example).unwrap(); + let _: FlatSearch = serde_json::from_value(json).unwrap(); + } + + #[test] + fn display_flat_search() { + let example = FlatSearch::example(); + let text = format!("{}", example); + assert!(text.contains("Data")); + assert!(text.contains("Threads")); + assert!(text.contains("Reps")); + } +} From e960a7dfb2191fcded8478aac5b530d48e66333e Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Thu, 25 Jun 2026 20:37:08 -0700 Subject: [PATCH 10/12] =?UTF-8?q?Promote=20average=5Fall=20to=C2=A0pub,=20?= =?UTF-8?q?reuse=20it?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- diskann-benchmark-core/src/lib.rs | 2 +- diskann-benchmark-core/src/utils.rs | 3 ++- diskann-benchmark/src/flat/search.rs | 14 ++++---------- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/diskann-benchmark-core/src/lib.rs b/diskann-benchmark-core/src/lib.rs index 920a7fe20..828359f66 100644 --- a/diskann-benchmark-core/src/lib.rs +++ b/diskann-benchmark-core/src/lib.rs @@ -40,7 +40,7 @@ //! in which method can fail, the [`anyhow::Error`] type balances generality and fidelity. mod internal; -pub(crate) mod utils; +pub mod utils; // Public Utility Modules pub mod recall; diff --git a/diskann-benchmark-core/src/utils.rs b/diskann-benchmark-core/src/utils.rs index 231754099..acff8c1c6 100644 --- a/diskann-benchmark-core/src/utils.rs +++ b/diskann-benchmark-core/src/utils.rs @@ -5,7 +5,8 @@ use diskann_benchmark_runner::utils::percentiles::AsF64Lossy; -pub(crate) fn average_all(x: I) -> f64 +/// Computes the arithmetic mean of an iterator of values, returning `0.0` when empty. +pub fn average_all(x: I) -> f64 where I: IntoIterator, { diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index 940612b1a..60aca726c 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -442,17 +442,11 @@ impl search::Aggregate for FlatAggregato .map(|r| recall.num_queries as f64 / r.end_to_end_latency().as_seconds()) .collect(); - let mean_cmps = { - let (sum, count) = results + let mean_cmps = benchmark_core::utils::average_all( + results .iter() - .flat_map(|r| r.output().iter().map(|o| o.comparisons as f64)) - .fold((0.0f64, 0usize), |(s, c), v| (s + v, c + 1)); - if count == 0 { - 0.0 - } else { - sum / count as f64 - } - } as f32; + .flat_map(|r| r.output().iter().map(|o| o.comparisons)), + ) as f32; Ok(FlatSearchResults { num_tasks: run.setup().tasks.into(), From dbf18a4a1540e95384b14208399192f74e93e270 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Thu, 25 Jun 2026 21:05:08 -0700 Subject: [PATCH 11/12] Removed Flat prefix --- diskann-benchmark/src/flat/search.rs | 84 ++++++++++++++-------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index 60aca726c..9e7e3d636 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -43,10 +43,10 @@ use crate::{ const NAME: &str = "flat-index"; pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { - registry.register(NAME, FlatBenchmark::::new())?; - registry.register(NAME, FlatBenchmark::::new())?; - registry.register(NAME, FlatBenchmark::::new())?; - registry.register(NAME, FlatBenchmark::::new())?; + registry.register(NAME, Flat::::new())?; + registry.register(NAME, Flat::::new())?; + registry.register(NAME, Flat::::new())?; + registry.register(NAME, Flat::::new())?; Ok(()) } @@ -78,11 +78,11 @@ impl DataProvider for InMemProvider { } } -struct FlatBenchmark { +struct Flat { _phantom: std::marker::PhantomData, } -impl FlatBenchmark { +impl Flat { fn new() -> Self { Self { _phantom: std::marker::PhantomData, @@ -90,7 +90,7 @@ impl FlatBenchmark { } } -impl Benchmark for FlatBenchmark +impl Benchmark for Flat where T: VectorRepr + AsDataType, { @@ -187,10 +187,10 @@ where let mut results = Vec::new(); - let searcher = Arc::new(FlatSearcher { + let searcher = Arc::new(Searcher { index, queries, - strategy: FlatScanStrategy::new(metric), + strategy: Strategy::new(metric), }); for &threads in &input.search.num_threads { @@ -200,11 +200,11 @@ where reps, }; - let run = search::Run::new(FlatSearchParameters { k }, setup); + let run = search::Run::new(SearchParameters { k }, setup); let aggregated = search::search_all( searcher.clone(), std::iter::once(run), - FlatAggregator::new(&groundtruth, k.get()), + Aggregator::new(&groundtruth, k.get()), )?; for item in aggregated { @@ -224,12 +224,12 @@ where /// A [`SearchStrategy`] implementation for [`InMemProvider`] that drives /// a full sequential scan over all vectors. -struct FlatScanStrategy { +struct Strategy { metric: Metric, _phantom: std::marker::PhantomData, } -impl FlatScanStrategy { +impl Strategy { fn new(metric: Metric) -> Self { Self { metric, @@ -239,15 +239,15 @@ impl FlatScanStrategy { } /// The visitor that iterates over all vectors in the provider. -struct FlatVisitor<'a, T: VectorRepr> { +struct Visitor<'a, T: VectorRepr> { data: &'a FastMemoryVectorProviderAsync, } -impl HasId for FlatVisitor<'_, T> { +impl HasId for Visitor<'_, T> { type Id = u32; } -impl DistancesUnordered for FlatVisitor<'_, T> { +impl DistancesUnordered for Visitor<'_, T> { type ElementRef<'a> = &'a [T]; type Error = diskann::error::Infallible; @@ -272,12 +272,12 @@ impl DistancesUnordered for FlatVisitor<'_, T> } } -impl SearchStrategy, &[T]> for FlatScanStrategy { +impl SearchStrategy, &[T]> for Strategy { type ElementRef<'a> = &'a [T]; type QueryComputer = T::QueryDistance; type QueryComputerError = diskann::error::Infallible; type Visitor<'a> - = FlatVisitor<'a, T> + = Visitor<'a, T> where Self: 'a, InMemProvider: 'a; @@ -288,7 +288,7 @@ impl SearchStrategy, &[T]> for FlatScanStrategy< provider: &'a InMemProvider, _context: &'a DefaultContext, ) -> Result, Self::Error> { - Ok(FlatVisitor { + Ok(Visitor { data: &provider.data, }) } @@ -306,32 +306,32 @@ impl SearchStrategy, &[T]> for FlatScanStrategy< ////////////////////////////////////////// /// Wraps a [`FlatIndex`] and queries to implement the [`Search`] trait from benchmark_core. -struct FlatSearcher { +struct Searcher { index: FlatIndex>, queries: Matrix, - strategy: FlatScanStrategy, + strategy: Strategy, } /// Search parameters for flat-index benchmarks. #[derive(Debug, Clone, Copy)] -struct FlatSearchParameters { +struct SearchParameters { k: NonZeroUsize, } /// Additional metrics collected during flat search. #[derive(Debug, Clone, Copy)] -struct FlatMetrics { +struct Metrics { /// The number of distance comparisons performed. pub comparisons: u32, } -impl search::Search for FlatSearcher +impl search::Search for Searcher where T: VectorRepr, { type Id = u32; - type Parameters = FlatSearchParameters; - type Output = FlatMetrics; + type Parameters = SearchParameters; + type Output = Metrics; fn num_queries(&self) -> usize { self.queries.nrows() @@ -365,7 +365,7 @@ where ) .await?; - Ok(FlatMetrics { + Ok(Metrics { comparisons: stats.cmps, }) } @@ -376,12 +376,12 @@ where ////////////////// /// Aggregates results from multiple flat search runs, computing recall metrics. -struct FlatAggregator<'a> { +struct Aggregator<'a> { groundtruth: &'a Matrix, recall_k: usize, } -impl<'a> FlatAggregator<'a> { +impl<'a> Aggregator<'a> { fn new(groundtruth: &'a Matrix, recall_k: usize) -> Self { Self { groundtruth, @@ -392,7 +392,7 @@ impl<'a> FlatAggregator<'a> { /// Results of a single flat search run. #[derive(Debug, Clone, Serialize)] -struct FlatSearchResults { +struct SearchResults { num_tasks: usize, k: usize, qps: Vec, @@ -404,14 +404,14 @@ struct FlatSearchResults { mean_cmps: f32, } -impl search::Aggregate for FlatAggregator<'_> { - type Output = FlatSearchResults; +impl search::Aggregate for Aggregator<'_> { + type Output = SearchResults; fn aggregate( &mut self, - run: search::Run, - mut results: Vec>, - ) -> anyhow::Result { + run: search::Run, + mut results: Vec>, + ) -> anyhow::Result { // Compute recall using the first repetition's results. let recall = match results.first() { Some(first) => benchmark_core::recall::knn( @@ -448,7 +448,7 @@ impl search::Aggregate for FlatAggregato .flat_map(|r| r.output().iter().map(|o| o.comparisons)), ) as f32; - Ok(FlatSearchResults { + Ok(SearchResults { num_tasks: run.setup().tasks.into(), k: run.parameters().k.get(), qps, @@ -468,7 +468,7 @@ impl search::Aggregate for FlatAggregato #[derive(Debug, Serialize)] struct FlatResult { - results: Vec, + results: Vec, } impl std::fmt::Display for FlatResult { @@ -533,7 +533,7 @@ mod tests { fn make_dummy_results(num_results: usize) -> FlatResult { let results = (0..num_results) - .map(|i| FlatSearchResults { + .map(|i| SearchResults { num_tasks: i + 1, k: 10, qps: vec![100.0], @@ -572,7 +572,7 @@ mod tests { #[test] fn description_with_matching_type() { - let benchmark = FlatBenchmark::::new(); + let benchmark = Flat::::new(); let input = crate::inputs::flat::FlatSearch::example(); let text = format!("{}", DescriptionHelper(&benchmark, Some(&input))); // When the type matches, description writes nothing (is_match() == true) @@ -581,7 +581,7 @@ mod tests { #[test] fn description_without_input() { - let benchmark = FlatBenchmark::::new(); + let benchmark = Flat::::new(); let text = format!("{}", DescriptionHelper::(&benchmark, None)); assert!(text.contains("Data Type: float32")); } @@ -589,7 +589,7 @@ mod tests { #[test] fn description_with_mismatched_type() { use diskann_benchmark_runner::utils::datatype::DataType; - let benchmark = FlatBenchmark::::new(); + let benchmark = Flat::::new(); let mut input = crate::inputs::flat::FlatSearch::example(); input.data_type = DataType::UInt8; let text = format!("{}", DescriptionHelper(&benchmark, Some(&input))); @@ -598,7 +598,7 @@ mod tests { /// Helper to call `description()` via `Display`. struct DescriptionHelper<'a, T: VectorRepr + AsDataType>( - &'a FlatBenchmark, + &'a Flat, Option<&'a crate::inputs::flat::FlatSearch>, ); From c33cdfb0488c659b663a521b3961d3a8a0e25299 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Thu, 2 Jul 2026 16:43:34 -0700 Subject: [PATCH 12/12] Revert "Switch to FastMemoryVectorProviderAsync" This reverts commit b6e742a0516eeab8b40492f4ccb382a698ce70c0. --- diskann-benchmark/src/flat/search.rs | 38 ++++++++++------------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/diskann-benchmark/src/flat/search.rs b/diskann-benchmark/src/flat/search.rs index 9e7e3d636..0f1c4f8cc 100644 --- a/diskann-benchmark/src/flat/search.rs +++ b/diskann-benchmark/src/flat/search.rs @@ -24,8 +24,6 @@ use diskann_benchmark_runner::{ utils::{datatype::AsDataType, percentiles, MicroSeconds}, Benchmark, Checkpoint, Registry, }; -use diskann_providers::model::graph::provider::async_::FastMemoryVectorProviderAsync; -use diskann_providers::storage::FileStorageProvider; use diskann_utils::{future::SendFuture, views::Matrix}; use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; use half::f16; @@ -56,10 +54,10 @@ pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> /// A minimal in-memory provider for flat search benchmarks. /// -/// Wraps a [`FastMemoryVectorProviderAsync`] and implements [`DataProvider`] -/// with identity ID mapping. -struct InMemProvider { - data: FastMemoryVectorProviderAsync, +/// Wraps a loaded [`Matrix`] and implements [`DataProvider`] with identity +/// ID mapping. +struct InMemProvider { + data: Arc>, } impl DataProvider for InMemProvider { @@ -130,18 +128,9 @@ where // Load dataset writeln!(output, "Loading dataset...")?; - let provider = { - let fmvp = FastMemoryVectorProviderAsync::::load_from_bin( - &FileStorageProvider, - &input.data.to_string_lossy(), - metric, - None, - None, - )?; - InMemProvider { data: fmvp } - }; - let nrows = provider.data.total(); - let ncols = provider.data.dim(); + let data: Matrix = datafiles::load_dataset(datafiles::BinFile(&input.data))?; + let nrows = data.nrows(); + let ncols = data.ncols(); anyhow::ensure!( nrows <= u32::MAX as usize, "flat-index benchmark requires <= {} vectors (got {}) to fit in u32 ids", @@ -150,7 +139,9 @@ where ); writeln!(output, " Loaded {} vectors of dimension {}", nrows, ncols)?; - // Build the FlatIndex + // Build the provider and wrap in FlatIndex + let data = Arc::new(data); + let provider = InMemProvider { data: data.clone() }; let index = FlatIndex::new(provider); // Load queries and groundtruth @@ -239,8 +230,8 @@ impl Strategy { } /// The visitor that iterates over all vectors in the provider. -struct Visitor<'a, T: VectorRepr> { - data: &'a FastMemoryVectorProviderAsync, +struct Visitor<'a, T> { + data: &'a Matrix, } impl HasId for Visitor<'_, T> { @@ -260,10 +251,7 @@ impl DistancesUnordered for Visitor<'_, T> { F: Send + FnMut(Self::Id, f32), { async move { - let total = self.data.total(); - for i in 0..total { - // SAFETY: single-writer load completed before search; no concurrent mutation. - let vector = unsafe { self.data.get_vector_sync(i) }; + for (i, vector) in self.data.row_iter().enumerate() { let dist = computer.evaluate_similarity(vector); f(i as u32, dist); }