From e176e431853a67814211121c296c7eb2434a1568 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Tue, 16 Jun 2026 10:56:51 +0800 Subject: [PATCH 01/14] Implement filter search for disk path using AdpativeL --- .../src/backend/disk_index/search.rs | 1 + diskann-disk/src/build/builder/core.rs | 1 + .../src/search/provider/disk_provider.rs | 202 +++++++++++++++++- diskann-tools/src/utils/search_disk_index.rs | 1 + 4 files changed, 203 insertions(+), 2 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 2a6966870..bc7692807 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -278,6 +278,7 @@ where Some(search_params.beam_width), vector_filter, search_params.is_flat_search, + None, // adaptive_l — disk-side AdaptiveL not yet exposed via JSON ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index ed69058de..1a114eb69 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1091,6 +1091,7 @@ pub(crate) mod disk_index_builder_tests { &mut associated_data, &|_| true, false, + None, // adaptive_l ); diskann_providers::test_utils::assert_top_k_exactly_match( diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index f2f24d5e1..bb599a81a 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -19,7 +19,8 @@ use diskann::{ graph::{ self, glue::{self, DefaultPostProcessor, SearchPostProcess, SearchStrategy}, - search::Knn, + index::QueryLabelProvider, + search::{AdaptiveL, InlineFilterSearch, Knn}, search_output_buffer, DiskANNIndex, }, neighbor::{Neighbor, NeighborPriorityQueue}, @@ -270,6 +271,26 @@ impl<'a> RerankAndFilter<'a> { } } +/// Adapter exposing the existing disk-side `&dyn Fn(&u32) -> bool` predicate +/// as a `QueryLabelProvider` for `InlineFilterSearch`. Lives entirely +/// inside `filter_search` — public `search()` keeps its existing predicate +/// type at the boundary. +struct PredicateLabelProvider<'a> { + predicate: &'a (dyn Fn(&u32) -> bool + Send + Sync), +} + +impl std::fmt::Debug for PredicateLabelProvider<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PredicateLabelProvider").finish_non_exhaustive() + } +} + +impl QueryLabelProvider for PredicateLabelProvider<'_> { + fn is_match(&self, vec_id: u32) -> bool { + (self.predicate)(&vec_id) + } +} + impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -825,8 +846,49 @@ where }) } + /// Run inline label-filtered graph search with optional adaptive-L sizing. + /// + /// Wraps `Knn` in an `InlineFilterSearch` that tracks matched candidates + /// during traversal. When `adaptive_l = Some(_)`, the beam (`l_search`) + /// is grown mid-query if the observed match specificity is low (see + /// `diskann::graph::search::AdaptiveL`). + /// + /// The disk-side `&dyn Fn(&u32) -> bool` predicate is adapted to the + /// `QueryLabelProvider` interface that `InlineFilterSearch` consumes + /// via a stack-allocated `PredicateLabelProvider` shim — no allocation, + /// no lifetime threading past this function body. + /// + /// Reuses the same `DiskAccessor` surface as the plain `Knn` graph path: + /// `start_point_distances` and `expand_beam`, both of which call + /// `pq_distances` internally. + async fn filter_search( + &self, + strategy: &DiskSearchStrategy<'_, Data, ProviderFactory>, + query: &[Data::VectorDataType], + knn: Knn, + vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync), + adaptive_l: Option, + output: &mut OB, + ) -> ANNResult + where + OB: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send, + { + let label_provider = PredicateLabelProvider { + predicate: vector_filter, + }; + let search = InlineFilterSearch::new(knn, &label_provider, adaptive_l); + self.index + .search(search, strategy, &DefaultContext, query, output) + .await + } + /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. + /// + /// `adaptive_l = Some(_)` routes the graph path through inline + /// label-filtered search with adaptive-L sizing; `None` keeps the plain + /// `Knn` behavior. Ignored on the flat-scan path. + #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], @@ -835,6 +897,7 @@ where beam_width: Option, vector_filter: Option>, is_flat_search: bool, + adaptive_l: Option, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; return_list_size as usize]; @@ -853,6 +916,7 @@ where &mut associated_data, &vector_filter.unwrap_or(default_vector_filter::()), is_flat_search, + adaptive_l, )?; let mut search_result = SearchResult { @@ -877,6 +941,10 @@ where /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. + /// + /// `adaptive_l` routes the graph path through `filter_search` + /// (inline label-filtered search). `None` runs plain `Knn`. Has no + /// effect on the flat-scan path (`is_flat_search = true`). #[allow(clippy::too_many_arguments)] pub(crate) fn search_internal( &self, @@ -890,6 +958,7 @@ where associated_data: &mut [Data::AssociatedDataType], vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), is_flat_search: bool, + adaptive_l: Option, ) -> ANNResult { let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices[..k_value], @@ -909,6 +978,16 @@ where l, &mut result_output_buffer, ))? + } else if adaptive_l.is_some() { + let knn_search = Knn::new(k, l, beam_width)?; + self.runtime.block_on(self.filter_search( + &strategy, + query, + knn_search, + vector_filter, + adaptive_l, + &mut result_output_buffer, + ))? } else { let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( @@ -1311,6 +1390,7 @@ mod disk_provider_tests { &mut associated_data, &(|_| true), false, + None, // adaptive_l ); // Calculate the range of the truth_result for this query @@ -1357,7 +1437,7 @@ mod disk_provider_tests { .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + .search(query, params.k as u32, params.l as u32, beam_width, None, false, None) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1469,6 +1549,7 @@ mod disk_provider_tests { &mut associated_data, &|_| true, false, + None, // adaptive_l ); assert!(result.is_err()); @@ -1538,6 +1619,7 @@ mod disk_provider_tests { Some(4), None, false, + None, // adaptive_l ); assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); @@ -1877,6 +1959,7 @@ mod disk_provider_tests { &mut associated_data, &vector_filter, is_flat_search, + None, // adaptive_l ); assert!(result.is_ok(), "Expected search to succeed"); @@ -1898,6 +1981,7 @@ mod disk_provider_tests { None, // beam_width Some(Box::new(vector_filter)), is_flat_search, + None, // adaptive_l ); assert!(result_with_filter.is_ok(), "Expected search to succeed"); @@ -1926,6 +2010,120 @@ mod disk_provider_tests { ); } + // =========================================================================== + // Inline filter + AdaptiveL behavioral tests + // =========================================================================== + // + // Two basic invariants from the design review: + // + // 1. `adaptive_l = Some(_)` with an always-true predicate visits every + // candidate as a "match," computes specificity = 100%, never triggers + // a resize, and produces the same top-k as plain `Knn`. This is the + // "no-op equivalence" guard. + // + // 2. `adaptive_l = Some(_)` with a selective predicate must produce a + // valid result set whose IDs all satisfy the predicate. Doesn't assert + // recall@k (would need filter-selective ground truth) — just that the + // inline path runs end-to-end and produces filter-conforming output. + + #[test] + fn test_adaptive_l_with_no_filter_matches_plain_knn() { + let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); + let search_engine = create_disk_index_searcher::( + CreateDiskIndexSearcherParams { + max_thread_num: 1, + pq_pivot_file_path: TEST_PQ_PIVOT_128DIM, + pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM, + index_path: TEST_INDEX_128DIM, + index_path_prefix: TEST_INDEX_PREFIX_128DIM, + ..Default::default() + }, + &storage_provider, + ); + let query = vec![0.1f32; 128]; + + let plain = search_engine + .search(&query, 10, 10, None, None, false, None) + .expect("plain Knn must succeed"); + + let inline_no_filter = search_engine + .search( + &query, + 10, + 10, + None, + Some(Box::new(|_| true)), + false, + Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + ) + .expect("inline filter with accept-all predicate must succeed"); + + let plain_ids: Vec = plain.results.iter().map(|r| r.vertex_id).collect(); + let inline_ids: Vec = inline_no_filter.results.iter().map(|r| r.vertex_id).collect(); + + assert_eq!( + plain.stats.result_count, inline_no_filter.stats.result_count, + "no-filter inline path must return same result count as plain Knn" + ); + assert_eq!( + plain_ids, inline_ids, + "no-filter inline path must return the same top-k IDs as plain Knn" + ); + } + + #[test] + fn test_adaptive_l_with_selective_predicate_returns_only_matches() { + let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); + let search_engine = create_disk_index_searcher::( + CreateDiskIndexSearcherParams { + max_thread_num: 1, + pq_pivot_file_path: TEST_PQ_PIVOT_128DIM, + pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM, + index_path: TEST_INDEX_128DIM, + index_path_prefix: TEST_INDEX_PREFIX_128DIM, + ..Default::default() + }, + &storage_provider, + ); + let query = vec![0.1f32; 128]; + // Predicate from `test_search_with_vector_filter::case_4` — three IDs + // known to be in the unfiltered top-10 for this query+fixture. + let predicate = |id: &u32| *id == 72 || *id == 87 || *id == 170; + + let result = search_engine + .search( + &query, + 10, + 10, + None, + Some(Box::new(predicate)), + false, + Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + ) + .expect("inline filter search with AdaptiveL must succeed"); + + // `result.results` is pre-allocated to `return_list_size`; only the + // first `result_count` entries are populated. The trailing entries + // are default zeros — not search output — so slice before asserting. + let count = result.stats.result_count as usize; + let ids: Vec = result + .results + .iter() + .take(count) + .map(|r| r.vertex_id) + .collect(); + for id in &ids { + assert!( + predicate(id), + "AdaptiveL result must only contain predicate-matching IDs; got {id} in {ids:?}" + ); + } + assert!( + !ids.is_empty(), + "AdaptiveL on a fixture with reachable matches must return at least one match" + ); + } + #[test] fn test_beam_search_respects_io_limit() { let io_limit = 11; // Set a small IO limit for testing diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 8bbdb1c8f..3b74bb41d 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -260,6 +260,7 @@ where Some(parameters.beam_width as usize), Some(vector_filter_function), parameters.is_flat_search, + None, // adaptive_l — disk-side AdaptiveL not yet exposed via CLI ); match result { From 1e82dd7bd7545a74e40d59efb267cfd17a73e379 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Tue, 16 Jun 2026 16:18:38 +0800 Subject: [PATCH 02/14] Use SearchPlan to avoid invalid parameter combinations --- .../src/backend/disk_index/search.rs | 28 ++- diskann-disk/src/build/builder/core.rs | 4 +- diskann-disk/src/search/mod.rs | 1 + .../src/search/provider/disk_provider.rs | 172 ++++++++++-------- diskann-tools/src/utils/search_disk_index.rs | 32 ++-- 5 files changed, 138 insertions(+), 99 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index bc7692807..e007a711f 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -13,8 +13,12 @@ use diskann::utils::VectorRepr; use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, - search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + search::{ + plan::SearchPlan, + provider::{ + disk_provider::DiskIndexSearcher, + disk_vertex_provider_factory::DiskVertexProviderFactory, + }, }, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, @@ -264,11 +268,17 @@ where zipped.for_each_in_pool( pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { - let vector_filter = if search_params.vector_filters_file.is_none() { - None - } else { - Some(Box::new(move |vid: &u32| vf.contains(vid)) - as Box bool + Send + Sync>) + // Construct the plan from the JSON-driven + // `(is_flat_search, has_filter)` pair. JSON config doesn't + // expose AdaptiveL yet, so `InlineFilter` is unreachable here. + let has_filter = search_params.vector_filters_file.is_some(); + let plan: SearchPlan<'_> = match (search_params.is_flat_search, has_filter) { + (true, false) => SearchPlan::flat(), + (true, true) => SearchPlan::flat_filtered(move |vid: &u32| vf.contains(vid)), + (false, false) => SearchPlan::graph(), + (false, true) => { + SearchPlan::graph_filtered(move |vid: &u32| vf.contains(vid)) + } }; match searcher.search( @@ -276,9 +286,7 @@ where search_params.recall_at, l, Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, - None, // adaptive_l — disk-side AdaptiveL not yet exposed via JSON + plan, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index 1a114eb69..6696877e7 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1089,9 +1089,7 @@ pub(crate) mod disk_index_builder_tests { &mut indices, &mut distances, &mut associated_data, - &|_| true, - false, - None, // adaptive_l + &crate::search::plan::SearchPlan::graph(), ); diskann_providers::test_utils::assert_top_k_exactly_match( diff --git a/diskann-disk/src/search/mod.rs b/diskann-disk/src/search/mod.rs index 54f710c7f..e01983c58 100644 --- a/diskann-disk/src/search/mod.rs +++ b/diskann-disk/src/search/mod.rs @@ -5,6 +5,7 @@ //! Model module containing data structures, providers, and traits for disk index operations +pub mod plan; pub mod pq; pub mod provider; pub mod traits; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index bb599a81a..0f39177eb 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -42,8 +42,8 @@ use tracing::debug; use crate::{ data_model::{CachingStrategy, GraphHeader}, - filter_parameter::{default_vector_filter, VectorFilter}, search::{ + plan::SearchPlan, provider::disk_vertex_provider_factory::DiskVertexProviderFactory, traits::{VertexProvider, VertexProviderFactory}, }, @@ -214,7 +214,11 @@ where { // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads. io_tracker: IOTracker, - vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), // Fn param is u32 as we validate "VectorIdType = u32" everywhere in this provider in trait bounds. + /// Borrowed predicate from `SearchPlan::predicate()`. `None` means + /// accept-all — downstream consumers short-circuit via `.map_or(true, ...)` + /// and skip the dyn-fn call entirely. Fn param is `u32` because + /// `VectorIdType = u32` is enforced by trait bounds throughout this provider. + vector_filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: &'a ProviderFactory, @@ -262,11 +266,13 @@ impl IOTracker { #[derive(Clone, Copy)] pub struct RerankAndFilter<'a> { - filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + /// `None` means accept-all; `post_process` short-circuits before invoking + /// the closure, so the unfiltered hot path pays no dyn-fn cost per node. + filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>, } impl<'a> RerankAndFilter<'a> { - fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { + fn new(filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>) -> Self { Self { filter } } } @@ -323,7 +329,9 @@ where let mut uncached_ids = Vec::new(); let mut reranked = candidates .map(|n| n.id) - .filter(|id| (self.filter)(id)) + // `None` short-circuits to `true` — no dyn-fn call when there's + // no predicate set, which is the common unfiltered hot path. + .filter(|id| self.filter.map_or(true, |f| f(id))) .filter_map(|n| { if let Some(entry) = accessor.scratch.distance_cache.get(&n) { Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0))) @@ -766,10 +774,14 @@ where }) } - /// Helper method to create a DiskSearchStrategy with common parameters + /// Helper method to create a DiskSearchStrategy with common parameters. + /// + /// `vector_filter = None` means accept-all and propagates as `None` + /// through the strategy and downstream consumers, so the unfiltered hot + /// path pays no dyn-fn call per node. fn search_strategy<'a>( &'a self, - vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), + vector_filter: Option<&'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync)>, ) -> DiskSearchStrategy<'a, Data, ProviderFactory> { DiskSearchStrategy { io_tracker: IOTracker::default(), @@ -782,13 +794,16 @@ where /// Perform a brute-force linear scan of all points in the index, returning the /// nearest neighbors that pass `vector_filter`. /// + /// `vector_filter = None` scans every vector (recall baseline) and skips + /// the per-ID dyn-fn call entirely. + /// /// The top `neighbors_before_reranking` candidates from the quantized scan will be /// provided to full-precision reranking. async fn flat_search( &self, strategy: &DiskSearchStrategy<'_, Data, ProviderFactory>, query: &[Data::VectorDataType], - vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync), + vector_filter: Option<&(dyn Fn(&u32) -> bool + Send + Sync)>, neighbors_before_reranking: usize, output: &mut OB, ) -> ANNResult @@ -819,7 +834,10 @@ where let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); let mut cmps = 0u32; - let mut iter = (0..provider.num_points as u32).filter(vector_filter); + // `None` short-circuits to `true` — no dyn-fn call per node on the + // unfiltered (recall-baseline) path. + let mut iter = (0..provider.num_points as u32) + .filter(|id| vector_filter.map_or(true, |f| f(id))); loop { id_buffer.clear(); id_buffer.extend(iter.by_ref().take(batch_size)); @@ -885,19 +903,17 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. /// - /// `adaptive_l = Some(_)` routes the graph path through inline - /// label-filtered search with adaptive-L sizing; `None` keeps the plain - /// `Knn` behavior. Ignored on the flat-scan path. - #[allow(clippy::too_many_arguments)] + /// The algorithm + filter combination is encoded by `plan` — see + /// [`SearchPlan`] for the available variants. The plan replaces the + /// previous `(vector_filter, is_flat_search, adaptive_l)` parameter + /// triple and makes invalid combinations unrepresentable. pub fn search( &self, query: &[Data::VectorDataType], return_list_size: u32, search_list_size: u32, beam_width: Option, - vector_filter: Option>, - is_flat_search: bool, - adaptive_l: Option, + plan: SearchPlan<'_>, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; return_list_size as usize]; @@ -914,9 +930,7 @@ where &mut indices, &mut distances, &mut associated_data, - &vector_filter.unwrap_or(default_vector_filter::()), - is_flat_search, - adaptive_l, + &plan, )?; let mut search_result = SearchResult { @@ -942,9 +956,10 @@ where /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. /// - /// `adaptive_l` routes the graph path through `filter_search` - /// (inline label-filtered search). `None` runs plain `Knn`. Has no - /// effect on the flat-scan path (`is_flat_search = true`). + /// Dispatches on `plan` variants: `FlatScan` → linear scan; `Graph` → + /// plain `Knn` with the optional post-filter applied during rerank; + /// `InlineFilter` → `filter_search` (`InlineFilterSearch` with optional + /// `AdaptiveL`). #[allow(clippy::too_many_arguments)] pub(crate) fn search_internal( &self, @@ -956,9 +971,7 @@ where indices: &mut [u32], distances: &mut [f32], associated_data: &mut [Data::AssociatedDataType], - vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), - is_flat_search: bool, - adaptive_l: Option, + plan: &SearchPlan<'_>, ) -> ANNResult { let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices[..k_value], @@ -966,37 +979,44 @@ where &mut associated_data[..k_value], ); - let strategy = self.search_strategy(vector_filter); + // `None` predicate propagates through strategy and downstream + // consumers — they short-circuit before the dyn-fn call, so the + // unfiltered hot path pays no per-node closure cost. + let predicate = plan.predicate(); + + let strategy = self.search_strategy(predicate); let timer = Instant::now(); let k = k_value; let l = search_list_size as usize; - let stats = if is_flat_search { - self.runtime.block_on(self.flat_search( + let stats = match plan { + SearchPlan::FlatScan { .. } => self.runtime.block_on(self.flat_search( &strategy, query, - vector_filter, + predicate, l, &mut result_output_buffer, - ))? - } else if adaptive_l.is_some() { - let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.filter_search( - &strategy, - query, - knn_search, - vector_filter, - adaptive_l, - &mut result_output_buffer, - ))? - } else { - let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.index.search( - knn_search, - &strategy, - &DefaultContext, - query, - &mut result_output_buffer, - ))? + ))?, + SearchPlan::Graph { .. } => { + let knn_search = Knn::new(k, l, beam_width)?; + self.runtime.block_on(self.index.search( + knn_search, + &strategy, + &DefaultContext, + query, + &mut result_output_buffer, + ))? + } + SearchPlan::InlineFilter { predicate: p, adaptive_l } => { + let knn_search = Knn::new(k, l, beam_width)?; + self.runtime.block_on(self.filter_search( + &strategy, + query, + knn_search, + p.as_ref(), + adaptive_l.clone(), + &mut result_output_buffer, + ))? + } }; query_stats.total_comparisons = stats.cmps; query_stats.search_hops = stats.hops; @@ -1388,9 +1408,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - &(|_| true), - false, - None, // adaptive_l + &SearchPlan::graph(), ); // Calculate the range of the truth_result for this query @@ -1437,7 +1455,7 @@ mod disk_provider_tests { .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false, None) + .search(query, params.k as u32, params.l as u32, beam_width, SearchPlan::graph()) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1547,9 +1565,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - &|_| true, - false, - None, // adaptive_l + &SearchPlan::graph(), ); assert!(result.is_err()); @@ -1582,7 +1598,7 @@ mod disk_provider_tests { &mut distances, &mut associated_data, ); - let strategy = search_engine.search_strategy(&|_| true); + let strategy = search_engine.search_strategy(None); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); let recorded_search = @@ -1617,9 +1633,7 @@ mod disk_provider_tests { return_list_size, search_list_size, Some(4), - None, - false, - None, // adaptive_l + SearchPlan::graph(), ); assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); @@ -1705,7 +1719,7 @@ mod disk_provider_tests { &mut distances, &mut associated_data, ); - let strategy = search_engine.search_strategy(&|_| true); + let strategy = search_engine.search_strategy(None); // Create diverse search parameters with attribute provider let diverse_params = DiverseSearchParams::new( @@ -1948,6 +1962,16 @@ mod disk_provider_tests { let mut distances = vec![0f32; 10]; let mut associated_data = vec![(); 10]; + // Build the same `SearchPlan` twice. `vector_filter` is a `fn` pointer + // (Copy), so each call reconstructs a fresh plan with the same filter. + let make_plan = || -> SearchPlan<'static> { + if is_flat_search { + SearchPlan::flat_filtered(vector_filter) + } else { + SearchPlan::graph_filtered(vector_filter) + } + }; + let result = search_engine.search_internal( &query, 10, @@ -1957,9 +1981,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - &vector_filter, - is_flat_search, - None, // adaptive_l + &make_plan(), ); assert!(result.is_ok(), "Expected search to succeed"); @@ -1979,9 +2001,7 @@ mod disk_provider_tests { 10, 10, None, // beam_width - Some(Box::new(vector_filter)), - is_flat_search, - None, // adaptive_l + make_plan(), ); assert!(result_with_filter.is_ok(), "Expected search to succeed"); @@ -2043,7 +2063,7 @@ mod disk_provider_tests { let query = vec![0.1f32; 128]; let plain = search_engine - .search(&query, 10, 10, None, None, false, None) + .search(&query, 10, 10, None, SearchPlan::graph()) .expect("plain Knn must succeed"); let inline_no_filter = search_engine @@ -2052,9 +2072,10 @@ mod disk_provider_tests { 10, 10, None, - Some(Box::new(|_| true)), - false, - Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + SearchPlan::inline_filter( + |_| true, + Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + ), ) .expect("inline filter with accept-all predicate must succeed"); @@ -2096,9 +2117,10 @@ mod disk_provider_tests { 10, 10, None, - Some(Box::new(predicate)), - false, - Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + SearchPlan::inline_filter( + predicate, + Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), + ), ) .expect("inline filter search with AdaptiveL must succeed"); @@ -2152,7 +2174,7 @@ mod disk_provider_tests { &mut associated_data, ); - let strategy = search_engine.search_strategy(&|_| true); + let strategy = search_engine.search_strategy(None); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 3b74bb41d..14df6f1fc 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -8,8 +8,12 @@ use std::{collections::HashSet, sync::atomic::AtomicBool, time::Instant}; use diskann::utils::IntoUsize; use diskann_disk::{ data_model::{CachingStrategy, GraphDataType}, - search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + search::{ + plan::SearchPlan, + provider::{ + disk_provider::DiskIndexSearcher, + disk_vertex_provider_factory::DiskVertexProviderFactory, + }, }, storage::disk_index_reader::DiskIndexReader, utils::{ @@ -246,21 +250,27 @@ where (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats), result_count, )| { - let vector_filter_function: Box bool + Send + Sync> = - if parameters.vector_filters_file.is_none() { - Box::new(|_: &u32| true) - } else { - Box::new(move |vector_id: &u32| vector_filter.contains(vector_id)) - }; + // Construct the plan from the CLI-driven + // `(is_flat_search, has_filter)` pair. CLI doesn't expose + // AdaptiveL yet, so `InlineFilter` is unreachable here. + let has_filter = parameters.vector_filters_file.is_some(); + let plan: SearchPlan<'_> = match (parameters.is_flat_search, has_filter) { + (true, false) => SearchPlan::flat(), + (true, true) => SearchPlan::flat_filtered(move |vid: &u32| { + vector_filter.contains(vid) + }), + (false, false) => SearchPlan::graph(), + (false, true) => SearchPlan::graph_filtered(move |vid: &u32| { + vector_filter.contains(vid) + }), + }; let result = searcher.search( query, parameters.recall_at, l, Some(parameters.beam_width as usize), - Some(vector_filter_function), - parameters.is_flat_search, - None, // adaptive_l — disk-side AdaptiveL not yet exposed via CLI + plan, ); match result { From 2d239f5ac02c03769f9a558977f462cad8b08b51 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Tue, 16 Jun 2026 23:46:12 +0800 Subject: [PATCH 03/14] Add SearchPlan file --- diskann-disk/src/search/plan.rs | 179 ++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 diskann-disk/src/search/plan.rs diff --git a/diskann-disk/src/search/plan.rs b/diskann-disk/src/search/plan.rs new file mode 100644 index 000000000..716314d81 --- /dev/null +++ b/diskann-disk/src/search/plan.rs @@ -0,0 +1,179 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Top-level disk search plan. +//! +//! Replaces the previous `(vector_filter, is_flat_search, adaptive_l)` +//! parameter triple on `DiskIndexSearcher::search` with a sum type, making +//! invalid combinations (flat scan with adaptive L, inline filter without a +//! predicate) unrepresentable. + +use diskann::graph::search::AdaptiveL; + +/// Owned closure used to filter vector IDs during disk search. +/// +/// The `&u32` argument is the disk-side internal/external ID (they coincide +/// on the disk path by construction). +pub type SearchPredicate<'a> = Box bool + Send + Sync + 'a>; + +/// Top-level disk search plan. +/// +/// Three variants encode the algorithm + filter combination: +/// +/// * `FlatScan` — brute-force linear scan, with or without an inline filter. +/// * `Graph` — plain greedy beam search; the optional filter is applied as a +/// hard post-filter during reranking (no traversal-time effect). +/// * `InlineFilter` — label-filtered graph search; the predicate is consulted +/// at visit time (not just during rerank). `adaptive_l = Some(_)` grows the +/// beam mid-search if the observed match specificity is low. +/// +/// `AdaptiveL` is unreachable except via `InlineFilter` — flat scan doesn't +/// benefit from beam widening, and plain graph search without inline tracking +/// can't compute specificity. `InlineFilter` always carries a predicate; an +/// inline filter with no filter degrades to plain Knn with extra bookkeeping +/// (verified — same top-k, slower), so the variant requires one explicitly. +pub enum SearchPlan<'a> { + /// Brute-force linear scan. `Some(p)` applies `p` inline; `None` scans + /// every vector (recall baseline). + FlatScan { filter: Option> }, + + /// Plain greedy beam search. The optional post-filter is applied during + /// reranking via `RerankAndFilter`; traversal is identical to the + /// unfiltered case. + Graph { filter: Option> }, + + /// Inline label-filtered graph search. The predicate is consulted at + /// visit time (`QueryLabelProvider::on_visit`) and again at rerank time. + /// `adaptive_l = Some(_)` enables mid-search beam widening based on + /// observed match specificity. + InlineFilter { + predicate: SearchPredicate<'a>, + adaptive_l: Option, + }, +} + +impl<'a> SearchPlan<'a> { + /// Flat scan over all vectors. Recall baseline. + pub fn flat() -> Self { + Self::FlatScan { filter: None } + } + + /// Flat scan restricted to vectors that satisfy `predicate`. + pub fn flat_filtered(predicate: F) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + Self::FlatScan { + filter: Some(Box::new(predicate)), + } + } + + /// Plain greedy graph search; no filter. + pub fn graph() -> Self { + Self::Graph { filter: None } + } + + /// Plain greedy graph search with a hard post-filter applied during + /// reranking. Traversal is unaffected. + pub fn graph_filtered(predicate: F) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + Self::Graph { + filter: Some(Box::new(predicate)), + } + } + + /// Inline label-filtered graph search. `adaptive_l = Some(_)` enables + /// mid-search beam widening; `None` runs inline tracking only (no + /// resizing). + pub fn inline_filter(predicate: F, adaptive_l: Option) -> Self + where + F: Fn(&u32) -> bool + Send + Sync + 'a, + { + Self::InlineFilter { + predicate: Box::new(predicate), + adaptive_l, + } + } + + /// Borrow the predicate carried by the plan, if any. + /// + /// `FlatScan { None }` and `Graph { None }` return `None` (accept-all, + /// consumers short-circuit). `FlatScan { Some(p) }`, `Graph { Some(p) }`, + /// and `InlineFilter { predicate: p, .. }` all return `Some(p)`. + pub fn predicate(&self) -> Option<&(dyn Fn(&u32) -> bool + Send + Sync)> { + match self { + SearchPlan::FlatScan { filter: None } | SearchPlan::Graph { filter: None } => None, + SearchPlan::FlatScan { filter: Some(p) } + | SearchPlan::Graph { filter: Some(p) } + | SearchPlan::InlineFilter { predicate: p, .. } => Some(p.as_ref()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn flat_no_filter_constructor() { + let plan = SearchPlan::flat(); + assert!(matches!(plan, SearchPlan::FlatScan { filter: None })); + assert!(plan.predicate().is_none()); + } + + #[test] + fn flat_filtered_constructor() { + let plan = SearchPlan::flat_filtered(|id| *id == 5); + let p = plan.predicate().expect("FlatScan { Some } exposes predicate"); + assert!(p(&5)); + assert!(!p(&4)); + } + + #[test] + fn graph_no_filter_constructor() { + let plan = SearchPlan::graph(); + assert!(matches!(plan, SearchPlan::Graph { filter: None })); + assert!(plan.predicate().is_none()); + } + + #[test] + fn graph_filtered_constructor() { + let plan = SearchPlan::graph_filtered(|id| *id == 7); + let p = plan.predicate().expect("Graph { Some } exposes predicate"); + assert!(p(&7)); + assert!(!p(&6)); + } + + #[test] + fn inline_filter_constructor_without_adaptive_l() { + let plan = SearchPlan::inline_filter(|id| *id == 3, None); + match &plan { + SearchPlan::InlineFilter { + adaptive_l: None, .. + } => {} + _ => panic!("expected InlineFilter with adaptive_l = None"), + } + let p = plan + .predicate() + .expect("InlineFilter always exposes a predicate"); + assert!(p(&3)); + assert!(!p(&2)); + } + + #[test] + fn inline_filter_constructor_with_adaptive_l() { + let adaptive = AdaptiveL::new(5, 16.0).expect("valid AdaptiveL"); + let plan = SearchPlan::inline_filter(|id| *id == 11, Some(adaptive)); + match &plan { + SearchPlan::InlineFilter { + adaptive_l: Some(_), + .. + } => {} + _ => panic!("expected InlineFilter with adaptive_l = Some"), + } + } +} From fc0b540efa9735d87d4d29f4a148c8f05360870a Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Wed, 17 Jun 2026 17:40:50 +0800 Subject: [PATCH 04/14] resolve comments --- .../example/disk-index-filter.json | 67 +++++++++++++ diskann-benchmark/src/main.rs | 31 ++++++ diskann-disk/src/search/plan.rs | 72 +++++--------- .../src/search/provider/disk_provider.rs | 98 ++++++++++--------- 4 files changed, 174 insertions(+), 94 deletions(-) create mode 100644 diskann-benchmark/example/disk-index-filter.json diff --git a/diskann-benchmark/example/disk-index-filter.json b/diskann-benchmark/example/disk-index-filter.json new file mode 100644 index 000000000..a3f35ca91 --- /dev/null +++ b/diskann-benchmark/example/disk-index-filter.json @@ -0,0 +1,67 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_filter_graph" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": "disk_index_10pts_idx_uint32_range_res_r_100000.bin" + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_filter_flat" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": true, + "distance": "squared_l2", + "vector_filters_file": "disk_index_10pts_idx_uint32_range_res_r_100000.bin" + } + } + } + ] +} diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index aa21c4ae5..63c8cc3d1 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -578,6 +578,37 @@ mod tests { let raw = value_from_file(&example_directory().join("graph-index-inline-filter.json")); run_integration_test(raw); } + + /// Filtered disk search end-to-end: drives the disk-index backend through + /// `disk-index-filter.json` + #[test] + #[cfg(feature = "disk-index")] + fn disk_index_filter_integration() { + let mut raw = value_from_file(&example_directory().join("disk-index-filter.json")); + prefix_search_directories(&mut raw, &root_directory()); + + let tempdir = tempfile::tempdir().unwrap(); + let input_path = tempdir.path().join("disk-index-filter.json"); + save_to_file(&input_path, &raw); + let output_path = tempdir.path().join("output.json"); + + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + allow_debug: true, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + let result = cli.run(&mut output); + let output_str = String::from_utf8(output.into_inner()).unwrap(); + println!("output = {}", output_str); + result.expect("disk-index-filter run failed"); + + assert!(output_path.exists()); + let results: Vec = load_from_file(&output_path); + assert_eq!(results.len(), num_jobs(&raw)); + } #[test] fn graph_index_filter_integration_with_gt_compute() { diff --git a/diskann-disk/src/search/plan.rs b/diskann-disk/src/search/plan.rs index 716314d81..4ace8ad17 100644 --- a/diskann-disk/src/search/plan.rs +++ b/diskann-disk/src/search/plan.rs @@ -5,10 +5,9 @@ //! Top-level disk search plan. //! -//! Replaces the previous `(vector_filter, is_flat_search, adaptive_l)` -//! parameter triple on `DiskIndexSearcher::search` with a sum type, making -//! invalid combinations (flat scan with adaptive L, inline filter without a -//! predicate) unrepresentable. +//! Encodes the algorithm + filter combination for `DiskIndexSearcher::search` +//! as a sum type so that invalid combinations (flat scan with adaptive L, +//! inline filter without a predicate) are unrepresentable at the API boundary. use diskann::graph::search::AdaptiveL; @@ -28,26 +27,12 @@ pub type SearchPredicate<'a> = Box bool + Send + Sync + 'a>; /// * `InlineFilter` — label-filtered graph search; the predicate is consulted /// at visit time (not just during rerank). `adaptive_l = Some(_)` grows the /// beam mid-search if the observed match specificity is low. -/// -/// `AdaptiveL` is unreachable except via `InlineFilter` — flat scan doesn't -/// benefit from beam widening, and plain graph search without inline tracking -/// can't compute specificity. `InlineFilter` always carries a predicate; an -/// inline filter with no filter degrades to plain Knn with extra bookkeeping -/// (verified — same top-k, slower), so the variant requires one explicitly. + pub enum SearchPlan<'a> { - /// Brute-force linear scan. `Some(p)` applies `p` inline; `None` scans - /// every vector (recall baseline). FlatScan { filter: Option> }, - /// Plain greedy beam search. The optional post-filter is applied during - /// reranking via `RerankAndFilter`; traversal is identical to the - /// unfiltered case. Graph { filter: Option> }, - /// Inline label-filtered graph search. The predicate is consulted at - /// visit time (`QueryLabelProvider::on_visit`) and again at rerank time. - /// `adaptive_l = Some(_)` enables mid-search beam widening based on - /// observed match specificity. InlineFilter { predicate: SearchPredicate<'a>, adaptive_l: Option, @@ -98,20 +83,6 @@ impl<'a> SearchPlan<'a> { adaptive_l, } } - - /// Borrow the predicate carried by the plan, if any. - /// - /// `FlatScan { None }` and `Graph { None }` return `None` (accept-all, - /// consumers short-circuit). `FlatScan { Some(p) }`, `Graph { Some(p) }`, - /// and `InlineFilter { predicate: p, .. }` all return `Some(p)`. - pub fn predicate(&self) -> Option<&(dyn Fn(&u32) -> bool + Send + Sync)> { - match self { - SearchPlan::FlatScan { filter: None } | SearchPlan::Graph { filter: None } => None, - SearchPlan::FlatScan { filter: Some(p) } - | SearchPlan::Graph { filter: Some(p) } - | SearchPlan::InlineFilter { predicate: p, .. } => Some(p.as_ref()), - } - } } #[cfg(test)] @@ -122,30 +93,36 @@ mod tests { fn flat_no_filter_constructor() { let plan = SearchPlan::flat(); assert!(matches!(plan, SearchPlan::FlatScan { filter: None })); - assert!(plan.predicate().is_none()); } #[test] fn flat_filtered_constructor() { let plan = SearchPlan::flat_filtered(|id| *id == 5); - let p = plan.predicate().expect("FlatScan { Some } exposes predicate"); - assert!(p(&5)); - assert!(!p(&4)); + match &plan { + SearchPlan::FlatScan { filter: Some(p) } => { + assert!(p(&5)); + assert!(!p(&4)); + } + _ => panic!("expected FlatScan with filter"), + } } #[test] fn graph_no_filter_constructor() { let plan = SearchPlan::graph(); assert!(matches!(plan, SearchPlan::Graph { filter: None })); - assert!(plan.predicate().is_none()); } #[test] fn graph_filtered_constructor() { let plan = SearchPlan::graph_filtered(|id| *id == 7); - let p = plan.predicate().expect("Graph { Some } exposes predicate"); - assert!(p(&7)); - assert!(!p(&6)); + match &plan { + SearchPlan::Graph { filter: Some(p) } => { + assert!(p(&7)); + assert!(!p(&6)); + } + _ => panic!("expected Graph with filter"), + } } #[test] @@ -153,15 +130,14 @@ mod tests { let plan = SearchPlan::inline_filter(|id| *id == 3, None); match &plan { SearchPlan::InlineFilter { - adaptive_l: None, .. - } => {} + predicate, + adaptive_l: None, + } => { + assert!(predicate(&3)); + assert!(!predicate(&2)); + } _ => panic!("expected InlineFilter with adaptive_l = None"), } - let p = plan - .predicate() - .expect("InlineFilter always exposes a predicate"); - assert!(p(&3)); - assert!(!p(&2)); } #[test] diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 0f39177eb..32a902ea4 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -214,11 +214,14 @@ where { // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads. io_tracker: IOTracker, - /// Borrowed predicate from `SearchPlan::predicate()`. `None` means - /// accept-all — downstream consumers short-circuit via `.map_or(true, ...)` - /// and skip the dyn-fn call entirely. Fn param is `u32` because - /// `VectorIdType = u32` is enforced by trait bounds throughout this provider. - vector_filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>, + /// Borrowed predicate consumed *only* by `default_post_processor()` → + /// `RerankAndFilter`. Other variants (`FlatScan`, `InlineFilter`) filter + /// earlier in their pipelines and install `None` here to avoid a + /// redundant second application. `None` means accept-all — the post- + /// processor short-circuits via `.map_or(true, ...)` and skips the + /// dyn-fn call entirely. Fn param is `u32` because `VectorIdType = u32` + /// is enforced by trait bounds throughout this provider. + postprocess_filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: &'a ProviderFactory, @@ -277,10 +280,8 @@ impl<'a> RerankAndFilter<'a> { } } -/// Adapter exposing the existing disk-side `&dyn Fn(&u32) -> bool` predicate -/// as a `QueryLabelProvider` for `InlineFilterSearch`. Lives entirely -/// inside `filter_search` — public `search()` keeps its existing predicate -/// type at the boundary. +/// Adapter exposing the disk-side `&dyn Fn(&u32) -> bool` predicate as a +/// `QueryLabelProvider` for `InlineFilterSearch`. struct PredicateLabelProvider<'a> { predicate: &'a (dyn Fn(&u32) -> bool + Send + Sync), } @@ -402,7 +403,7 @@ where type Processor = RerankAndFilter<'this>; fn default_post_processor(&'this self) -> Self::Processor { - RerankAndFilter::new(self.vector_filter) + RerankAndFilter::new(self.postprocess_filter) } } @@ -776,16 +777,18 @@ where /// Helper method to create a DiskSearchStrategy with common parameters. /// - /// `vector_filter = None` means accept-all and propagates as `None` - /// through the strategy and downstream consumers, so the unfiltered hot - /// path pays no dyn-fn call per node. + /// `postprocess_filter = None` means accept-all at rerank time. Callers pass + /// `None` on paths that filter earlier in their pipeline (`FlatScan` at + /// scan time, `InlineFilter` at visit time) to avoid re-applying the + /// predicate redundantly; only the plain `Graph` path installs a + /// predicate here fn search_strategy<'a>( &'a self, - vector_filter: Option<&'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync)>, + postprocess_filter: Option<&'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync)>, ) -> DiskSearchStrategy<'a, Data, ProviderFactory> { DiskSearchStrategy { io_tracker: IOTracker::default(), - vector_filter, + postprocess_filter, vertex_provider_factory: &self.vertex_provider_factory, scratch_pool: &self.scratch_pool, } @@ -902,11 +905,7 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. - /// - /// The algorithm + filter combination is encoded by `plan` — see - /// [`SearchPlan`] for the available variants. The plan replaces the - /// previous `(vector_filter, is_flat_search, adaptive_l)` parameter - /// triple and makes invalid combinations unrepresentable. + pub fn search( &self, query: &[Data::VectorDataType], @@ -955,11 +954,6 @@ where /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. - /// - /// Dispatches on `plan` variants: `FlatScan` → linear scan; `Graph` → - /// plain `Knn` with the optional post-filter applied during rerank; - /// `InlineFilter` → `filter_search` (`InlineFilterSearch` with optional - /// `AdaptiveL`). #[allow(clippy::too_many_arguments)] pub(crate) fn search_internal( &self, @@ -979,43 +973,55 @@ where &mut associated_data[..k_value], ); - // `None` predicate propagates through strategy and downstream - // consumers — they short-circuit before the dyn-fn call, so the - // unfiltered hot path pays no per-node closure cost. - let predicate = plan.predicate(); - - let strategy = self.search_strategy(predicate); let timer = Instant::now(); let k = k_value; let l = search_list_size as usize; - let stats = match plan { - SearchPlan::FlatScan { .. } => self.runtime.block_on(self.flat_search( - &strategy, - query, - predicate, - l, - &mut result_output_buffer, - ))?, - SearchPlan::Graph { .. } => { + + // * `FlatScan` — `flat_search` filters the scan iterator at + // construction; non-matching IDs never enter `best`. + // * `Graph` — plain greedy traversal doesn't consult any + // predicate, if predicate presented `RerankAndFilter` will filter out non-matching nodes. + // * `InlineFilter` — `InlineFilterSearch` only forwards `Accept` + // nodes into `matched_results`, no filtering in post-processing. + let (strategy, stats) = match plan { + SearchPlan::FlatScan { filter } => { + let strategy = self.search_strategy(None); + let stats = self.runtime.block_on(self.flat_search( + &strategy, + query, + filter.as_deref(), + l, + &mut result_output_buffer, + ))?; + (strategy, stats) + } + SearchPlan::Graph { filter } => { + let strategy = self.search_strategy(filter.as_deref()); let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.index.search( + let stats = self.runtime.block_on(self.index.search( knn_search, &strategy, &DefaultContext, query, &mut result_output_buffer, - ))? + ))?; + (strategy, stats) } - SearchPlan::InlineFilter { predicate: p, adaptive_l } => { + SearchPlan::InlineFilter { + predicate, + adaptive_l, + } => { + let strategy = self.search_strategy(None); let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.filter_search( + let stats = self.runtime.block_on(self.filter_search( &strategy, query, knn_search, - p.as_ref(), + predicate.as_ref(), adaptive_l.clone(), &mut result_output_buffer, - ))? + ))?; + (strategy, stats) } }; query_stats.total_comparisons = stats.cmps; From 0bdbe636116bab1fa637032e586613dcbc6440db Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Thu, 18 Jun 2026 11:22:00 +0800 Subject: [PATCH 05/14] Adapt InlineFilterSearch to new API --- .../src/search/provider/disk_provider.rs | 71 +++++++++++-------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 32a902ea4..a994642c0 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -18,8 +18,8 @@ use diskann::{ error::IntoANNResult, graph::{ self, + ext::labeled::{self, QueryLabelProvider}, glue::{self, DefaultPostProcessor, SearchPostProcess, SearchStrategy}, - index::QueryLabelProvider, search::{AdaptiveL, InlineFilterSearch, Knn}, search_output_buffer, DiskANNIndex, }, @@ -212,8 +212,8 @@ where Data: GraphDataType, ProviderFactory: VertexProviderFactory, { - // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads. - io_tracker: IOTracker, + // Borrowed from `search_internal` so the strategy can be passed by value + io_tracker: &'a IOTracker, /// Borrowed predicate consumed *only* by `default_post_processor()` → /// `RerankAndFilter`. Other variants (`FlatScan`, `InlineFilter`) filter /// earlier in their pipelines and install `None` here to avoid a @@ -784,10 +784,11 @@ where /// predicate here fn search_strategy<'a>( &'a self, + io_tracker: &'a IOTracker, postprocess_filter: Option<&'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync)>, ) -> DiskSearchStrategy<'a, Data, ProviderFactory> { DiskSearchStrategy { - io_tracker: IOTracker::default(), + io_tracker, postprocess_filter, vertex_provider_factory: &self.vertex_provider_factory, scratch_pool: &self.scratch_pool, @@ -882,9 +883,9 @@ where /// Reuses the same `DiskAccessor` surface as the plain `Knn` graph path: /// `start_point_distances` and `expand_beam`, both of which call /// `pq_distances` internally. - async fn filter_search( + async fn filter_search<'a, OB>( &self, - strategy: &DiskSearchStrategy<'_, Data, ProviderFactory>, + strategy: DiskSearchStrategy<'a, Data, ProviderFactory>, query: &[Data::VectorDataType], knn: Knn, vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync), @@ -897,9 +898,11 @@ where let label_provider = PredicateLabelProvider { predicate: vector_filter, }; - let search = InlineFilterSearch::new(knn, &label_provider, adaptive_l); + + let filtered_strategy = labeled::Filtered::new(strategy, &label_provider); + let search = InlineFilterSearch::new(knn, adaptive_l); self.index - .search(search, strategy, &DefaultContext, query, output) + .search(search, &filtered_strategy, &DefaultContext, query, output) .await } @@ -977,62 +980,64 @@ where let k = k_value; let l = search_list_size as usize; + let io_tracker = IOTracker::default(); + // * `FlatScan` — `flat_search` filters the scan iterator at // construction; non-matching IDs never enter `best`. // * `Graph` — plain greedy traversal doesn't consult any // predicate, if predicate presented `RerankAndFilter` will filter out non-matching nodes. // * `InlineFilter` — `InlineFilterSearch` only forwards `Accept` // nodes into `matched_results`, no filtering in post-processing. - let (strategy, stats) = match plan { + let stats = match plan { SearchPlan::FlatScan { filter } => { - let strategy = self.search_strategy(None); - let stats = self.runtime.block_on(self.flat_search( + let strategy = self.search_strategy(&io_tracker, None); + self.runtime.block_on(self.flat_search( &strategy, query, filter.as_deref(), l, &mut result_output_buffer, - ))?; - (strategy, stats) + ))? } SearchPlan::Graph { filter } => { - let strategy = self.search_strategy(filter.as_deref()); + let strategy = self.search_strategy(&io_tracker, filter.as_deref()); let knn_search = Knn::new(k, l, beam_width)?; - let stats = self.runtime.block_on(self.index.search( + self.runtime.block_on(self.index.search( knn_search, &strategy, &DefaultContext, query, &mut result_output_buffer, - ))?; - (strategy, stats) + ))? } SearchPlan::InlineFilter { predicate, adaptive_l, } => { - let strategy = self.search_strategy(None); + // Strategy is moved by value into `filter_search` so that the + // `labeled::Filtered` wrapper can own it; `io_tracker` keeps + // its counters reachable from this scope. + let strategy = self.search_strategy(&io_tracker, None); let knn_search = Knn::new(k, l, beam_width)?; - let stats = self.runtime.block_on(self.filter_search( - &strategy, + self.runtime.block_on(self.filter_search( + strategy, query, knn_search, predicate.as_ref(), adaptive_l.clone(), &mut result_output_buffer, - ))?; - (strategy, stats) + ))? } }; query_stats.total_comparisons = stats.cmps; query_stats.search_hops = stats.hops; query_stats.total_execution_time_us = timer.elapsed().as_micros(); - query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128; - query_stats.total_io_operations = strategy.io_tracker.io_count() as u32; - query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32; + query_stats.io_time_us = IOTracker::time(&io_tracker.io_time_us) as u128; + query_stats.total_io_operations = io_tracker.io_count() as u32; + query_stats.total_vertices_loaded = io_tracker.io_count() as u32; query_stats.query_pq_preprocess_time_us = - IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128; + IOTracker::time(&io_tracker.preprocess_time_us) as u128; query_stats.cpu_time_us = query_stats.total_execution_time_us - query_stats.io_time_us - query_stats.query_pq_preprocess_time_us; @@ -1604,7 +1609,8 @@ mod disk_provider_tests { &mut distances, &mut associated_data, ); - let strategy = search_engine.search_strategy(None); + let io_tracker = IOTracker::default(); + let strategy = search_engine.search_strategy(&io_tracker, None); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); let recorded_search = @@ -1725,7 +1731,8 @@ mod disk_provider_tests { &mut distances, &mut associated_data, ); - let strategy = search_engine.search_strategy(None); + let io_tracker = IOTracker::default(); + let strategy = search_engine.search_strategy(&io_tracker, None); // Create diverse search parameters with attribute provider let diverse_params = DiverseSearchParams::new( @@ -1772,7 +1779,10 @@ mod disk_provider_tests { &mut distances2, &mut associated_data2, ); - let strategy2 = search_engine.search_strategy(&|_| true); + let io_tracker2 = IOTracker::default(); + // Old API passed `&|_| true` (accept-all); after the Option refactor + // that's structurally equivalent to `None`. + let strategy2 = search_engine.search_strategy(&io_tracker2, None); let search_params2 = Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap(); @@ -2180,7 +2190,8 @@ mod disk_provider_tests { &mut associated_data, ); - let strategy = search_engine.search_strategy(None); + let io_tracker = IOTracker::default(); + let strategy = search_engine.search_strategy(&io_tracker, None); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); From be79ce58e8cd4c00e34d41acd4c74428fcb49c7a Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Thu, 18 Jun 2026 11:52:26 +0800 Subject: [PATCH 06/14] rename SearchPlan to SearchMode --- diskann-disk/src/search/mod.rs | 2 +- .../src/search/provider/disk_provider.rs | 46 +++++++++---------- .../src/search/{plan.rs => search_mode.rs} | 40 ++++++++-------- diskann-tools/src/utils/search_disk_index.rs | 16 +++---- 4 files changed, 52 insertions(+), 52 deletions(-) rename diskann-disk/src/search/{plan.rs => search_mode.rs} (81%) diff --git a/diskann-disk/src/search/mod.rs b/diskann-disk/src/search/mod.rs index e01983c58..8c3eb8274 100644 --- a/diskann-disk/src/search/mod.rs +++ b/diskann-disk/src/search/mod.rs @@ -5,7 +5,7 @@ //! Model module containing data structures, providers, and traits for disk index operations -pub mod plan; +pub mod search_mode; pub mod pq; pub mod provider; pub mod traits; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index a994642c0..3d7315e0b 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -43,7 +43,7 @@ use tracing::debug; use crate::{ data_model::{CachingStrategy, GraphHeader}, search::{ - plan::SearchPlan, + search_mode::SearchMode, provider::disk_vertex_provider_factory::DiskVertexProviderFactory, traits::{VertexProvider, VertexProviderFactory}, }, @@ -915,7 +915,7 @@ where return_list_size: u32, search_list_size: u32, beam_width: Option, - plan: SearchPlan<'_>, + mode: SearchMode<'_>, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; return_list_size as usize]; @@ -932,7 +932,7 @@ where &mut indices, &mut distances, &mut associated_data, - &plan, + &mode, )?; let mut search_result = SearchResult { @@ -968,7 +968,7 @@ where indices: &mut [u32], distances: &mut [f32], associated_data: &mut [Data::AssociatedDataType], - plan: &SearchPlan<'_>, + mode: &SearchMode<'_>, ) -> ANNResult { let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices[..k_value], @@ -988,8 +988,8 @@ where // predicate, if predicate presented `RerankAndFilter` will filter out non-matching nodes. // * `InlineFilter` — `InlineFilterSearch` only forwards `Accept` // nodes into `matched_results`, no filtering in post-processing. - let stats = match plan { - SearchPlan::FlatScan { filter } => { + let stats = match mode { + SearchMode::FlatScan { filter } => { let strategy = self.search_strategy(&io_tracker, None); self.runtime.block_on(self.flat_search( &strategy, @@ -999,7 +999,7 @@ where &mut result_output_buffer, ))? } - SearchPlan::Graph { filter } => { + SearchMode::Graph { filter } => { let strategy = self.search_strategy(&io_tracker, filter.as_deref()); let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( @@ -1010,11 +1010,11 @@ where &mut result_output_buffer, ))? } - SearchPlan::InlineFilter { + SearchMode::InlineFilter { predicate, adaptive_l, } => { - // Strategy is moved by value into `filter_search` so that the + // Strategy is passed by value into `filter_search` so that the // `labeled::Filtered` wrapper can own it; `io_tracker` keeps // its counters reachable from this scope. let strategy = self.search_strategy(&io_tracker, None); @@ -1419,7 +1419,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - &SearchPlan::graph(), + &SearchMode::graph(), ); // Calculate the range of the truth_result for this query @@ -1466,7 +1466,7 @@ mod disk_provider_tests { .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, SearchPlan::graph()) + .search(query, params.k as u32, params.l as u32, beam_width, SearchMode::graph()) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1576,7 +1576,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - &SearchPlan::graph(), + &SearchMode::graph(), ); assert!(result.is_err()); @@ -1645,7 +1645,7 @@ mod disk_provider_tests { return_list_size, search_list_size, Some(4), - SearchPlan::graph(), + SearchMode::graph(), ); assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); @@ -1978,13 +1978,13 @@ mod disk_provider_tests { let mut distances = vec![0f32; 10]; let mut associated_data = vec![(); 10]; - // Build the same `SearchPlan` twice. `vector_filter` is a `fn` pointer - // (Copy), so each call reconstructs a fresh plan with the same filter. - let make_plan = || -> SearchPlan<'static> { + // Build the same `SearchMode` twice. `vector_filter` is a `fn` pointer + // (Copy), so each call reconstructs a fresh mode with the same filter. + let make_mode = || -> SearchMode<'static> { if is_flat_search { - SearchPlan::flat_filtered(vector_filter) + SearchMode::flat_filtered(vector_filter) } else { - SearchPlan::graph_filtered(vector_filter) + SearchMode::graph_filtered(vector_filter) } }; @@ -1997,7 +1997,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, - &make_plan(), + &make_mode(), ); assert!(result.is_ok(), "Expected search to succeed"); @@ -2017,7 +2017,7 @@ mod disk_provider_tests { 10, 10, None, // beam_width - make_plan(), + make_mode(), ); assert!(result_with_filter.is_ok(), "Expected search to succeed"); @@ -2079,7 +2079,7 @@ mod disk_provider_tests { let query = vec![0.1f32; 128]; let plain = search_engine - .search(&query, 10, 10, None, SearchPlan::graph()) + .search(&query, 10, 10, None, SearchMode::graph()) .expect("plain Knn must succeed"); let inline_no_filter = search_engine @@ -2088,7 +2088,7 @@ mod disk_provider_tests { 10, 10, None, - SearchPlan::inline_filter( + SearchMode::inline_filter( |_| true, Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), ), @@ -2133,7 +2133,7 @@ mod disk_provider_tests { 10, 10, None, - SearchPlan::inline_filter( + SearchMode::inline_filter( predicate, Some(AdaptiveL::new(5, 16.0).expect("valid AdaptiveL")), ), diff --git a/diskann-disk/src/search/plan.rs b/diskann-disk/src/search/search_mode.rs similarity index 81% rename from diskann-disk/src/search/plan.rs rename to diskann-disk/src/search/search_mode.rs index 4ace8ad17..f476c007a 100644 --- a/diskann-disk/src/search/plan.rs +++ b/diskann-disk/src/search/search_mode.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -//! Top-level disk search plan. +//! Top-level disk search mode. //! //! Encodes the algorithm + filter combination for `DiskIndexSearcher::search` //! as a sum type so that invalid combinations (flat scan with adaptive L, @@ -17,7 +17,7 @@ use diskann::graph::search::AdaptiveL; /// on the disk path by construction). pub type SearchPredicate<'a> = Box bool + Send + Sync + 'a>; -/// Top-level disk search plan. +/// Top-level disk search mode. /// /// Three variants encode the algorithm + filter combination: /// @@ -28,7 +28,7 @@ pub type SearchPredicate<'a> = Box bool + Send + Sync + 'a>; /// at visit time (not just during rerank). `adaptive_l = Some(_)` grows the /// beam mid-search if the observed match specificity is low. -pub enum SearchPlan<'a> { +pub enum SearchMode<'a> { FlatScan { filter: Option> }, Graph { filter: Option> }, @@ -39,7 +39,7 @@ pub enum SearchPlan<'a> { }, } -impl<'a> SearchPlan<'a> { +impl<'a> SearchMode<'a> { /// Flat scan over all vectors. Recall baseline. pub fn flat() -> Self { Self::FlatScan { filter: None } @@ -91,15 +91,15 @@ mod tests { #[test] fn flat_no_filter_constructor() { - let plan = SearchPlan::flat(); - assert!(matches!(plan, SearchPlan::FlatScan { filter: None })); + let mode = SearchMode::flat(); + assert!(matches!(mode, SearchMode::FlatScan { filter: None })); } #[test] fn flat_filtered_constructor() { - let plan = SearchPlan::flat_filtered(|id| *id == 5); - match &plan { - SearchPlan::FlatScan { filter: Some(p) } => { + let mode = SearchMode::flat_filtered(|id| *id == 5); + match &mode { + SearchMode::FlatScan { filter: Some(p) } => { assert!(p(&5)); assert!(!p(&4)); } @@ -109,15 +109,15 @@ mod tests { #[test] fn graph_no_filter_constructor() { - let plan = SearchPlan::graph(); - assert!(matches!(plan, SearchPlan::Graph { filter: None })); + let mode = SearchMode::graph(); + assert!(matches!(mode, SearchMode::Graph { filter: None })); } #[test] fn graph_filtered_constructor() { - let plan = SearchPlan::graph_filtered(|id| *id == 7); - match &plan { - SearchPlan::Graph { filter: Some(p) } => { + let mode = SearchMode::graph_filtered(|id| *id == 7); + match &mode { + SearchMode::Graph { filter: Some(p) } => { assert!(p(&7)); assert!(!p(&6)); } @@ -127,9 +127,9 @@ mod tests { #[test] fn inline_filter_constructor_without_adaptive_l() { - let plan = SearchPlan::inline_filter(|id| *id == 3, None); - match &plan { - SearchPlan::InlineFilter { + let mode = SearchMode::inline_filter(|id| *id == 3, None); + match &mode { + SearchMode::InlineFilter { predicate, adaptive_l: None, } => { @@ -143,9 +143,9 @@ mod tests { #[test] fn inline_filter_constructor_with_adaptive_l() { let adaptive = AdaptiveL::new(5, 16.0).expect("valid AdaptiveL"); - let plan = SearchPlan::inline_filter(|id| *id == 11, Some(adaptive)); - match &plan { - SearchPlan::InlineFilter { + let mode = SearchMode::inline_filter(|id| *id == 11, Some(adaptive)); + match &mode { + SearchMode::InlineFilter { adaptive_l: Some(_), .. } => {} diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 14df6f1fc..f8b07e429 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -9,7 +9,7 @@ use diskann::utils::IntoUsize; use diskann_disk::{ data_model::{CachingStrategy, GraphDataType}, search::{ - plan::SearchPlan, + search_mode::SearchMode, provider::{ disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, @@ -250,17 +250,17 @@ where (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats), result_count, )| { - // Construct the plan from the CLI-driven + // Construct the mode from the CLI-driven // `(is_flat_search, has_filter)` pair. CLI doesn't expose // AdaptiveL yet, so `InlineFilter` is unreachable here. let has_filter = parameters.vector_filters_file.is_some(); - let plan: SearchPlan<'_> = match (parameters.is_flat_search, has_filter) { - (true, false) => SearchPlan::flat(), - (true, true) => SearchPlan::flat_filtered(move |vid: &u32| { + let mode: SearchMode<'_> = match (parameters.is_flat_search, has_filter) { + (true, false) => SearchMode::flat(), + (true, true) => SearchMode::flat_filtered(move |vid: &u32| { vector_filter.contains(vid) }), - (false, false) => SearchPlan::graph(), - (false, true) => SearchPlan::graph_filtered(move |vid: &u32| { + (false, false) => SearchMode::graph(), + (false, true) => SearchMode::graph_filtered(move |vid: &u32| { vector_filter.contains(vid) }), }; @@ -270,7 +270,7 @@ where parameters.recall_at, l, Some(parameters.beam_width as usize), - plan, + mode, ); match result { From f86d7bcb4317051a705beb5579ed0926ee635820 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Thu, 18 Jun 2026 17:18:10 +0800 Subject: [PATCH 07/14] update SearchMode in benchmark --- diskann-benchmark/src/disk_index/search.rs | 16 ++++++++-------- diskann-disk/src/build/builder/core.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index ba45357d8..2306cbc38 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -14,7 +14,7 @@ use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, search::{ - plan::SearchPlan, + search_mode::SearchMode, provider::{ disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, @@ -268,16 +268,16 @@ where zipped.for_each_in_pool( pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { - // Construct the plan from the JSON-driven + // Construct the mode from the JSON-driven // `(is_flat_search, has_filter)` pair. JSON config doesn't // expose AdaptiveL yet, so `InlineFilter` is unreachable here. let has_filter = search_params.vector_filters_file.is_some(); - let plan: SearchPlan<'_> = match (search_params.is_flat_search, has_filter) { - (true, false) => SearchPlan::flat(), - (true, true) => SearchPlan::flat_filtered(move |vid: &u32| vf.contains(vid)), - (false, false) => SearchPlan::graph(), + let mode: SearchMode<'_> = match (search_params.is_flat_search, has_filter) { + (true, false) => SearchMode::flat(), + (true, true) => SearchMode::flat_filtered(move |vid: &u32| vf.contains(vid)), + (false, false) => SearchMode::graph(), (false, true) => { - SearchPlan::graph_filtered(move |vid: &u32| vf.contains(vid)) + SearchMode::graph_filtered(move |vid: &u32| vf.contains(vid)) } }; @@ -286,7 +286,7 @@ where search_params.recall_at, l, Some(search_params.beam_width), - plan, + mode, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index 6696877e7..37882bb8a 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1089,7 +1089,7 @@ pub(crate) mod disk_index_builder_tests { &mut indices, &mut distances, &mut associated_data, - &crate::search::plan::SearchPlan::graph(), + &crate::search::search_mode::SearchMode::graph(), ); diskann_providers::test_utils::assert_top_k_exactly_match( From d1dcec30b7580ea0bf136de349c627afa57f7f98 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Thu, 18 Jun 2026 19:24:13 +0800 Subject: [PATCH 08/14] fix format issue --- diskann-benchmark/src/disk_index/search.rs | 6 ++-- diskann-benchmark/src/main.rs | 2 +- diskann-disk/src/search/mod.rs | 2 +- .../src/search/provider/disk_provider.rs | 35 ++++++++++++------- diskann-disk/src/search/search_mode.rs | 9 +++-- diskann-tools/src/utils/search_disk_index.rs | 14 ++++---- 6 files changed, 39 insertions(+), 29 deletions(-) diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index 2306cbc38..718c738b2 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -14,11 +14,11 @@ use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, search::{ - search_mode::SearchMode, provider::{ disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, }, + search_mode::SearchMode, }, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, @@ -276,9 +276,7 @@ where (true, false) => SearchMode::flat(), (true, true) => SearchMode::flat_filtered(move |vid: &u32| vf.contains(vid)), (false, false) => SearchMode::graph(), - (false, true) => { - SearchMode::graph_filtered(move |vid: &u32| vf.contains(vid)) - } + (false, true) => SearchMode::graph_filtered(move |vid: &u32| vf.contains(vid)), }; match searcher.search( diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index 609fee7ba..aa9a1db41 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -586,7 +586,7 @@ mod tests { let raw = value_from_file(&example_directory().join("graph-index-inline-filter.json")); run_integration_test(raw); } - + /// Filtered disk search end-to-end: drives the disk-index backend through /// `disk-index-filter.json` #[test] diff --git a/diskann-disk/src/search/mod.rs b/diskann-disk/src/search/mod.rs index 8c3eb8274..b377148de 100644 --- a/diskann-disk/src/search/mod.rs +++ b/diskann-disk/src/search/mod.rs @@ -5,7 +5,7 @@ //! Model module containing data structures, providers, and traits for disk index operations -pub mod search_mode; pub mod pq; pub mod provider; +pub mod search_mode; pub mod traits; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 3d7315e0b..a86c972be 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -43,8 +43,8 @@ use tracing::debug; use crate::{ data_model::{CachingStrategy, GraphHeader}, search::{ - search_mode::SearchMode, provider::disk_vertex_provider_factory::DiskVertexProviderFactory, + search_mode::SearchMode, traits::{VertexProvider, VertexProviderFactory}, }, storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader}, @@ -207,6 +207,11 @@ where /// from local data structures). Moving these components to the search strategy allows /// DiskProvider to satisfy 'static constraints while enabling flexible per-search /// resource management. +/// Borrowed predicate used internally by the disk search pipeline. +/// Spelled out here to keep the field/parameter signatures under +/// `clippy::type_complexity`'s default threshold. +type PostprocessFilter<'a> = &'a (dyn Fn(&u32) -> bool + Send + Sync); + pub struct DiskSearchStrategy<'a, Data, ProviderFactory> where Data: GraphDataType, @@ -221,7 +226,7 @@ where /// processor short-circuits via `.map_or(true, ...)` and skips the /// dyn-fn call entirely. Fn param is `u32` because `VectorIdType = u32` /// is enforced by trait bounds throughout this provider. - postprocess_filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>, + postprocess_filter: Option>, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: &'a ProviderFactory, @@ -271,11 +276,11 @@ impl IOTracker { pub struct RerankAndFilter<'a> { /// `None` means accept-all; `post_process` short-circuits before invoking /// the closure, so the unfiltered hot path pays no dyn-fn cost per node. - filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>, + filter: Option>, } impl<'a> RerankAndFilter<'a> { - fn new(filter: Option<&'a (dyn Fn(&u32) -> bool + Send + Sync)>) -> Self { + fn new(filter: Option>) -> Self { Self { filter } } } @@ -288,7 +293,8 @@ struct PredicateLabelProvider<'a> { impl std::fmt::Debug for PredicateLabelProvider<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PredicateLabelProvider").finish_non_exhaustive() + f.debug_struct("PredicateLabelProvider") + .finish_non_exhaustive() } } @@ -332,7 +338,7 @@ where .map(|n| n.id) // `None` short-circuits to `true` — no dyn-fn call when there's // no predicate set, which is the common unfiltered hot path. - .filter(|id| self.filter.map_or(true, |f| f(id))) + .filter(|id| self.filter.is_none_or(|f| f(id))) .filter_map(|n| { if let Some(entry) = accessor.scratch.distance_cache.get(&n) { Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0))) @@ -378,7 +384,7 @@ where ) -> Result { DiskAccessor::new( provider, - &self.io_tracker, + self.io_tracker, query, self.vertex_provider_factory, self.scratch_pool, @@ -785,7 +791,7 @@ where fn search_strategy<'a>( &'a self, io_tracker: &'a IOTracker, - postprocess_filter: Option<&'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync)>, + postprocess_filter: Option>, ) -> DiskSearchStrategy<'a, Data, ProviderFactory> { DiskSearchStrategy { io_tracker, @@ -840,8 +846,8 @@ where // `None` short-circuits to `true` — no dyn-fn call per node on the // unfiltered (recall-baseline) path. - let mut iter = (0..provider.num_points as u32) - .filter(|id| vector_filter.map_or(true, |f| f(id))); + let mut iter = + (0..provider.num_points as u32).filter(|id| vector_filter.is_none_or(|f| f(id))); loop { id_buffer.clear(); id_buffer.extend(iter.by_ref().take(batch_size)); @@ -898,7 +904,7 @@ where let label_provider = PredicateLabelProvider { predicate: vector_filter, }; - + let filtered_strategy = labeled::Filtered::new(strategy, &label_provider); let search = InlineFilterSearch::new(knn, adaptive_l); self.index @@ -908,7 +914,6 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. - pub fn search( &self, query: &[Data::VectorDataType], @@ -2096,7 +2101,11 @@ mod disk_provider_tests { .expect("inline filter with accept-all predicate must succeed"); let plain_ids: Vec = plain.results.iter().map(|r| r.vertex_id).collect(); - let inline_ids: Vec = inline_no_filter.results.iter().map(|r| r.vertex_id).collect(); + let inline_ids: Vec = inline_no_filter + .results + .iter() + .map(|r| r.vertex_id) + .collect(); assert_eq!( plain.stats.result_count, inline_no_filter.stats.result_count, diff --git a/diskann-disk/src/search/search_mode.rs b/diskann-disk/src/search/search_mode.rs index f476c007a..58c606a0e 100644 --- a/diskann-disk/src/search/search_mode.rs +++ b/diskann-disk/src/search/search_mode.rs @@ -27,11 +27,14 @@ pub type SearchPredicate<'a> = Box bool + Send + Sync + 'a>; /// * `InlineFilter` — label-filtered graph search; the predicate is consulted /// at visit time (not just during rerank). `adaptive_l = Some(_)` grows the /// beam mid-search if the observed match specificity is low. - pub enum SearchMode<'a> { - FlatScan { filter: Option> }, + FlatScan { + filter: Option>, + }, - Graph { filter: Option> }, + Graph { + filter: Option>, + }, InlineFilter { predicate: SearchPredicate<'a>, diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index f8b07e429..5ca7e6e52 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -9,11 +9,11 @@ use diskann::utils::IntoUsize; use diskann_disk::{ data_model::{CachingStrategy, GraphDataType}, search::{ - search_mode::SearchMode, provider::{ disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, }, + search_mode::SearchMode, }, storage::disk_index_reader::DiskIndexReader, utils::{ @@ -256,13 +256,13 @@ where let has_filter = parameters.vector_filters_file.is_some(); let mode: SearchMode<'_> = match (parameters.is_flat_search, has_filter) { (true, false) => SearchMode::flat(), - (true, true) => SearchMode::flat_filtered(move |vid: &u32| { - vector_filter.contains(vid) - }), + (true, true) => { + SearchMode::flat_filtered(move |vid: &u32| vector_filter.contains(vid)) + } (false, false) => SearchMode::graph(), - (false, true) => SearchMode::graph_filtered(move |vid: &u32| { - vector_filter.contains(vid) - }), + (false, true) => { + SearchMode::graph_filtered(move |vid: &u32| vector_filter.contains(vid)) + } }; let result = searcher.search( From bf11a17e5cbd2266027a62ca0b97b783e8d856df Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Tue, 30 Jun 2026 16:18:16 +0800 Subject: [PATCH 09/14] Eliminate inner dyn-Fn dispatch in InlineFilter label provider --- .../src/search/provider/disk_provider.rs | 36 ++++--------------- diskann-disk/src/search/search_mode.rs | 29 ++++++++++++--- 2 files changed, 32 insertions(+), 33 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 00379e8cf..89cf5d4fb 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -306,25 +306,6 @@ impl<'a> RerankAndFilter<'a> { } } -/// Adapter exposing the disk-side `&dyn Fn(&u32) -> bool` predicate as a -/// `QueryLabelProvider` for `InlineFilterSearch`. -struct PredicateLabelProvider<'a> { - predicate: &'a (dyn Fn(&u32) -> bool + Send + Sync), -} - -impl std::fmt::Debug for PredicateLabelProvider<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PredicateLabelProvider") - .finish_non_exhaustive() - } -} - -impl QueryLabelProvider for PredicateLabelProvider<'_> { - fn is_match(&self, vec_id: u32) -> bool { - (self.predicate)(&vec_id) - } -} - impl<'a> DeterminantDiversityAndFilter<'a> { pub fn new(filter: Option>, params: DeterminantDiversityParams) -> Self { Self { filter, params } @@ -1024,10 +1005,11 @@ where /// is grown mid-query if the observed match specificity is low (see /// `diskann::graph::search::AdaptiveL`). /// - /// The disk-side `&dyn Fn(&u32) -> bool` predicate is adapted to the - /// `QueryLabelProvider` interface that `InlineFilterSearch` consumes - /// via a stack-allocated `PredicateLabelProvider` shim — no allocation, - /// no lifetime threading past this function body. + /// The label-provider trait object is built once in + /// `SearchMode::inline_filter` from a generic adapter, so each filter + /// evaluation costs exactly one indirect dispatch (through the + /// `&dyn QueryLabelProvider` boundary required by `labeled::Filtered`), + /// not two. /// /// Reuses the same `DiskAccessor` surface as the plain `Knn` graph path: /// `start_point_distances` and `expand_beam`, both of which call @@ -1037,18 +1019,14 @@ where strategy: DiskSearchStrategy<'a, Data, ProviderFactory>, query: &[Data::VectorDataType], knn: Knn, - vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync), + label_provider: &(dyn QueryLabelProvider + 'a), adaptive_l: Option, output: &mut OB, ) -> ANNResult where OB: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send, { - let label_provider = PredicateLabelProvider { - predicate: vector_filter, - }; - - let filtered_strategy = labeled::Filtered::new(strategy, &label_provider); + let filtered_strategy = labeled::Filtered::new(strategy, label_provider); let search = InlineFilterSearch::new(knn, adaptive_l); self.index .search(search, &filtered_strategy, &DefaultContext, query, output) diff --git a/diskann-disk/src/search/search_mode.rs b/diskann-disk/src/search/search_mode.rs index 8f970b32a..e81556cbf 100644 --- a/diskann-disk/src/search/search_mode.rs +++ b/diskann-disk/src/search/search_mode.rs @@ -9,6 +9,7 @@ //! as a sum type so that invalid combinations (flat scan with adaptive L, //! inline filter without a predicate) are unrepresentable at the API boundary. +use diskann::graph::ext::labeled::QueryLabelProvider; use diskann::graph::search::AdaptiveL; use diskann_providers::model::graph::provider::DeterminantDiversityParams; @@ -42,7 +43,7 @@ pub enum SearchMode<'a> { }, InlineFilter { - predicate: SearchPredicate<'a>, + predicate: Box + 'a>, adaptive_l: Option, }, @@ -87,12 +88,32 @@ impl<'a> SearchMode<'a> { /// Inline label-filtered graph search. `adaptive_l = Some(_)` enables /// mid-search beam widening; `None` runs inline tracking only (no /// resizing). + /// + /// The closure is wrapped in a generic adapter (`FnLabelProvider`) + /// that implements `QueryLabelProvider`. pub fn inline_filter(predicate: F, adaptive_l: Option) -> Self where F: Fn(&u32) -> bool + Send + Sync + 'a, { + struct FnLabelProvider(F); + + impl std::fmt::Debug for FnLabelProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FnLabelProvider").finish_non_exhaustive() + } + } + + impl QueryLabelProvider for FnLabelProvider + where + F: Fn(&u32) -> bool + Send + Sync, + { + fn is_match(&self, vec_id: u32) -> bool { + (self.0)(&vec_id) + } + } + Self::InlineFilter { - predicate: Box::new(predicate), + predicate: Box::new(FnLabelProvider(predicate)), adaptive_l, } } @@ -168,8 +189,8 @@ mod tests { predicate, adaptive_l: None, } => { - assert!(predicate(&3)); - assert!(!predicate(&2)); + assert!(predicate.is_match(3)); + assert!(!predicate.is_match(2)); } _ => panic!("expected InlineFilter with adaptive_l = None"), } From 631bdcb07a7d9ef18b26c84e3a10dc19487f767a Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Wed, 1 Jul 2026 11:28:18 +0800 Subject: [PATCH 10/14] Add PostporcessStrategy for more clear code logic for None means AcceptAll --- .../src/search/provider/disk_provider.rs | 117 ++++++++++-------- diskann-disk/src/search/search_mode.rs | 12 +- 2 files changed, 71 insertions(+), 58 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 89cf5d4fb..f7987ad60 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -218,6 +218,20 @@ where /// `clippy::type_complexity`'s default threshold. type PostprocessFilter<'a> = &'a (dyn Fn(&u32) -> bool + Send + Sync); +/// Encodes whether to accept all candidates at rerank time or apply a +/// specific predicate. Used by `RerankAndFilter` and +/// `DeterminantDiversityAndFilter` instead of `Option` +/// so call sites are self-documenting without relying on comments to +/// explain what `None` means. +#[derive(Clone, Copy)] +pub enum PostprocessStrategy<'a> { + /// Accept every candidate — no predicate is called. Used by `FlatScan` + /// (filtered at scan time) and `InlineFilter` (filtered at visit time). + AcceptAll, + /// Apply the given predicate; non-matching candidates are dropped. + Apply(PostprocessFilter<'a>), +} + pub struct DiskSearchStrategy<'a, Data, ProviderFactory> where Data: GraphDataType, @@ -225,14 +239,10 @@ where { // Borrowed from `search_internal` so the strategy can be passed by value io_tracker: &'a IOTracker, - /// Borrowed predicate consumed *only* by `default_post_processor()` → - /// `RerankAndFilter`. Other variants (`FlatScan`, `InlineFilter`) filter - /// earlier in their pipelines and install `None` here to avoid a - /// redundant second application. `None` means accept-all — the post- - /// processor short-circuits via `.map_or(true, ...)` and skips the - /// dyn-fn call entirely. Fn param is `u32` because `VectorIdType = u32` - /// is enforced by trait bounds throughout this provider. - postprocess_filter: Option>, + /// Consumed only by `default_post_processor()` → `RerankAndFilter`. + /// `FlatScan` and `InlineFilter` filter earlier in their pipelines and + /// pass `AcceptAll` here to avoid a redundant second pass. + postprocess_filter: PostprocessStrategy<'a>, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: &'a ProviderFactory, @@ -280,14 +290,12 @@ impl IOTracker { #[derive(Clone, Copy)] pub struct RerankAndFilter<'a> { - /// `None` means accept-all; `post_process` short-circuits before invoking - /// the closure, so the unfiltered hot path pays no dyn-fn cost per node. - filter: Option>, + filter: PostprocessStrategy<'a>, } #[derive(Clone, Copy)] pub struct DeterminantDiversityAndFilter<'a> { - filter: Option>, + filter: PostprocessStrategy<'a>, params: DeterminantDiversityParams, } @@ -301,13 +309,13 @@ pub enum DiskSearchPostProcessor<'a> { } impl<'a> RerankAndFilter<'a> { - pub fn new(filter: Option>) -> Self { + pub fn new(filter: PostprocessStrategy<'a>) -> Self { Self { filter } } } impl<'a> DeterminantDiversityAndFilter<'a> { - pub fn new(filter: Option>, params: DeterminantDiversityParams) -> Self { + pub fn new(filter: PostprocessStrategy<'a>, params: DeterminantDiversityParams) -> Self { Self { filter, params } } } @@ -342,20 +350,27 @@ where let provider = accessor.provider; let mut uncached_ids = Vec::new(); - let mut reranked = candidates - .map(|n| n.id) - // `None` short-circuits to `true` — no dyn-fn call when there's - // no predicate set, which is the common unfiltered hot path. - .filter(|id| self.filter.is_none_or(|f| f(id))) - .filter_map(|n| { + let mut reranked = { + let mut process = |n: u32| { if let Some(entry) = accessor.scratch.distance_cache.get(&n) { Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0))) } else { uncached_ids.push(n); None } - }) - .collect::, _>>()?; + }; + match self.filter { + PostprocessStrategy::AcceptAll => candidates + .map(|n| n.id) + .filter_map(&mut process) + .collect::, _>>()?, + PostprocessStrategy::Apply(f) => candidates + .map(|n| n.id) + .filter(|id| f(id)) + .filter_map(&mut process) + .collect::, _>>()?, + } + }; if !uncached_ids.is_empty() { ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?; for n in &uncached_ids { @@ -404,12 +419,13 @@ where let provider = accessor.provider; let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; - let candidate_ids: Vec = candidates - .map(|candidate| candidate.id) - // `None` short-circuits to accept-all; only the `Some` arm - // pays the dyn-fn call per candidate. - .filter(|id| self.filter.is_none_or(|f| f(id))) - .collect(); + let candidate_ids: Vec = match self.filter { + PostprocessStrategy::AcceptAll => candidates.map(|candidate| candidate.id).collect(), + PostprocessStrategy::Apply(f) => candidates + .map(|candidate| candidate.id) + .filter(|id| f(id)) + .collect(), + }; if candidate_ids.is_empty() { return Ok(0); @@ -905,17 +921,11 @@ where }) } - /// Helper method to create a DiskSearchStrategy with common parameters. - /// - /// `postprocess_filter = None` means accept-all at rerank time. Callers pass - /// `None` on paths that filter earlier in their pipeline (`FlatScan` at - /// scan time, `InlineFilter` at visit time) to avoid re-applying the - /// predicate redundantly; only the plain `Graph` path installs a - /// predicate here + /// Helper method to create a `DiskSearchStrategy` with common parameters. fn search_strategy<'a>( &'a self, io_tracker: &'a IOTracker, - postprocess_filter: Option>, + postprocess_filter: PostprocessStrategy<'a>, ) -> DiskSearchStrategy<'a, Data, ProviderFactory> { DiskSearchStrategy { io_tracker, @@ -1119,7 +1129,7 @@ where // as the post-processor over the L candidate pool. let stats = match mode { SearchMode::FlatScan { filter } => { - let strategy = self.search_strategy(&io_tracker, None); + let strategy = self.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); self.runtime.block_on(self.flat_search( &strategy, query, @@ -1129,7 +1139,12 @@ where ))? } SearchMode::Graph { filter } => { - let strategy = self.search_strategy(&io_tracker, filter.as_deref()); + let strategy = self.search_strategy( + &io_tracker, + filter + .as_deref() + .map_or(PostprocessStrategy::AcceptAll, PostprocessStrategy::Apply), + ); let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( knn_search, @@ -1139,20 +1154,17 @@ where &mut result_output_buffer, ))? } - SearchMode::InlineFilter { - predicate, - adaptive_l, - } => { + SearchMode::InlineFilter { filter, adaptive_l } => { // Strategy is passed by value into `filter_search` so that the // `labeled::Filtered` wrapper can own it; `io_tracker` keeps // its counters reachable from this scope. - let strategy = self.search_strategy(&io_tracker, None); + let strategy = self.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.filter_search( strategy, query, knn_search, - predicate.as_ref(), + filter.as_ref(), adaptive_l.clone(), &mut result_output_buffer, ))? @@ -1161,10 +1173,13 @@ where // Strategy installs the filter so `RerankAndFilter` would also // honor it, but the active post-processor here is the // diversity selector built from `DiskSearchPostProcessor`. - let strategy = self.search_strategy(&io_tracker, filter.as_deref()); + let postprocess_config = filter + .as_deref() + .map_or(PostprocessStrategy::AcceptAll, PostprocessStrategy::Apply); + let strategy = self.search_strategy(&io_tracker, postprocess_config); let knn_search = Knn::new(k, l, beam_width)?; let processor = DiskSearchPostProcessor::DeterminantDiversity( - DeterminantDiversityAndFilter::new(filter.as_deref(), *params), + DeterminantDiversityAndFilter::new(postprocess_config, *params), ); self.runtime.block_on(self.index.search_with( knn_search, @@ -1763,7 +1778,7 @@ mod disk_provider_tests { &mut associated_data, ); let io_tracker = IOTracker::default(); - let strategy = search_engine.search_strategy(&io_tracker, None); + let strategy = search_engine.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); let recorded_search = @@ -1999,7 +2014,7 @@ mod disk_provider_tests { &mut associated_data, ); let io_tracker = IOTracker::default(); - let strategy = search_engine.search_strategy(&io_tracker, None); + let strategy = search_engine.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); // Create diverse search parameters with attribute provider let diverse_params = DiverseSearchParams::new( @@ -2047,9 +2062,7 @@ mod disk_provider_tests { &mut associated_data2, ); let io_tracker2 = IOTracker::default(); - // Old API passed `&|_| true` (accept-all); after the Option refactor - // that's structurally equivalent to `None`. - let strategy2 = search_engine.search_strategy(&io_tracker2, None); + let strategy2 = search_engine.search_strategy(&io_tracker2, PostprocessStrategy::AcceptAll); let search_params2 = Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap(); @@ -2462,7 +2475,7 @@ mod disk_provider_tests { ); let io_tracker = IOTracker::default(); - let strategy = search_engine.search_strategy(&io_tracker, None); + let strategy = search_engine.search_strategy(&io_tracker, PostprocessStrategy::AcceptAll); let mut search_record = VisitedSearchRecord::new(0); let search_params = Knn::new(10, 10, Some(4)).unwrap(); diff --git a/diskann-disk/src/search/search_mode.rs b/diskann-disk/src/search/search_mode.rs index e81556cbf..49f7462f9 100644 --- a/diskann-disk/src/search/search_mode.rs +++ b/diskann-disk/src/search/search_mode.rs @@ -43,7 +43,7 @@ pub enum SearchMode<'a> { }, InlineFilter { - predicate: Box + 'a>, + filter: Box + 'a>, adaptive_l: Option, }, @@ -90,7 +90,7 @@ impl<'a> SearchMode<'a> { /// resizing). /// /// The closure is wrapped in a generic adapter (`FnLabelProvider`) - /// that implements `QueryLabelProvider`. + /// that implements `QueryLabelProvider`. pub fn inline_filter(predicate: F, adaptive_l: Option) -> Self where F: Fn(&u32) -> bool + Send + Sync + 'a, @@ -113,7 +113,7 @@ impl<'a> SearchMode<'a> { } Self::InlineFilter { - predicate: Box::new(FnLabelProvider(predicate)), + filter: Box::new(FnLabelProvider(predicate)), adaptive_l, } } @@ -186,11 +186,11 @@ mod tests { let mode = SearchMode::inline_filter(|id| *id == 3, None); match &mode { SearchMode::InlineFilter { - predicate, + filter, adaptive_l: None, } => { - assert!(predicate.is_match(3)); - assert!(!predicate.is_match(2)); + assert!(filter.is_match(3)); + assert!(!filter.is_match(2)); } _ => panic!("expected InlineFilter with adaptive_l = None"), } From 8c06b031c5d6562547a20e46aa124a94e64d5730 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Wed, 1 Jul 2026 14:02:23 +0800 Subject: [PATCH 11/14] fix format issue --- diskann-tools/src/utils/search_disk_index.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index e4ec93e21..c01bbc8b7 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -257,7 +257,10 @@ where // Construct the mode from the CLI-driven // `(is_flat_search, has_filter)` pair. CLI doesn't expose // AdaptiveL yet, so `InlineFilter` is unreachable here. - let mode: SearchMode<'_> = match (parameters.is_flat_search, parameters.filter_bitmap_file.is_some()) { + let mode: SearchMode<'_> = match ( + parameters.is_flat_search, + parameters.filter_bitmap_file.is_some(), + ) { (true, false) => SearchMode::flat(), (true, true) => { SearchMode::flat_filtered(move |vid: &u32| vector_filter.contains(vid)) From aeda98208119b88d1315264a57194b2d3d359dab Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Wed, 1 Jul 2026 14:16:04 +0800 Subject: [PATCH 12/14] Add a json deserializable struct in benchmark module for SearchMode --- diskann-benchmark/src/disk_index/search.rs | 30 ++---- diskann-benchmark/src/inputs/disk.rs | 101 +++++++++++++++++++++ 2 files changed, 109 insertions(+), 22 deletions(-) diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index 6e79e81e3..3c3952ad1 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -272,28 +272,14 @@ where pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { // Construct the SearchMode from the JSON-driven - // (is_flat_search, has_filter, post_processor) triple. - // Flat scan ignores `post_processor` — determinant-diversity - // is a graph-traversal-only post-processing step. + // `adaptive_l` is now encapsulated in `DiskSearchMode`, so the + // benchmark only supplies the per-query filter and post-processor. let has_filter = search_params.vector_filters_file.is_some(); - let post = search_params.post_processor.as_ref(); - let mode: SearchMode<'_> = match (search_params.is_flat_search, has_filter, post) { - (true, false, _) => SearchMode::flat(), - (true, true, _) => SearchMode::flat_filtered(move |vid: &u32| vf.contains(vid)), - (false, false, None) => SearchMode::graph(), - (false, true, None) => { - SearchMode::graph_filtered(move |vid: &u32| vf.contains(vid)) - } - (false, false, Some(TopkPostProcessor::DeterminantDiversity(params))) => { - SearchMode::diverse_graph(*params) - } - (false, true, Some(TopkPostProcessor::DeterminantDiversity(params))) => { - SearchMode::diverse_graph_filtered( - move |vid: &u32| vf.contains(vid), - *params, - ) - } - }; + let mode: SearchMode<'_> = search_params.search_mode.search_mode( + has_filter, + vf, + search_params.post_processor.as_ref(), + ); match searcher.search( q, @@ -366,7 +352,7 @@ where num_threads: search_params.num_threads, beam_width: search_params.beam_width, recall_at: search_params.recall_at, - is_flat_search: search_params.is_flat_search, + is_flat_search: search_params.search_mode.is_flat_search, distance: search_params.distance, uses_vector_filters: search_params.vector_filters_file.is_some(), num_nodes_to_cache: search_params.num_nodes_to_cache, diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index f6b4cacac..d3e73a91f 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -6,8 +6,15 @@ use std::{fmt, num::NonZeroUsize, path::Path}; use anyhow::Context; +#[cfg(feature = "disk-index")] +use std::collections::HashSet; + +#[cfg(feature = "disk-index")] +use diskann::graph; use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; #[cfg(feature = "disk-index")] +use diskann_disk::search::search_mode::SearchMode; +#[cfg(feature = "disk-index")] use diskann_disk::QuantizationType; use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file}; use serde::{Deserialize, Serialize}; @@ -17,6 +24,9 @@ use crate::{ utils::SimilarityMeasure, }; +#[cfg(feature = "disk-index")] +use crate::inputs::graph_index::AdaptiveL; + ////////////// // Registry // ////////////// @@ -62,6 +72,80 @@ pub(crate) struct DiskIndexBuild { pub(crate) save_path: String, } +#[cfg(feature = "disk-index")] +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DiskSearchMode { + pub(crate) is_flat_search: bool, + #[serde(default)] + pub(crate) adaptive_l: Option, +} + +#[cfg(feature = "disk-index")] +impl DiskSearchMode { + pub(crate) fn search_mode<'a>( + &'a self, + has_vector_filters: bool, + vector_filter: &'a HashSet, + post_processor: Option<&TopkPostProcessor>, + ) -> SearchMode<'a> { + let adaptive_l = self.adaptive_l.as_ref().map(|adaptive_l| { + graph::search::AdaptiveL::new(adaptive_l.sample_count.into(), adaptive_l.scale_factor) + .expect("validated adaptive L must construct") + }); + + match ( + self.is_flat_search, + has_vector_filters, + post_processor, + adaptive_l, + ) { + (true, false, _, _) => SearchMode::flat(), + (true, true, _, _) => { + SearchMode::flat_filtered(move |vid: &u32| vector_filter.contains(vid)) + } + (false, false, Some(TopkPostProcessor::DeterminantDiversity(params)), _) => { + SearchMode::diverse_graph(*params) + } + (false, true, Some(TopkPostProcessor::DeterminantDiversity(params)), _) => { + SearchMode::diverse_graph_filtered( + move |vid: &u32| vector_filter.contains(vid), + *params, + ) + } + (false, false, None, Some(adaptive_l)) => { + SearchMode::inline_filter(|_| true, Some(adaptive_l)) + } + (false, true, None, Some(adaptive_l)) => SearchMode::inline_filter( + move |vid: &u32| vector_filter.contains(vid), + Some(adaptive_l), + ), + (false, false, None, None) => SearchMode::graph(), + (false, true, None, None) => { + SearchMode::graph_filtered(move |vid: &u32| vector_filter.contains(vid)) + } + } + } + + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + if let Some(adaptive_l) = self.adaptive_l.as_mut() { + adaptive_l.validate(checker)?; + } + Ok(()) + } +} + +#[cfg(feature = "disk-index")] +impl fmt::Display for DiskSearchMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let base = if self.is_flat_search { "flat" } else { "graph" }; + if self.adaptive_l.is_some() { + write!(f, "{} + adaptive-l", base) + } else { + write!(f, "{}", base) + } + } +} + /// Search phase configuration #[derive(Debug, Deserialize, Serialize)] pub(crate) struct DiskSearchPhase { @@ -71,6 +155,9 @@ pub(crate) struct DiskSearchPhase { pub(crate) beam_width: usize, pub(crate) search_list: Vec, pub(crate) recall_at: u32, + #[cfg(feature = "disk-index")] + pub(crate) search_mode: DiskSearchMode, + #[cfg(not(feature = "disk-index"))] pub(crate) is_flat_search: bool, pub(crate) distance: SimilarityMeasure, pub(crate) vector_filters_file: Option, @@ -181,6 +268,11 @@ impl DiskSearchPhase { vf.resolve(checker).context("invalid vector_filters_file")?; } + #[cfg(feature = "disk-index")] + self.search_mode + .validate(checker) + .context("invalid disk search mode")?; + // basic numeric sanity checks if self.search_list.is_empty() { anyhow::bail!("search_list must have at least one value"); @@ -250,6 +342,12 @@ impl Example for DiskIndexOperation { beam_width: 16, recall_at: 10, num_threads: 8, + #[cfg(feature = "disk-index")] + search_mode: DiskSearchMode { + is_flat_search: false, + adaptive_l: None, + }, + #[cfg(not(feature = "disk-index"))] is_flat_search: false, distance: SimilarityMeasure::SquaredL2, vector_filters_file: None, @@ -367,6 +465,9 @@ impl DiskSearchPhase { write_field!(f, "Beam Width", self.beam_width)?; write_field!(f, "Recall@", self.recall_at)?; write_field!(f, "Threads", self.num_threads)?; + #[cfg(feature = "disk-index")] + write_field!(f, "Search Mode", self.search_mode)?; + #[cfg(not(feature = "disk-index"))] write_field!(f, "Flat Search", self.is_flat_search)?; write_field!(f, "Distance", self.distance)?; match &self.vector_filters_file { From 1f773210409be0212208973a6c98000cceff9273 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Wed, 1 Jul 2026 14:21:03 +0800 Subject: [PATCH 13/14] remove unused import --- diskann-benchmark/src/disk_index/search.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index 3c3952ad1..a1e84e79b 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -36,10 +36,7 @@ use serde::{Deserialize, Serialize}; use crate::{ disk_index::json_spancollector::JsonSpanCollector, - inputs::{ - disk::{DiskIndexLoad, DiskSearchPhase}, - post_processor::TopkPostProcessor, - }, + inputs::disk::{DiskIndexLoad, DiskSearchPhase}, utils::{datafiles, SimilarityMeasure}, }; From 285344747681dff6fd24f966f65885968a242ed5 Mon Sep 17 00:00:00 2001 From: yaohongdeng Date: Wed, 1 Jul 2026 14:33:05 +0800 Subject: [PATCH 14/14] Add tolerance for old is_flat_search input in benchmark --- diskann-benchmark/src/inputs/disk.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index d3e73a91f..7ed521a87 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -73,7 +73,7 @@ pub(crate) struct DiskIndexBuild { } #[cfg(feature = "disk-index")] -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] pub(crate) struct DiskSearchMode { pub(crate) is_flat_search: bool, #[serde(default)] @@ -156,7 +156,13 @@ pub(crate) struct DiskSearchPhase { pub(crate) search_list: Vec, pub(crate) recall_at: u32, #[cfg(feature = "disk-index")] + #[serde(default)] pub(crate) search_mode: DiskSearchMode, + // Backward compatibility for older benchmark inputs that used + // `is_flat_search` directly at the search-phase level. + #[cfg(feature = "disk-index")] + #[serde(default, skip_serializing)] + pub(crate) is_flat_search: Option, #[cfg(not(feature = "disk-index"))] pub(crate) is_flat_search: bool, pub(crate) distance: SimilarityMeasure, @@ -268,6 +274,11 @@ impl DiskSearchPhase { vf.resolve(checker).context("invalid vector_filters_file")?; } + #[cfg(feature = "disk-index")] + if let Some(is_flat_search) = self.is_flat_search { + self.search_mode.is_flat_search = is_flat_search; + } + #[cfg(feature = "disk-index")] self.search_mode .validate(checker) @@ -347,6 +358,8 @@ impl Example for DiskIndexOperation { is_flat_search: false, adaptive_l: None, }, + #[cfg(feature = "disk-index")] + is_flat_search: None, #[cfg(not(feature = "disk-index"))] is_flat_search: false, distance: SimilarityMeasure::SquaredL2,