diff --git a/Cargo.lock b/Cargo.lock index 1ecde9f9f..5cd52b5d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -665,6 +665,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "diskann-wide", + "futures-util", "half", "indicatif", "itertools 0.13.0", @@ -929,6 +930,7 @@ dependencies = [ "diskann-quantization", "diskann-utils", "diskann-vector", + "futures-util", "half", "itertools 0.13.0", "num_cpus", @@ -940,6 +942,7 @@ dependencies = [ "rstest", "serde", "serde_json", + "tokio", "tracing", "tracing-subscriber", "vfs", diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index ce5018aad..fc8241bc8 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -26,6 +26,7 @@ serde = { workspace = true, features = ["derive"] } serde_json.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["rt-multi-thread"] } +futures-util = { workspace = true, features = ["alloc"] } diskann-vector.workspace = true diskann-wide.workspace = true diskann-label-filter.workspace = true diff --git a/diskann-benchmark/src/disk_index/build.rs b/diskann-benchmark/src/disk_index/build.rs index 9ab095cdf..629f949fb 100644 --- a/diskann-benchmark/src/disk_index/build.rs +++ b/diskann-benchmark/src/disk_index/build.rs @@ -129,7 +129,10 @@ where }; let start = std::time::Instant::now(); - disk_index.build()?; + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(params.num_threads.max(1)) + .build()?; + runtime.block_on(disk_index.build())?; let total_time: MicroSeconds = start.elapsed().into(); drop(span); diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index 77acaa351..b9ee4e829 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -3,7 +3,6 @@ * Licensed under the MIT license. */ -use rayon::prelude::*; use std::{collections::HashSet, fmt, sync::atomic::AtomicBool, time::Instant}; use opentelemetry::{global, trace::Span, trace::Tracer}; @@ -20,14 +19,12 @@ use diskann_disk::{ utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, }; use diskann_providers::storage::StorageReadProvider; -use diskann_providers::{ - storage::{ - get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, FileStorageProvider, - }, - utils::{create_thread_pool, ParallelIteratorInPool}, +use diskann_providers::storage::{ + get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, FileStorageProvider, }; use diskann_tools::utils::{search_index_utils, KRecallAtN}; use diskann_utils::views::Matrix; +use futures_util::stream::{self, StreamExt}; use serde::{Deserialize, Serialize}; use crate::{ @@ -227,12 +224,11 @@ where &index_reader, vertex_provider_factory, search_params.distance.into(), - None, )?; logger.log_checkpoint("index_loaded"); - let pool = create_thread_pool(search_params.num_threads)?; + let runtime = tokio::runtime::Builder::new_current_thread().build()?; let mut search_results_per_l = Vec::with_capacity(search_params.search_list.len()); let has_any_search_failed = AtomicBool::new(false); @@ -253,59 +249,70 @@ where tracer.start(span_name) }; - let zipped = queries - .par_row_iter() - .zip(vector_filters.par_iter()) - .zip(result_ids.par_chunks_mut(search_params.recall_at as usize)) - .zip(result_dists.par_chunks_mut(search_params.recall_at as usize)) - .zip(statistics_vec.par_iter_mut()) - .zip(result_counts.par_iter_mut()); - - zipped.for_each_in_pool( - pool.as_ref(), - |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { - let vector_filter = if search_params.vector_filters_file.is_none() { - None - } else { - Some(Box::new(move |vid: &u32| vf.contains(vid)) - as Box bool + Send + Sync>) - }; - - match searcher.search( - q, - search_params.recall_at, - l, - Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, - ) { - Ok(search_result) => { - *stats = search_result.stats.query_statistics; - *rc = search_result.results.len() as u32; - let actual_results = search_result - .results - .len() - .min(search_params.recall_at as usize); - for (i, result_item) in search_result - .results - .iter() - .take(actual_results) - .enumerate() - { - id_chunk[i] = result_item.vertex_id; - dist_chunk[i] = result_item.distance; - } + // Drive all queries concurrently on the caller-owned runtime, bounding the + // number of in-flight searches to `num_threads`. + let search_results: Vec<_> = runtime.block_on(async { + stream::iter(queries.row_iter().enumerate()) + .map(|(query_id, q)| { + let vf = &vector_filters[query_id]; + let searcher = &searcher; + async move { + let vector_filter = if search_params.vector_filters_file.is_none() { + None + } else { + Some(Box::new(move |vid: &u32| vf.contains(vid)) + as Box bool + Send + Sync>) + }; + + let result = searcher + .search( + q, + search_params.recall_at, + l, + Some(search_params.beam_width), + vector_filter, + search_params.is_flat_search, + ) + .await; + (query_id, result) } - Err(e) => { - eprintln!("Search failed for query: {:?}", e); - *rc = 0; - id_chunk.fill(0); - dist_chunk.fill(0.0); - has_any_search_failed.store(true, std::sync::atomic::Ordering::Release); + }) + .buffer_unordered(search_params.num_threads.max(1)) + .collect() + .await + }); + + for (query_id, result) in search_results { + let base = query_id * search_params.recall_at as usize; + let id_chunk = &mut result_ids[base..base + search_params.recall_at as usize]; + let dist_chunk = &mut result_dists[base..base + search_params.recall_at as usize]; + match result { + Ok(search_result) => { + statistics_vec[query_id] = search_result.stats.query_statistics; + result_counts[query_id] = search_result.results.len() as u32; + let actual_results = search_result + .results + .len() + .min(search_params.recall_at as usize); + for (i, result_item) in search_result + .results + .iter() + .take(actual_results) + .enumerate() + { + id_chunk[i] = result_item.vertex_id; + dist_chunk[i] = result_item.distance; } } - }, - ); + Err(e) => { + eprintln!("Search failed for query: {:?}", e); + result_counts[query_id] = 0; + id_chunk.fill(0); + dist_chunk.fill(0.0); + has_any_search_failed.store(true, std::sync::atomic::Ordering::Release); + } + } + } let total_time = start.elapsed(); if has_any_search_failed.load(std::sync::atomic::Ordering::Acquire) { diff --git a/diskann-disk/Cargo.toml b/diskann-disk/Cargo.toml index d49fafa17..ef7c3e39e 100644 --- a/diskann-disk/Cargo.toml +++ b/diskann-disk/Cargo.toml @@ -38,7 +38,7 @@ rand.workspace = true rayon.workspace = true serde = { workspace = true, features = ["derive"] } thiserror.workspace = true -tokio = { workspace = true, features = ["full"] } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "sync"] } tracing.workspace = true vfs = { workspace = true } diff --git a/diskann-disk/src/build/builder/build.rs b/diskann-disk/src/build/builder/build.rs index 046cbf8af..6c7be5ab6 100644 --- a/diskann-disk/src/build/builder/build.rs +++ b/diskann-disk/src/build/builder/build.rs @@ -42,7 +42,6 @@ use crate::{ }, inmem_builder::{load_inmem_index_builder, new_inmem_index_builder, InmemIndexBuilder}, quantizer::BuildQuantizer, - tokio::create_runtime, }, chunking::{ checkpoint::{ @@ -205,22 +204,7 @@ where ) } - pub fn build(&mut self) -> ANNResult<()> { - let runtime = create_runtime(self.index_configuration.num_threads)?; - runtime.block_on(async { - match self.build_internal().await { - Err(err) if err.kind() == ANNErrorKind::BuildInterrupted => { - info!( - "Index build was interrupted by continuation_checker, progress saved for resumption" - ); - Ok(()) // Return success for controlled interruptions - } - result => result, // Pass through any other result (Ok or Err) - } - }) - } - - async fn build_internal(&mut self) -> ANNResult<()> { + pub async fn build(&mut self) -> ANNResult<()> { let mut logger = PerfLogger::new_disk_index_build_logger(); let pool = create_thread_pool(self.index_configuration.num_threads)?; @@ -233,17 +217,30 @@ where self.index_configuration.num_threads ); - self.generate_compressed_data(pool.as_ref()).await?; - logger.log_checkpoint(DiskIndexBuildCheckpoint::PqConstruction); + let result: ANNResult<()> = async { + self.generate_compressed_data(pool.as_ref()).await?; + logger.log_checkpoint(DiskIndexBuildCheckpoint::PqConstruction); - self.build_inmem_index(pool.as_ref()).await?; - logger.log_checkpoint(DiskIndexBuildCheckpoint::InmemIndexBuild); + self.build_inmem_index(pool.as_ref()).await?; + logger.log_checkpoint(DiskIndexBuildCheckpoint::InmemIndexBuild); - // Use physical file to pass the memory index to the disk writer - self.create_disk_layout()?; - logger.log_checkpoint(DiskIndexBuildCheckpoint::DiskLayout); + // Use physical file to pass the memory index to the disk writer + self.create_disk_layout()?; + logger.log_checkpoint(DiskIndexBuildCheckpoint::DiskLayout); - Ok(()) + Ok(()) + } + .await; + + match result { + Err(err) if err.kind() == ANNErrorKind::BuildInterrupted => { + info!( + "Index build was interrupted by continuation_checker, progress saved for resumption" + ); + Ok(()) // Return success for controlled interruptions + } + result => result, // Pass through any other result (Ok or Err) + } } async fn generate_compressed_data(&mut self, pool: RayonThreadPoolRef<'_>) -> ANNResult<()> { diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index ed69058de..b957d6f64 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -825,7 +825,10 @@ pub(crate) mod disk_index_builder_tests { }?; let timer = Timer::new(); - disk_index.build()?; + let runtime = tokio::runtime::Builder::new_multi_thread() + .build() + .expect("failed to build tokio runtime"); + runtime.block_on(disk_index.build())?; println!("Indexing time: {} seconds", timer.elapsed().as_secs_f64()); Ok(()) @@ -1054,9 +1057,10 @@ pub(crate) mod disk_index_builder_tests { &index_reader, vertex_provider_factory, params.metric, - None, )?; + let runtime = tokio::runtime::Builder::new_current_thread().build()?; + let data = read_bin::(&mut storage_provider.open_reader(¶ms.data_path)?)?; let dim = data.ncols(); @@ -1080,7 +1084,7 @@ pub(crate) mod disk_index_builder_tests { let mut distances = vec![0f32; top_k]; let mut associated_data = vec![(); top_k]; - _ = search_engine.search_internal( + _ = runtime.block_on(search_engine.search_internal( query_data, top_k, search_l, @@ -1091,7 +1095,7 @@ pub(crate) mod disk_index_builder_tests { &mut associated_data, &|_| true, false, - ); + )); diskann_providers::test_utils::assert_top_k_exactly_match( q, >, &indices, &distances, top_k, diff --git a/diskann-disk/src/build/builder/mod.rs b/diskann-disk/src/build/builder/mod.rs index 9c22a0766..ae867d176 100644 --- a/diskann-disk/src/build/builder/mod.rs +++ b/diskann-disk/src/build/builder/mod.rs @@ -9,7 +9,6 @@ pub mod core; pub mod quantizer; pub mod inmem_builder; -pub mod tokio; #[cfg(test)] mod tests; diff --git a/diskann-disk/src/build/builder/tokio.rs b/diskann-disk/src/build/builder/tokio.rs deleted file mode 100644 index 4b66c37ce..000000000 --- a/diskann-disk/src/build/builder/tokio.rs +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use diskann::{ANNError, ANNResult}; - -/// Creates a new multi-threaded tokio runtime with the specified number of worker threads. -/// If `num_threads` is 0, it defaults to the number of logical CPUs. -pub fn create_runtime(num_threads: usize) -> ANNResult { - let mut builder = tokio::runtime::Builder::new_multi_thread(); - - if num_threads != 0 { - builder.worker_threads(num_threads); - } - - builder.build().map_err(|err| { - ANNError::log_index_error(format!("Failed to initialize tokio runtime: {}", err)) - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn get_logical_cpu_count() -> usize { - std::thread::available_parallelism() - .map(|n| n.get()) - .unwrap_or(1) - } - - #[test] - fn test_create_runtime_with_zero_threads_no_panic() { - // This test ensures that passing 0 threads doesn't panic - // and properly defaults to the number of logical CPUs - let result = create_runtime(0); - - // Should not panic and should succeed - assert!(result.is_ok(), "create_runtime(0) should not panic or fail"); - - let runtime = result.unwrap(); - - // Verify the runtime was created successfully by executing a simple task - let result = runtime.block_on(async { tokio::spawn(async { 42 }).await }); - - assert!(result.is_ok(), "Runtime should be functional"); - assert_eq!(result.unwrap(), 42); - } - - #[test] - fn test_create_runtime_with_specific_threads() { - // Test that specifying a specific number of threads works - let result = create_runtime(2); - assert!(result.is_ok(), "create_runtime(2) should succeed"); - - let runtime = result.unwrap(); - - // Verify the runtime works - let result = runtime.block_on(async { tokio::spawn(async { "test" }).await }); - - assert!(result.is_ok(), "Runtime should be functional"); - assert_eq!(result.unwrap(), "test"); - } - - #[test] - fn test_create_runtime_with_one_thread() { - // Test edge case with 1 thread - let result = create_runtime(1); - assert!(result.is_ok(), "create_runtime(1) should succeed"); - - let runtime = result.unwrap(); - - // Verify the runtime works even with just 1 thread - let result = runtime.block_on(async { tokio::spawn(async { true }).await }); - - assert!( - result.is_ok(), - "Single-threaded runtime should be functional" - ); - assert!(result.unwrap()); - } - - #[test] - fn test_zero_threads_defaults_to_cpu_count() { - // Test that 0 threads actually uses the logical CPU count - let expected_cpu_count = get_logical_cpu_count(); - - // We can't directly inspect the runtime's thread count easily, - // but we can ensure it doesn't panic and works correctly - let result = create_runtime(0); - assert!( - result.is_ok(), - "create_runtime(0) should default to {} CPUs", - expected_cpu_count - ); - - let runtime = result.unwrap(); - - // Test that the runtime can handle multiple concurrent tasks - // which would fail if it only had 1 thread and we expected more - let result = runtime.block_on(async { - let tasks = (0..expected_cpu_count.min(4)) - .map(|i| tokio::spawn(async move { i * 2 })) - .collect::>(); - - let mut results = Vec::new(); - for task in tasks { - results.push(task.await.unwrap()); - } - results - }); - - assert!( - result.len() <= 4, - "Should handle concurrent tasks successfully" - ); - } -} diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index f2f24d5e1..365a9b168 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -36,7 +36,6 @@ use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; use crate::search::pq::{quantizer_preprocess, PQData, PQScratch}; use diskann_vector::{distance::Metric, DistanceFunction}; -use tokio::runtime::Runtime; use tracing::debug; use crate::{ @@ -635,7 +634,6 @@ pub struct DiskIndexSearcher< ProviderFactory: VertexProviderFactory, { index: DiskANNIndex>, - runtime: Runtime, /// The vertex provider factory is used to create the vertex provider for each search instance. vertex_provider_factory: ProviderFactory, @@ -688,20 +686,13 @@ where /// * `disk_index_reader` - The disk index reader. /// * `vertex_provider_factory` - The vertex provider factory. /// * `metric` - Distance metric used for vector similarity calculations. - /// * `runtime` - Tokio runtime handle for executing async operations. pub fn new( num_threads: usize, search_io_limit: usize, disk_index_reader: &DiskIndexReader, vertex_provider_factory: ProviderFactory, metric: Metric, - runtime: Option, ) -> ANNResult { - let runtime = match runtime { - Some(rt) => rt, - None => tokio::runtime::Builder::new_current_thread().build()?, - }; - let graph_header = vertex_provider_factory.get_header()?; let metadata = graph_header.metadata(); let max_degree = graph_header.max_degree::()? as u32; @@ -739,7 +730,6 @@ where let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads)); Ok(Self { index, - runtime, vertex_provider_factory, scratch_pool, }) @@ -827,13 +817,13 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. - pub fn search( + pub async fn search( &self, query: &[Data::VectorDataType], return_list_size: u32, search_list_size: u32, beam_width: Option, - vector_filter: Option>, + vector_filter: Option>, is_flat_search: bool, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); @@ -842,18 +832,20 @@ where let mut associated_data = vec![Data::AssociatedDataType::default(); return_list_size as usize]; - let stats = self.search_internal( - query, - return_list_size as usize, - search_list_size, - beam_width, - &mut query_stats, - &mut indices, - &mut distances, - &mut associated_data, - &vector_filter.unwrap_or(default_vector_filter::()), - is_flat_search, - )?; + let stats = self + .search_internal( + query, + return_list_size as usize, + search_list_size, + beam_width, + &mut query_stats, + &mut indices, + &mut distances, + &mut associated_data, + &vector_filter.unwrap_or(default_vector_filter::()), + is_flat_search, + ) + .await?; let mut search_result = SearchResult { results: Vec::with_capacity(return_list_size as usize), @@ -878,7 +870,7 @@ where /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. #[allow(clippy::too_many_arguments)] - pub(crate) fn search_internal( + pub(crate) async fn search_internal( &self, query: &[Data::VectorDataType], k_value: usize, @@ -902,22 +894,25 @@ where let k = k_value; let l = search_list_size as usize; let stats = if is_flat_search { - self.runtime.block_on(self.flat_search( + self.flat_search( &strategy, query, vector_filter, l, &mut result_output_buffer, - ))? + ) + .await? } else { let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.index.search( - knn_search, - &strategy, - &DefaultContext, - query, - &mut result_output_buffer, - ))? + self.index + .search( + knn_search, + &strategy, + &DefaultContext, + query, + &mut result_output_buffer, + ) + .await? }; query_stats.total_comparisons = stats.cmps; query_stats.search_hops = stats.hops; @@ -1204,11 +1199,6 @@ mod disk_provider_tests { { assert!(params.io_limit > 0); - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(params.max_thread_num) - .build() - .unwrap(); - let disk_index_reader = DiskIndexReader::new( params.pq_pivot_file_path.to_string(), params.pq_compressed_file_path.to_string(), @@ -1231,7 +1221,6 @@ mod disk_provider_tests { &disk_index_reader, vertex_provider_factory, Metric::L2, - Some(runtime), ) .unwrap() } @@ -1291,6 +1280,8 @@ mod disk_provider_tests { load_query_result(params.storage_provider, params.truth_result_file_path); let pool = create_thread_pool(params.thread_num.into_usize()).unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap(); + let handle = runtime.handle().clone(); queries .par_row_iter() .enumerate() @@ -1300,7 +1291,7 @@ mod disk_provider_tests { let mut distances = vec![0f32; 10]; let mut associated_data = vec![(); 10]; - let result = params.index_search_engine.search_internal( + let result = handle.block_on(params.index_search_engine.search_internal( query, params.k, params.l as u32, @@ -1311,7 +1302,7 @@ mod disk_provider_tests { &mut associated_data, &(|_| true), false, - ); + )); // Calculate the range of the truth_result for this query let truth_slice = &truth_result[i * params.k..(i + 1) * params.k]; @@ -1351,13 +1342,21 @@ mod disk_provider_tests { let truth_result = load_query_result(params.storage_provider, params.truth_result_file_path); let pool = create_thread_pool(params.thread_num.into_usize()).unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap(); + let handle = runtime.handle().clone(); queries .par_row_iter() .enumerate() .for_each_in_pool(pool.as_ref(), |(i, query)| { - let result = params - .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + let result = handle + .block_on(params.index_search_engine.search( + query, + params.k as u32, + params.l as u32, + beam_width, + None, + false, + )) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1381,8 +1380,8 @@ mod disk_provider_tests { }); } - #[test] - fn test_disk_search_invalid_input() { + #[tokio::test] + async fn test_disk_search_invalid_input() { let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); let ctx = &DefaultContext; @@ -1458,25 +1457,27 @@ mod disk_provider_tests { let mut associated_data = vec![0u32; 10]; // Set L: {} to a value of at least K: - let result = search_engine.search_internal( - &query, - 10, - 10 - 1, - None, - &mut query_stats, - &mut indices, - &mut distances, - &mut associated_data, - &|_| true, - false, - ); + let result = search_engine + .search_internal( + &query, + 10, + 10 - 1, + None, + &mut query_stats, + &mut indices, + &mut distances, + &mut associated_data, + &|_| true, + false, + ) + .await; assert!(result.is_err()); assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError); } - #[test] - fn test_disk_search_beam_search() { + #[tokio::test] + async fn test_disk_search_beam_search() { let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); let search_engine = create_disk_index_searcher::( @@ -1507,14 +1508,15 @@ mod disk_provider_tests { let recorded_search = diskann::graph::search::RecordedKnn::new(search_params, &mut search_record); search_engine - .runtime - .block_on(search_engine.index.search( + .index + .search( recorded_search, &strategy, &DefaultContext, query_vector.as_slice(), &mut result_output_buffer, - )) + ) + .await .unwrap(); let ids = search_record @@ -1531,14 +1533,16 @@ mod disk_provider_tests { let return_list_size = 10; let search_list_size = 10; - let result = search_engine.search( - &query_vector, - return_list_size, - search_list_size, - Some(4), - None, - false, - ); + let result = search_engine + .search( + &query_vector, + return_list_size, + search_list_size, + Some(4), + None, + false, + ) + .await; assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); assert_eq!( @@ -1554,8 +1558,8 @@ mod disk_provider_tests { } #[cfg(feature = "experimental_diversity_search")] - #[test] - fn test_disk_search_diversity_search() { + #[tokio::test] + async fn test_disk_search_diversity_search() { use diskann::graph::DiverseSearchParams; use diskann::neighbor::AttributeValueProvider; use std::collections::HashMap; @@ -1636,14 +1640,15 @@ mod disk_provider_tests { let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params); let stats = search_engine - .runtime - .block_on(search_engine.index.search( + .index + .search( diverse_search, &strategy, &DefaultContext, query_vector.as_slice(), &mut result_output_buffer, - )) + ) + .await .unwrap(); // Verify that search was performed and returned some results @@ -1676,14 +1681,15 @@ mod disk_provider_tests { let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params); let stats = search_engine - .runtime - .block_on(search_engine.index.search( + .index + .search( diverse_search2, &strategy2, &DefaultContext, query_vector.as_slice(), &mut result_output_buffer2, - )) + ) + .await .unwrap(); // Verify results @@ -1821,7 +1827,8 @@ mod disk_provider_tests { vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0], vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], )] - fn test_search_with_vector_filter( + #[tokio::test] + async fn test_search_with_vector_filter( #[case] vector_filter: fn(&u32) -> bool, #[case] is_flat_search: bool, #[case] expected_result_count: u32, @@ -1866,18 +1873,20 @@ mod disk_provider_tests { let mut distances = vec![0f32; 10]; let mut associated_data = vec![(); 10]; - let result = search_engine.search_internal( - &query, - 10, - 10, - None, // beam_width - &mut query_stats, - &mut indices, - &mut distances, - &mut associated_data, - &vector_filter, - is_flat_search, - ); + let result = search_engine + .search_internal( + &query, + 10, + 10, + None, // beam_width + &mut query_stats, + &mut indices, + &mut distances, + &mut associated_data, + &vector_filter, + is_flat_search, + ) + .await; assert!(result.is_ok(), "Expected search to succeed"); assert_eq!( @@ -1891,14 +1900,16 @@ mod disk_provider_tests { "Expected distances to match" ); - let result_with_filter = search_engine.search( - &query, - 10, - 10, - None, // beam_width - Some(Box::new(vector_filter)), - is_flat_search, - ); + let result_with_filter = search_engine + .search( + &query, + 10, + 10, + None, // beam_width + Some(Box::new(vector_filter)), + is_flat_search, + ) + .await; assert!(result_with_filter.is_ok(), "Expected search to succeed"); let result_with_filter_unwrapped = result_with_filter.unwrap(); @@ -1926,8 +1937,8 @@ mod disk_provider_tests { ); } - #[test] - fn test_beam_search_respects_io_limit() { + #[tokio::test] + async fn test_beam_search_respects_io_limit() { let io_limit = 11; // Set a small IO limit for testing let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); @@ -1961,14 +1972,15 @@ mod disk_provider_tests { let recorded_search = diskann::graph::search::RecordedKnn::new(search_params, &mut search_record); search_engine - .runtime - .block_on(search_engine.index.search( + .index + .search( recorded_search, &strategy, &DefaultContext, query_vector.as_slice(), &mut result_output_buffer, - )) + ) + .await .unwrap(); let visited_ids = search_record .visited diff --git a/diskann-disk/src/search/provider/disk_vertex_provider.rs b/diskann-disk/src/search/provider/disk_vertex_provider.rs index f83759533..7a1fa9ffc 100644 --- a/diskann-disk/src/search/provider/disk_vertex_provider.rs +++ b/diskann-disk/src/search/provider/disk_vertex_provider.rs @@ -365,7 +365,8 @@ mod disk_vertex_provider_tests { .expect("Failed to delete mem index associated data file"); } - disk_index.build().unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap(); + runtime.block_on(disk_index.build()).unwrap(); // Assert that all data was kept in memory and no files were written to the disk. assert!(!storage_provider.exists(&mem_index_file_path)); diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 573b0f05c..59379c848 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -35,6 +35,8 @@ anyhow.workspace = true itertools.workspace = true diskann-label-filter.workspace = true serde_json.workspace = true +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } +futures-util = { workspace = true, features = ["alloc"] } [dev-dependencies] rstest.workspace = true diff --git a/diskann-tools/src/utils/build_disk_index.rs b/diskann-tools/src/utils/build_disk_index.rs index d55a1f566..890acc6ba 100644 --- a/diskann-tools/src/utils/build_disk_index.rs +++ b/diskann-tools/src/utils/build_disk_index.rs @@ -169,7 +169,13 @@ where } let timer = Timer::new(); - disk_index.build()?; + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(parameters.num_threads.max(1)) + .build() + .map_err(|err| { + ANNError::log_index_error(format!("Failed to initialize tokio runtime: {}", err)) + })?; + runtime.block_on(disk_index.build())?; let diff = timer.elapsed(); println!("Indexing time: {} seconds", diff.as_secs_f64()); diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 8bbdb1c8f..feaa53cdf 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -17,16 +17,14 @@ use diskann_disk::{ QueryStatistics, }, }; +use diskann_providers::storage::{get_compressed_pq_file, get_pq_pivot_file}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann_providers::{ - storage::{get_compressed_pq_file, get_pq_pivot_file}, - utils::{create_thread_pool, ParallelIteratorInPool}, -}; use diskann_utils::{ io::{read_bin, write_bin}, views::MatrixView, }; use diskann_vector::distance::Metric; +use futures_util::stream::{self, StreamExt}; use opentelemetry::global::BoxedSpan; #[cfg(feature = "perf_test")] use opentelemetry::{ @@ -34,7 +32,6 @@ use opentelemetry::{ KeyValue, }; use ordered_float::OrderedFloat; -use rayon::prelude::*; use tracing::{error, info}; use crate::utils::{search_index_utils, CMDResult, CMDToolError, KRecallAtN}; @@ -154,7 +151,6 @@ where &index_reader, vertex_provider_factory, parameters.metric, - None, )?; logger.log_checkpoint("index_loaded"); @@ -198,12 +194,15 @@ where let mut query_result_ids: Vec> = vec![vec![]; parameters.l_vec.len()]; let mut query_result_dists: Vec> = vec![vec![]; parameters.l_vec.len()]; - let mut cmp_stats: Vec = vec![0; query_num]; let has_any_search_failed = AtomicBool::new(false); let mut best_recall = 0.0; - let pool = create_thread_pool(parameters.num_threads)?; + let runtime = tokio::runtime::Builder::new_current_thread() + .build() + .map_err(|e| CMDToolError { + details: format!("Failed to build tokio runtime: {e}"), + })?; for (test_id, &l) in parameters.l_vec.iter().enumerate() { if l < parameters.recall_at { @@ -221,15 +220,6 @@ where let mut statistics: Vec = vec![QueryStatistics::default(); query_num]; let mut result_counts: Vec = vec![0; query_num]; - let zipped = cmp_stats - .par_iter_mut() - .zip(queries.par_row_iter()) - .zip(vector_filters.par_iter()) - .zip(query_result_ids[test_id].par_chunks_mut(parameters.recall_at as usize)) - .zip(query_result_dists[test_id].par_chunks_mut(parameters.recall_at as usize)) - .zip(statistics.par_iter_mut()) - .zip(result_counts.par_iter_mut()); - let mut _span: BoxedSpan; #[cfg(feature = "perf_test")] { @@ -240,49 +230,61 @@ where } let test_start = Instant::now(); - zipped.for_each_in_pool( - pool.as_ref(), - |( - (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats), - result_count, - )| { - let vector_filter_function: Box bool + Send + Sync> = - if parameters.vector_filters_file.is_none() { - Box::new(|_: &u32| true) - } else { - Box::new(move |vector_id: &u32| vector_filter.contains(vector_id)) - }; - - let result = searcher.search( - query, - parameters.recall_at, - l, - Some(parameters.beam_width as usize), - Some(vector_filter_function), - parameters.is_flat_search, - ); - - match result { - Ok(search_result) => { - *result_count = search_result.stats.result_count; - *stats = search_result.stats.query_statistics; - search_result - .results - .iter() - .take(parameters.recall_at as usize) - .enumerate() - .for_each(|(i, item)| { - query_result_id[i] = item.vertex_id; - query_result_dist[i] = item.distance; - }); - } - Err(e) => { - error!("Error during search: {}", e); - has_any_search_failed.store(true, std::sync::atomic::Ordering::Release); + // Drive all queries concurrently on the caller-owned runtime, bounding the + // number of in-flight searches to `num_threads`. + let search_results: Vec<_> = runtime.block_on(async { + stream::iter(queries.row_iter().enumerate()) + .map(|(query_id, query)| { + let vector_filter = &vector_filters[query_id]; + let searcher = &searcher; + async move { + let vector_filter_function: Box bool + Send + Sync> = + if parameters.vector_filters_file.is_none() { + Box::new(|_: &u32| true) + } else { + Box::new(move |vector_id: &u32| vector_filter.contains(vector_id)) + }; + + let result = searcher + .search( + query, + parameters.recall_at, + l, + Some(parameters.beam_width as usize), + Some(vector_filter_function), + parameters.is_flat_search, + ) + .await; + (query_id, result) } + }) + .buffer_unordered(parameters.num_threads.max(1)) + .collect() + .await + }); + + for (query_id, result) in search_results { + match result { + Ok(search_result) => { + result_counts[query_id] = search_result.stats.result_count; + statistics[query_id] = search_result.stats.query_statistics; + let base = query_id * parameters.recall_at as usize; + search_result + .results + .iter() + .take(parameters.recall_at as usize) + .enumerate() + .for_each(|(i, item)| { + query_result_ids[test_id][base + i] = item.vertex_id; + query_result_dists[test_id][base + i] = item.distance; + }); } - }, - ); + Err(e) => { + error!("Error during search: {}", e); + has_any_search_failed.store(true, std::sync::atomic::Ordering::Release); + } + } + } let diff = test_start.elapsed(); let qps = query_num as f32 / diff.as_secs_f32();