diff --git a/diskann-benchmark/example/disk-index-filter.json b/diskann-benchmark/example/disk-index-filter.json new file mode 100644 index 000000000..a3f35ca91 --- /dev/null +++ b/diskann-benchmark/example/disk-index-filter.json @@ -0,0 +1,67 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_filter_graph" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": "disk_index_10pts_idx_uint32_range_res_r_100000.bin" + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_filter_flat" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": true, + "distance": "squared_l2", + "vector_filters_file": "disk_index_10pts_idx_uint32_range_res_r_100000.bin" + } + } + } + ] +} diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index db7bccbdc..a1e84e79b 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -13,9 +13,12 @@ use diskann::utils::VectorRepr; use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, - search::provider::{ - disk_provider::{DiskIndexSearcher, SearchPostProcessorKind}, - disk_vertex_provider_factory::DiskVertexProviderFactory, + search::{ + provider::{ + disk_provider::DiskIndexSearcher, + disk_vertex_provider_factory::DiskVertexProviderFactory, + }, + search_mode::SearchMode, }, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, @@ -33,10 +36,7 @@ use serde::{Deserialize, Serialize}; use crate::{ disk_index::json_spancollector::JsonSpanCollector, - inputs::{ - disk::{DiskIndexLoad, DiskSearchPhase}, - post_processor::TopkPostProcessor, - }, + inputs::disk::{DiskIndexLoad, DiskSearchPhase}, utils::{datafiles, SimilarityMeasure}, }; @@ -268,27 +268,22 @@ where zipped.for_each_in_pool( pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { - let post_processor = search_params.post_processor.as_ref().map_or( - SearchPostProcessorKind::None, - |TopkPostProcessor::DeterminantDiversity(params)| { - SearchPostProcessorKind::DeterminantDiversity(*params) - }, + // Construct the SearchMode from the JSON-driven + // `adaptive_l` is now encapsulated in `DiskSearchMode`, so the + // benchmark only supplies the per-query filter and post-processor. + let has_filter = search_params.vector_filters_file.is_some(); + let mode: SearchMode<'_> = search_params.search_mode.search_mode( + has_filter, + vf, + search_params.post_processor.as_ref(), ); - let vector_filter = if search_params.vector_filters_file.is_none() { - None - } else { - Some(Box::new(move |vid: &u32| vf.contains(vid)) - as Box bool + Send + Sync>) - }; match searcher.search( q, search_params.recall_at, l, Some(search_params.beam_width), - vector_filter, - post_processor, - search_params.is_flat_search, + mode, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; @@ -354,7 +349,7 @@ where num_threads: search_params.num_threads, beam_width: search_params.beam_width, recall_at: search_params.recall_at, - is_flat_search: search_params.is_flat_search, + is_flat_search: search_params.search_mode.is_flat_search, distance: search_params.distance, uses_vector_filters: search_params.vector_filters_file.is_some(), num_nodes_to_cache: search_params.num_nodes_to_cache, diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index f6b4cacac..7ed521a87 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -6,8 +6,15 @@ use std::{fmt, num::NonZeroUsize, path::Path}; use anyhow::Context; +#[cfg(feature = "disk-index")] +use std::collections::HashSet; + +#[cfg(feature = "disk-index")] +use diskann::graph; use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; #[cfg(feature = "disk-index")] +use diskann_disk::search::search_mode::SearchMode; +#[cfg(feature = "disk-index")] use diskann_disk::QuantizationType; use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file}; use serde::{Deserialize, Serialize}; @@ -17,6 +24,9 @@ use crate::{ utils::SimilarityMeasure, }; +#[cfg(feature = "disk-index")] +use crate::inputs::graph_index::AdaptiveL; + ////////////// // Registry // ////////////// @@ -62,6 +72,80 @@ pub(crate) struct DiskIndexBuild { pub(crate) save_path: String, } +#[cfg(feature = "disk-index")] +#[derive(Debug, Serialize, Deserialize, Default)] +pub(crate) struct DiskSearchMode { + pub(crate) is_flat_search: bool, + #[serde(default)] + pub(crate) adaptive_l: Option, +} + +#[cfg(feature = "disk-index")] +impl DiskSearchMode { + pub(crate) fn search_mode<'a>( + &'a self, + has_vector_filters: bool, + vector_filter: &'a HashSet, + post_processor: Option<&TopkPostProcessor>, + ) -> SearchMode<'a> { + let adaptive_l = self.adaptive_l.as_ref().map(|adaptive_l| { + graph::search::AdaptiveL::new(adaptive_l.sample_count.into(), adaptive_l.scale_factor) + .expect("validated adaptive L must construct") + }); + + match ( + self.is_flat_search, + has_vector_filters, + post_processor, + adaptive_l, + ) { + (true, false, _, _) => SearchMode::flat(), + (true, true, _, _) => { + SearchMode::flat_filtered(move |vid: &u32| vector_filter.contains(vid)) + } + (false, false, Some(TopkPostProcessor::DeterminantDiversity(params)), _) => { + SearchMode::diverse_graph(*params) + } + (false, true, Some(TopkPostProcessor::DeterminantDiversity(params)), _) => { + SearchMode::diverse_graph_filtered( + move |vid: &u32| vector_filter.contains(vid), + *params, + ) + } + (false, false, None, Some(adaptive_l)) => { + SearchMode::inline_filter(|_| true, Some(adaptive_l)) + } + (false, true, None, Some(adaptive_l)) => SearchMode::inline_filter( + move |vid: &u32| vector_filter.contains(vid), + Some(adaptive_l), + ), + (false, false, None, None) => SearchMode::graph(), + (false, true, None, None) => { + SearchMode::graph_filtered(move |vid: &u32| vector_filter.contains(vid)) + } + } + } + + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + if let Some(adaptive_l) = self.adaptive_l.as_mut() { + adaptive_l.validate(checker)?; + } + Ok(()) + } +} + +#[cfg(feature = "disk-index")] +impl fmt::Display for DiskSearchMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let base = if self.is_flat_search { "flat" } else { "graph" }; + if self.adaptive_l.is_some() { + write!(f, "{} + adaptive-l", base) + } else { + write!(f, "{}", base) + } + } +} + /// Search phase configuration #[derive(Debug, Deserialize, Serialize)] pub(crate) struct DiskSearchPhase { @@ -71,6 +155,15 @@ pub(crate) struct DiskSearchPhase { pub(crate) beam_width: usize, pub(crate) search_list: Vec, pub(crate) recall_at: u32, + #[cfg(feature = "disk-index")] + #[serde(default)] + pub(crate) search_mode: DiskSearchMode, + // Backward compatibility for older benchmark inputs that used + // `is_flat_search` directly at the search-phase level. + #[cfg(feature = "disk-index")] + #[serde(default, skip_serializing)] + pub(crate) is_flat_search: Option, + #[cfg(not(feature = "disk-index"))] pub(crate) is_flat_search: bool, pub(crate) distance: SimilarityMeasure, pub(crate) vector_filters_file: Option, @@ -181,6 +274,16 @@ impl DiskSearchPhase { vf.resolve(checker).context("invalid vector_filters_file")?; } + #[cfg(feature = "disk-index")] + if let Some(is_flat_search) = self.is_flat_search { + self.search_mode.is_flat_search = is_flat_search; + } + + #[cfg(feature = "disk-index")] + self.search_mode + .validate(checker) + .context("invalid disk search mode")?; + // basic numeric sanity checks if self.search_list.is_empty() { anyhow::bail!("search_list must have at least one value"); @@ -250,6 +353,14 @@ impl Example for DiskIndexOperation { beam_width: 16, recall_at: 10, num_threads: 8, + #[cfg(feature = "disk-index")] + search_mode: DiskSearchMode { + is_flat_search: false, + adaptive_l: None, + }, + #[cfg(feature = "disk-index")] + is_flat_search: None, + #[cfg(not(feature = "disk-index"))] is_flat_search: false, distance: SimilarityMeasure::SquaredL2, vector_filters_file: None, @@ -367,6 +478,9 @@ impl DiskSearchPhase { write_field!(f, "Beam Width", self.beam_width)?; write_field!(f, "Recall@", self.recall_at)?; write_field!(f, "Threads", self.num_threads)?; + #[cfg(feature = "disk-index")] + write_field!(f, "Search Mode", self.search_mode)?; + #[cfg(not(feature = "disk-index"))] write_field!(f, "Flat Search", self.is_flat_search)?; write_field!(f, "Distance", self.distance)?; match &self.vector_filters_file { diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index a30a6ccdc..14ce2cc33 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -688,6 +688,37 @@ mod tests { run_integration_test(raw); } + /// Filtered disk search end-to-end: drives the disk-index backend through + /// `disk-index-filter.json` + #[test] + #[cfg(feature = "disk-index")] + fn disk_index_filter_integration() { + let mut raw = value_from_file(&example_directory().join("disk-index-filter.json")); + prefix_search_directories(&mut raw, &root_directory()); + + let tempdir = tempfile::tempdir().unwrap(); + let input_path = tempdir.path().join("disk-index-filter.json"); + save_to_file(&input_path, &raw); + let output_path = tempdir.path().join("output.json"); + + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + allow_debug: true, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + let result = cli.run(&mut output); + let output_str = String::from_utf8(output.into_inner()).unwrap(); + println!("output = {}", output_str); + result.expect("disk-index-filter run failed"); + + assert!(output_path.exists()); + let results: Vec = load_from_file(&output_path); + assert_eq!(results.len(), num_jobs(&raw)); + } + #[test] fn graph_index_inline_filter_yfcc_integration() { // First, parse and modify the input file to establish paths relative to the diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index 2fb24c07a..37882bb8a 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1089,9 +1089,7 @@ pub(crate) mod disk_index_builder_tests { &mut indices, &mut distances, &mut associated_data, - None, - &|_| true, - false, + &crate::search::search_mode::SearchMode::graph(), ); diskann_providers::test_utils::assert_top_k_exactly_match( diff --git a/diskann-disk/src/search/mod.rs b/diskann-disk/src/search/mod.rs index 54f710c7f..b377148de 100644 --- a/diskann-disk/src/search/mod.rs +++ b/diskann-disk/src/search/mod.rs @@ -7,4 +7,5 @@ pub mod pq; pub mod provider; +pub mod search_mode; pub mod traits; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 8e95dfa4b..f7987ad60 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -18,8 +18,9 @@ use diskann::{ error::IntoANNResult, graph::{ self, + ext::labeled::{self, QueryLabelProvider}, glue::{self, DefaultPostProcessor, SearchPostProcess, SearchStrategy}, - search::Knn, + search::{AdaptiveL, InlineFilterSearch, Knn}, search_output_buffer, DiskANNIndex, }, neighbor::{Neighbor, NeighborPriorityQueue}, @@ -47,9 +48,9 @@ use tracing::debug; use crate::{ data_model::{CachingStrategy, GraphHeader}, - filter_parameter::{default_vector_filter, VectorFilter}, search::{ provider::disk_vertex_provider_factory::DiskVertexProviderFactory, + search_mode::SearchMode, traits::{VertexProvider, VertexProviderFactory}, }, storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader}, @@ -212,14 +213,36 @@ where /// from local data structures). Moving these components to the search strategy allows /// DiskProvider to satisfy 'static constraints while enabling flexible per-search /// resource management. +/// Borrowed predicate used internally by the disk search pipeline. +/// Spelled out here to keep the field/parameter signatures under +/// `clippy::type_complexity`'s default threshold. +type PostprocessFilter<'a> = &'a (dyn Fn(&u32) -> bool + Send + Sync); + +/// Encodes whether to accept all candidates at rerank time or apply a +/// specific predicate. Used by `RerankAndFilter` and +/// `DeterminantDiversityAndFilter` instead of `Option` +/// so call sites are self-documenting without relying on comments to +/// explain what `None` means. +#[derive(Clone, Copy)] +pub enum PostprocessStrategy<'a> { + /// Accept every candidate — no predicate is called. Used by `FlatScan` + /// (filtered at scan time) and `InlineFilter` (filtered at visit time). + AcceptAll, + /// Apply the given predicate; non-matching candidates are dropped. + Apply(PostprocessFilter<'a>), +} + pub struct DiskSearchStrategy<'a, Data, ProviderFactory> where Data: GraphDataType, ProviderFactory: VertexProviderFactory, { - // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads. - io_tracker: IOTracker, - vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), // Fn param is u32 as we validate "VectorIdType = u32" everywhere in this provider in trait bounds. + // Borrowed from `search_internal` so the strategy can be passed by value + io_tracker: &'a IOTracker, + /// Consumed only by `default_post_processor()` → `RerankAndFilter`. + /// `FlatScan` and `InlineFilter` filter earlier in their pipelines and + /// pass `AcceptAll` here to avoid a redundant second pass. + postprocess_filter: PostprocessStrategy<'a>, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: &'a ProviderFactory, @@ -267,23 +290,18 @@ impl IOTracker { #[derive(Clone, Copy)] pub struct RerankAndFilter<'a> { - filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + filter: PostprocessStrategy<'a>, } #[derive(Clone, Copy)] pub struct DeterminantDiversityAndFilter<'a> { - filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + filter: PostprocessStrategy<'a>, params: DeterminantDiversityParams, } -#[derive(Clone, Copy)] -pub enum SearchPostProcessorKind { - /// No post-processing; search results are returned as-is. - None, - RerankAndFilter, - DeterminantDiversity(DeterminantDiversityParams), -} - +/// Internal dispatch wrapper used by `search_internal`'s `DiverseGraph` arm +/// to feed `DiskANNIndex::search_with`. Hidden behind `SearchMode` from the +/// public API. #[derive(Clone, Copy)] pub enum DiskSearchPostProcessor<'a> { RerankAndFilter(RerankAndFilter<'a>), @@ -291,16 +309,13 @@ pub enum DiskSearchPostProcessor<'a> { } impl<'a> RerankAndFilter<'a> { - pub fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { + pub fn new(filter: PostprocessStrategy<'a>) -> Self { Self { filter } } } impl<'a> DeterminantDiversityAndFilter<'a> { - pub fn new( - filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), - params: DeterminantDiversityParams, - ) -> Self { + pub fn new(filter: PostprocessStrategy<'a>, params: DeterminantDiversityParams) -> Self { Self { filter, params } } } @@ -335,18 +350,27 @@ where let provider = accessor.provider; let mut uncached_ids = Vec::new(); - let mut reranked = candidates - .map(|n| n.id) - .filter(|id| (self.filter)(id)) - .filter_map(|n| { + let mut reranked = { + let mut process = |n: u32| { if let Some(entry) = accessor.scratch.distance_cache.get(&n) { Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0))) } else { uncached_ids.push(n); None } - }) - .collect::, _>>()?; + }; + match self.filter { + PostprocessStrategy::AcceptAll => candidates + .map(|n| n.id) + .filter_map(&mut process) + .collect::, _>>()?, + PostprocessStrategy::Apply(f) => candidates + .map(|n| n.id) + .filter(|id| f(id)) + .filter_map(&mut process) + .collect::, _>>()?, + } + }; if !uncached_ids.is_empty() { ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?; for n in &uncached_ids { @@ -395,10 +419,13 @@ where let provider = accessor.provider; let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; - let candidate_ids: Vec = candidates - .map(|candidate| candidate.id) - .filter(|id| (self.filter)(id)) - .collect(); + let candidate_ids: Vec = match self.filter { + PostprocessStrategy::AcceptAll => candidates.map(|candidate| candidate.id).collect(), + PostprocessStrategy::Apply(f) => candidates + .map(|candidate| candidate.id) + .filter(|id| f(id)) + .collect(), + }; if candidate_ids.is_empty() { return Ok(0); @@ -497,7 +524,7 @@ where ) -> Result { DiskAccessor::new( provider, - &self.io_tracker, + self.io_tracker, query, self.vertex_provider_factory, self.scratch_pool, @@ -522,7 +549,7 @@ where type Processor = RerankAndFilter<'this>; fn default_post_processor(&'this self) -> Self::Processor { - RerankAndFilter::new(self.vector_filter) + RerankAndFilter::new(self.postprocess_filter) } } @@ -894,14 +921,15 @@ where }) } - /// Helper method to create a DiskSearchStrategy with common parameters + /// Helper method to create a `DiskSearchStrategy` with common parameters. fn search_strategy<'a>( &'a self, - vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), + io_tracker: &'a IOTracker, + postprocess_filter: PostprocessStrategy<'a>, ) -> DiskSearchStrategy<'a, Data, ProviderFactory> { DiskSearchStrategy { - io_tracker: IOTracker::default(), - vector_filter, + io_tracker, + postprocess_filter, vertex_provider_factory: &self.vertex_provider_factory, scratch_pool: &self.scratch_pool, } @@ -910,13 +938,16 @@ where /// Perform a brute-force linear scan of all points in the index, returning the /// nearest neighbors that pass `vector_filter`. /// + /// `vector_filter = None` scans every vector (recall baseline) and skips + /// the per-ID dyn-fn call entirely. + /// /// The top `neighbors_before_reranking` candidates from the quantized scan will be /// provided to full-precision reranking. async fn flat_search( &self, strategy: &DiskSearchStrategy<'_, Data, ProviderFactory>, query: &[Data::VectorDataType], - vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync), + vector_filter: Option<&(dyn Fn(&u32) -> bool + Send + Sync)>, neighbors_before_reranking: usize, output: &mut OB, ) -> ANNResult @@ -947,7 +978,10 @@ where let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); let mut cmps = 0u32; - let mut iter = (0..provider.num_points as u32).filter(vector_filter); + // `None` short-circuits to `true` — no dyn-fn call per node on the + // unfiltered (recall-baseline) path. + let mut iter = + (0..provider.num_points as u32).filter(|id| vector_filter.is_none_or(|f| f(id))); loop { id_buffer.clear(); id_buffer.extend(iter.by_ref().take(batch_size)); @@ -974,18 +1008,50 @@ where }) } + /// Run inline label-filtered graph search with optional adaptive-L sizing. + /// + /// Wraps `Knn` in an `InlineFilterSearch` that tracks matched candidates + /// during traversal. When `adaptive_l = Some(_)`, the beam (`l_search`) + /// is grown mid-query if the observed match specificity is low (see + /// `diskann::graph::search::AdaptiveL`). + /// + /// The label-provider trait object is built once in + /// `SearchMode::inline_filter` from a generic adapter, so each filter + /// evaluation costs exactly one indirect dispatch (through the + /// `&dyn QueryLabelProvider` boundary required by `labeled::Filtered`), + /// not two. + /// + /// Reuses the same `DiskAccessor` surface as the plain `Knn` graph path: + /// `start_point_distances` and `expand_beam`, both of which call + /// `pq_distances` internally. + async fn filter_search<'a, OB>( + &self, + strategy: DiskSearchStrategy<'a, Data, ProviderFactory>, + query: &[Data::VectorDataType], + knn: Knn, + label_provider: &(dyn QueryLabelProvider + 'a), + adaptive_l: Option, + output: &mut OB, + ) -> ANNResult + where + OB: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send, + { + let filtered_strategy = labeled::Filtered::new(strategy, label_provider); + let search = InlineFilterSearch::new(knn, adaptive_l); + self.index + .search(search, &filtered_strategy, &DefaultContext, query, output) + .await + } + /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. - #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], return_list_size: u32, search_list_size: u32, beam_width: Option, - vector_filter: Option>, - post_processor: SearchPostProcessorKind, - is_flat_search: bool, + mode: SearchMode<'_>, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; return_list_size as usize]; @@ -993,21 +1059,6 @@ where let mut associated_data = vec![Data::AssociatedDataType::default(); return_list_size as usize]; - let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); - let post_processor = match post_processor { - SearchPostProcessorKind::None => None, - SearchPostProcessorKind::RerankAndFilter => { - Some(DiskSearchPostProcessor::RerankAndFilter( - RerankAndFilter::new(vector_filter.as_ref()), - )) - } - SearchPostProcessorKind::DeterminantDiversity(params) => { - Some(DiskSearchPostProcessor::DeterminantDiversity( - DeterminantDiversityAndFilter::new(vector_filter.as_ref(), params), - )) - } - }; - let stats = self.search_internal( query, return_list_size as usize, @@ -1017,9 +1068,7 @@ where &mut indices, &mut distances, &mut associated_data, - post_processor, - vector_filter.as_ref(), - is_flat_search, + &mode, )?; let mut search_result = SearchResult { @@ -1055,9 +1104,7 @@ where indices: &mut [u32], distances: &mut [f32], associated_data: &mut [Data::AssociatedDataType], - post_processor: Option>, - vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), - is_flat_search: bool, + mode: &SearchMode<'_>, ) -> ANNResult { let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices[..k_value], @@ -1065,45 +1112,94 @@ where &mut associated_data[..k_value], ); - let strategy = self.search_strategy(vector_filter); let timer = Instant::now(); let k = k_value; let l = search_list_size as usize; - let stats = if is_flat_search { - self.runtime.block_on(self.flat_search( - &strategy, - query, - vector_filter, - l, - &mut result_output_buffer, - ))? - } else if let Some(processor) = post_processor { - self.runtime.block_on(self.index.search_with( - Knn::new(k, l, beam_width)?, - &strategy, - processor, - &DefaultContext, - query, - &mut result_output_buffer, - ))? - } else { - self.runtime.block_on(self.index.search( - Knn::new(k, l, beam_width)?, - &strategy, - &DefaultContext, - query, - &mut result_output_buffer, - ))? + + let io_tracker = IOTracker::default(); + + // * `FlatScan` — `flat_search` filters the scan iterator at + // construction; non-matching IDs never enter `best`. + // * `Graph` — plain greedy traversal doesn't consult any predicate; + // if a predicate is set, `RerankAndFilter` filters out + // non-matching nodes at rerank time. + // * `InlineFilter` — `InlineFilterSearch` only forwards `Accept` nodes + // into `matched_results`; no filtering in post-process. + // * `DiverseGraph` — `index.search_with` runs `DeterminantDiversityAndFilter` + // as the post-processor over the L candidate pool. + let stats = match mode { + SearchMode::FlatScan { filter } => { + let strategy = self.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); + self.runtime.block_on(self.flat_search( + &strategy, + query, + filter.as_deref(), + l, + &mut result_output_buffer, + ))? + } + SearchMode::Graph { filter } => { + let strategy = self.search_strategy( + &io_tracker, + filter + .as_deref() + .map_or(PostprocessStrategy::AcceptAll, PostprocessStrategy::Apply), + ); + let knn_search = Knn::new(k, l, beam_width)?; + self.runtime.block_on(self.index.search( + knn_search, + &strategy, + &DefaultContext, + query, + &mut result_output_buffer, + ))? + } + SearchMode::InlineFilter { filter, adaptive_l } => { + // Strategy is passed by value into `filter_search` so that the + // `labeled::Filtered` wrapper can own it; `io_tracker` keeps + // its counters reachable from this scope. + let strategy = self.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); + let knn_search = Knn::new(k, l, beam_width)?; + self.runtime.block_on(self.filter_search( + strategy, + query, + knn_search, + filter.as_ref(), + adaptive_l.clone(), + &mut result_output_buffer, + ))? + } + SearchMode::DiverseGraph { filter, params } => { + // Strategy installs the filter so `RerankAndFilter` would also + // honor it, but the active post-processor here is the + // diversity selector built from `DiskSearchPostProcessor`. + let postprocess_config = filter + .as_deref() + .map_or(PostprocessStrategy::AcceptAll, PostprocessStrategy::Apply); + let strategy = self.search_strategy(&io_tracker, postprocess_config); + let knn_search = Knn::new(k, l, beam_width)?; + let processor = DiskSearchPostProcessor::DeterminantDiversity( + DeterminantDiversityAndFilter::new(postprocess_config, *params), + ); + self.runtime.block_on(self.index.search_with( + knn_search, + &strategy, + processor, + &DefaultContext, + query, + &mut result_output_buffer, + ))? + } }; query_stats.total_comparisons = stats.cmps; query_stats.search_hops = stats.hops; query_stats.total_execution_time_us = timer.elapsed().as_micros(); - query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128; - query_stats.total_io_operations = strategy.io_tracker.io_count() as u32; - query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32; + query_stats.io_time_us = IOTracker::time(&io_tracker.io_time_us) as u128; + query_stats.total_io_operations = io_tracker.io_count() as u32; + query_stats.total_vertices_loaded = io_tracker.io_count() as u32; query_stats.query_pq_preprocess_time_us = - IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128; + IOTracker::time(&io_tracker.preprocess_time_us) as u128; query_stats.cpu_time_us = query_stats.total_execution_time_us - query_stats.io_time_us - query_stats.query_pq_preprocess_time_us; @@ -1485,9 +1581,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - None::>, - &(|_| true), - false, + &SearchMode::graph(), ); // Calculate the range of the truth_result for this query @@ -1539,9 +1633,7 @@ mod disk_provider_tests { params.k as u32, params.l as u32, beam_width, - None, - SearchPostProcessorKind::None, - false, + SearchMode::graph(), ) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); @@ -1652,9 +1744,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - None::>, - &|_| true, - false, + &SearchMode::graph(), ); assert!(result.is_err()); @@ -1687,7 +1777,8 @@ mod disk_provider_tests { &mut distances, &mut associated_data, ); - let strategy = search_engine.search_strategy(&|_| true); + let io_tracker = IOTracker::default(); + let strategy = search_engine.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); let recorded_search = @@ -1722,9 +1813,7 @@ mod disk_provider_tests { return_list_size, search_list_size, Some(4), - None, - SearchPostProcessorKind::None, - false, + SearchMode::graph(), ); assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); @@ -1767,9 +1856,7 @@ mod disk_provider_tests { search_list_size, search_list_size, Some(4), - None, - SearchPostProcessorKind::None, - false, + SearchMode::graph(), ) .unwrap(); let baseline_ids: std::collections::HashSet = @@ -1787,9 +1874,7 @@ mod disk_provider_tests { return_list_size, search_list_size, Some(4), - None, - SearchPostProcessorKind::DeterminantDiversity(params), - false, + SearchMode::diverse_graph(params), ) .unwrap(); let det_div_ids: Vec = result.results.iter().map(|r| r.vertex_id).collect(); @@ -1827,9 +1912,7 @@ mod disk_provider_tests { return_list_size, search_list_size, Some(4), - None, - SearchPostProcessorKind::DeterminantDiversity(pure_params), - false, + SearchMode::diverse_graph(pure_params), ) .unwrap(); let pure_ids: Vec = pure_result.results.iter().map(|r| r.vertex_id).collect(); @@ -1844,16 +1927,13 @@ mod disk_provider_tests { // The vector_filter is honored by det-div: filter out the baseline top-1 and // verify it is excluded from the det-div results. let excluded = baseline_top1.vertex_id; - let filter: VectorFilter = Box::new(move |id| *id != excluded); let filtered = search_engine .search( &query_vector, return_list_size, search_list_size, Some(4), - Some(filter), - SearchPostProcessorKind::DeterminantDiversity(params), - false, + SearchMode::diverse_graph_filtered(move |id: &u32| *id != excluded, params), ) .unwrap(); let filtered_ids: Vec = filtered.results.iter().map(|r| r.vertex_id).collect(); @@ -1933,7 +2013,8 @@ mod disk_provider_tests { &mut distances, &mut associated_data, ); - let strategy = search_engine.search_strategy(&|_| true); + let io_tracker = IOTracker::default(); + let strategy = search_engine.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); // Create diverse search parameters with attribute provider let diverse_params = DiverseSearchParams::new( @@ -1980,7 +2061,8 @@ mod disk_provider_tests { &mut distances2, &mut associated_data2, ); - let strategy2 = search_engine.search_strategy(&|_| true); + let io_tracker2 = IOTracker::default(); + let strategy2 = search_engine.search_strategy(&io_tracker2, PostprocessStrategy::AcceptAll); let search_params2 = Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap(); @@ -2176,6 +2258,16 @@ mod disk_provider_tests { let mut distances = vec![0f32; 10]; let mut associated_data = vec![(); 10]; + // Build the same `SearchMode` twice. `vector_filter` is a `fn` pointer + // (Copy), so each call reconstructs a fresh mode with the same filter. + let make_mode = || -> SearchMode<'static> { + if is_flat_search { + SearchMode::flat_filtered(vector_filter) + } else { + SearchMode::graph_filtered(vector_filter) + } + }; + let result = search_engine.search_internal( &query, 10, @@ -2185,9 +2277,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - None::>, - &vector_filter, - is_flat_search, + &make_mode(), ); assert!(result.is_ok(), "Expected search to succeed"); @@ -2207,9 +2297,7 @@ mod disk_provider_tests { 10, 10, None, // beam_width - Some(Box::new(vector_filter)), - SearchPostProcessorKind::None, - is_flat_search, + make_mode(), ); assert!(result_with_filter.is_ok(), "Expected search to succeed"); @@ -2238,6 +2326,126 @@ mod disk_provider_tests { ); } + // =========================================================================== + // Inline filter + AdaptiveL behavioral tests + // =========================================================================== + // + // Two basic invariants from the design review: + // + // 1. `adaptive_l = Some(_)` with an always-true predicate visits every + // candidate as a "match," computes specificity = 100%, never triggers + // a resize, and produces the same top-k as plain `Knn`. This is the + // "no-op equivalence" guard. + // + // 2. `adaptive_l = Some(_)` with a selective predicate must produce a + // valid result set whose IDs all satisfy the predicate. Doesn't assert + // recall@k (would need filter-selective ground truth) — just that the + // inline path runs end-to-end and produces filter-conforming output. + + #[test] + fn test_adaptive_l_with_no_filter_matches_plain_knn() { + let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); + let search_engine = create_disk_index_searcher::( + CreateDiskIndexSearcherParams { + max_thread_num: 1, + pq_pivot_file_path: TEST_PQ_PIVOT_128DIM, + pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM, + index_path: TEST_INDEX_128DIM, + index_path_prefix: TEST_INDEX_PREFIX_128DIM, + ..Default::default() + }, + &storage_provider, + ); + let query = vec![0.1f32; 128]; + + let plain = search_engine + .search(&query, 10, 10, None, SearchMode::graph()) + .expect("plain Knn must succeed"); + + let inline_no_filter = search_engine + .search( + &query, + 10, + 10, + None, + SearchMode::inline_filter( + |_| true, + Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + ), + ) + .expect("inline filter with accept-all predicate must succeed"); + + let plain_ids: Vec = plain.results.iter().map(|r| r.vertex_id).collect(); + let inline_ids: Vec = inline_no_filter + .results + .iter() + .map(|r| r.vertex_id) + .collect(); + + assert_eq!( + plain.stats.result_count, inline_no_filter.stats.result_count, + "no-filter inline path must return same result count as plain Knn" + ); + assert_eq!( + plain_ids, inline_ids, + "no-filter inline path must return the same top-k IDs as plain Knn" + ); + } + + #[test] + fn test_adaptive_l_with_selective_predicate_returns_only_matches() { + let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); + let search_engine = create_disk_index_searcher::( + CreateDiskIndexSearcherParams { + max_thread_num: 1, + pq_pivot_file_path: TEST_PQ_PIVOT_128DIM, + pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM, + index_path: TEST_INDEX_128DIM, + index_path_prefix: TEST_INDEX_PREFIX_128DIM, + ..Default::default() + }, + &storage_provider, + ); + let query = vec![0.1f32; 128]; + // Predicate from `test_search_with_vector_filter::case_4` — three IDs + // known to be in the unfiltered top-10 for this query+fixture. + let predicate = |id: &u32| *id == 72 || *id == 87 || *id == 170; + + let result = search_engine + .search( + &query, + 10, + 10, + None, + SearchMode::inline_filter( + predicate, + Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + ), + ) + .expect("inline filter search with AdaptiveL must succeed"); + + // `result.results` is pre-allocated to `return_list_size`; only the + // first `result_count` entries are populated. The trailing entries + // are default zeros — not search output — so slice before asserting. + let count = result.stats.result_count as usize; + let ids: Vec = result + .results + .iter() + .take(count) + .map(|r| r.vertex_id) + .collect(); + for id in &ids { + assert!( + predicate(id), + "AdaptiveL result must only contain predicate-matching IDs; got {id} in {ids:?}" + ); + } + assert!( + !ids.is_empty(), + "AdaptiveL on a fixture with reachable matches must return at least one match" + ); + } + #[test] fn test_beam_search_respects_io_limit() { let io_limit = 11; // Set a small IO limit for testing @@ -2266,7 +2474,8 @@ mod disk_provider_tests { &mut associated_data, ); - let strategy = search_engine.search_strategy(&|_| true); + let io_tracker = IOTracker::default(); + let strategy = search_engine.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); diff --git a/diskann-disk/src/search/search_mode.rs b/diskann-disk/src/search/search_mode.rs new file mode 100644 index 000000000..49f7462f9 --- /dev/null +++ b/diskann-disk/src/search/search_mode.rs @@ -0,0 +1,211 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Top-level disk search mode. +//! +//! Encodes the algorithm + filter combination for `DiskIndexSearcher::search` +//! as a sum type so that invalid combinations (flat scan with adaptive L, +//! inline filter without a predicate) are unrepresentable at the API boundary. + +use diskann::graph::ext::labeled::QueryLabelProvider; +use diskann::graph::search::AdaptiveL; +use diskann_providers::model::graph::provider::DeterminantDiversityParams; + +/// Owned closure used to filter vector IDs during disk search. +/// +/// The `&u32` argument is the disk-side internal/external ID (they coincide +/// on the disk path by construction). +pub type SearchPredicate<'a> = Box bool + Send + Sync + 'a>; + +/// Top-level disk search mode. +/// +/// Three variants encode the algorithm + filter combination: +/// +/// * `FlatScan` — brute-force linear scan, with or without an inline filter. +/// * `Graph` — plain greedy beam search; the optional filter is applied as a +/// hard post-filter during reranking (no traversal-time effect). +/// * `InlineFilter` — label-filtered graph search; the predicate is consulted +/// at visit time (not just during rerank). `adaptive_l = Some(_)` grows the +/// beam mid-search if the observed match specificity is low. +/// * `DiverseGraph` — greedy graph search with determinant-diversity +/// post-processing; selects a maximally diverse top-k from the candidate +/// pool using `DeterminantDiversityParams`. Optional hard post-filter is +/// applied during the diversity selection step. +pub enum SearchMode<'a> { + FlatScan { + filter: Option>, + }, + + Graph { + filter: Option>, + }, + + InlineFilter { + filter: Box + 'a>, + adaptive_l: Option, + }, + + DiverseGraph { + filter: Option>, + params: DeterminantDiversityParams, + }, +} + +impl<'a> SearchMode<'a> { + /// Flat scan over all vectors. Recall baseline. + pub fn flat() -> Self { + Self::FlatScan { filter: None } + } + + /// Flat scan restricted to vectors that satisfy `predicate`. + pub fn flat_filtered(predicate: F) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + Self::FlatScan { + filter: Some(Box::new(predicate)), + } + } + + /// Plain greedy graph search; no filter. + pub fn graph() -> Self { + Self::Graph { filter: None } + } + + /// Plain greedy graph search with a hard post-filter applied during + /// reranking. Traversal is unaffected. + pub fn graph_filtered(predicate: F) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + Self::Graph { + filter: Some(Box::new(predicate)), + } + } + + /// Inline label-filtered graph search. `adaptive_l = Some(_)` enables + /// mid-search beam widening; `None` runs inline tracking only (no + /// resizing). + /// + /// The closure is wrapped in a generic adapter (`FnLabelProvider`) + /// that implements `QueryLabelProvider`. + pub fn inline_filter(predicate: F, adaptive_l: Option) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + struct FnLabelProvider(F); + + impl std::fmt::Debug for FnLabelProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FnLabelProvider").finish_non_exhaustive() + } + } + + impl QueryLabelProvider for FnLabelProvider + where + F: Fn(&u32) -> bool + Send + Sync, + { + fn is_match(&self, vec_id: u32) -> bool { + (self.0)(&vec_id) + } + } + + Self::InlineFilter { + filter: Box::new(FnLabelProvider(predicate)), + adaptive_l, + } + } + + /// Greedy graph search with determinant-diversity post-processing. + /// Selects a diverse top-k from the candidate pool found at L. + pub fn diverse_graph(params: DeterminantDiversityParams) -> Self { + Self::DiverseGraph { + filter: None, + params, + } + } + + /// Greedy graph search with determinant-diversity post-processing and a + /// hard post-filter. The filter is honored during the diverse-selection + /// step (non-matching IDs are excluded from the final top-k). + pub fn diverse_graph_filtered(predicate: F, params: DeterminantDiversityParams) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + Self::DiverseGraph { + filter: Some(Box::new(predicate)), + params, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn flat_no_filter_constructor() { + let mode = SearchMode::flat(); + assert!(matches!(mode, SearchMode::FlatScan { filter: None })); + } + + #[test] + fn flat_filtered_constructor() { + let mode = SearchMode::flat_filtered(|id| *id == 5); + match &mode { + SearchMode::FlatScan { filter: Some(p) } => { + assert!(p(&5)); + assert!(!p(&4)); + } + _ => panic!("expected FlatScan with filter"), + } + } + + #[test] + fn graph_no_filter_constructor() { + let mode = SearchMode::graph(); + assert!(matches!(mode, SearchMode::Graph { filter: None })); + } + + #[test] + fn graph_filtered_constructor() { + let mode = SearchMode::graph_filtered(|id| *id == 7); + match &mode { + SearchMode::Graph { filter: Some(p) } => { + assert!(p(&7)); + assert!(!p(&6)); + } + _ => panic!("expected Graph with filter"), + } + } + + #[test] + fn inline_filter_constructor_without_adaptive_l() { + let mode = SearchMode::inline_filter(|id| *id == 3, None); + match &mode { + SearchMode::InlineFilter { + filter, + adaptive_l: None, + } => { + assert!(filter.is_match(3)); + assert!(!filter.is_match(2)); + } + _ => panic!("expected InlineFilter with adaptive_l = None"), + } + } + + #[test] + fn inline_filter_constructor_with_adaptive_l() { + let adaptive = AdaptiveL::new(5, 16.0).expect("valid AdaptiveL"); + let mode = SearchMode::inline_filter(|id| *id == 11, Some(adaptive)); + match &mode { + SearchMode::InlineFilter { + adaptive_l: Some(_), + .. + } => {} + _ => panic!("expected InlineFilter with adaptive_l = Some"), + } + } +} diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index c38544d07..c01bbc8b7 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -8,9 +8,12 @@ use std::{collections::HashSet, sync::atomic::AtomicBool, time::Instant}; use diskann::utils::IntoUsize; use diskann_disk::{ data_model::{CachingStrategy, GraphDataType}, - search::provider::{ - disk_provider::{DiskIndexSearcher, SearchPostProcessorKind}, - disk_vertex_provider_factory::DiskVertexProviderFactory, + search::{ + provider::{ + disk_provider::DiskIndexSearcher, + disk_vertex_provider_factory::DiskVertexProviderFactory, + }, + search_mode::SearchMode, }, storage::disk_index_reader::DiskIndexReader, utils::{ @@ -251,21 +254,29 @@ where (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats), result_count, )| { - let vector_filter_function: Box bool + Send + Sync> = - if parameters.filter_bitmap_file.is_none() { - Box::new(|_: &u32| true) - } else { - Box::new(move |vector_id: &u32| vector_filter.contains(vector_id)) - }; + // Construct the mode from the CLI-driven + // `(is_flat_search, has_filter)` pair. CLI doesn't expose + // AdaptiveL yet, so `InlineFilter` is unreachable here. + let mode: SearchMode<'_> = match ( + parameters.is_flat_search, + parameters.filter_bitmap_file.is_some(), + ) { + (true, false) => SearchMode::flat(), + (true, true) => { + SearchMode::flat_filtered(move |vid: &u32| vector_filter.contains(vid)) + } + (false, false) => SearchMode::graph(), + (false, true) => { + SearchMode::graph_filtered(move |vid: &u32| vector_filter.contains(vid)) + } + }; let result = searcher.search( query, parameters.recall_at, l, Some(parameters.beam_width as usize), - Some(vector_filter_function), - SearchPostProcessorKind::None, - parameters.is_flat_search, + mode, ); match result {