From d2fe014b6cab5cb45c91ef6baa1ebae8ff950be0 Mon Sep 17 00:00:00 2001 From: "Yuanyuan Tian (from Dev Box)" Date: Tue, 30 Jun 2026 14:59:46 +0800 Subject: [PATCH] test: improve unit-test coverage across crates Add targeted unit tests across diskann-bftree, diskann-label-filter, diskann-providers, diskann-quantization, diskann-tools, and vectorset to raise workspace coverage. Add tempfile dev-dependencies for file-IO tests. Add a codecov ignore for vectorset/src/main.rs (binary entry point, not unit-testable). --- .codecov.yml | 1 + Cargo.lock | 2 + diskann-bftree/src/lib.rs | 42 ++ diskann-bftree/src/neighbors.rs | 18 + diskann-bftree/src/quant.rs | 23 + diskann-bftree/src/vectors.rs | 40 ++ .../document_provider.rs | 218 ++++++++ .../encoded_attribute_accessor.rs | 84 +++ .../roaring_attribute_store.rs | 63 +++ .../encoded_document_accessor.rs | 75 +++ .../inline_beta_search/inline_beta_filter.rs | 46 ++ .../inline_beta_search/predicate_evaluator.rs | 91 ++++ .../src/kv_index/generic_index.rs | 52 ++ .../kv_index/inverted_index_provider_impl.rs | 102 ++++ .../kv_index/posting_list_accessor_impl.rs | 55 ++ .../src/kv_index/query_evaluator_impl.rs | 51 ++ diskann-label-filter/src/parser/ast.rs | 60 +++ .../src/parser/query_parser.rs | 134 +++++ .../src/set/roaring_set_provider.rs | 78 +++ .../src/stores/bftree_store.rs | 9 + .../src/traits/query_evaluator.rs | 47 ++ .../src/utils/flatten_utils.rs | 44 ++ .../src/utils/jsonl_reader.rs | 87 ++++ .../src/model/pq/fixed_chunk_pq_table.rs | 80 +++ .../src/model/pq/pq_construction.rs | 81 +++ diskann-providers/src/storage/pq_storage.rs | 61 +++ diskann-providers/src/utils/timer.rs | 9 + .../src/multi_vector/distance/factory.rs | 53 ++ diskann-tools/Cargo.toml | 1 + diskann-tools/src/bin/generate_minmax.rs | 60 +++ diskann-tools/src/bin/subsample_bin.rs | 84 +++ diskann-tools/src/utils/ground_truth.rs | 485 ++++++++++++++++++ diskann-tools/src/utils/search_index_utils.rs | 117 +++++ vectorset/Cargo.toml | 3 + vectorset/src/loader.rs | 414 ++++++++++----- 35 files changed, 2737 insertions(+), 133 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 0f2f9af31..7f1f35131 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -51,3 +51,4 @@ ignore: - "**/tests/**" - "**/benches/**" - "**/examples/**" + - "vectorset/src/main.rs" # binary entry point (Redis CLI), not unit-testable diff --git a/Cargo.lock b/Cargo.lock index 1ecde9f9f..aa1abc359 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -940,6 +940,7 @@ dependencies = [ "rstest", "serde", "serde_json", + "tempfile", "tracing", "tracing-subscriber", "vfs", @@ -3745,6 +3746,7 @@ dependencies = [ "rand 0.9.4", "redis", "serde", + "tempfile", "tokio", "toml 0.9.10+spec-1.1.0", ] diff --git a/diskann-bftree/src/lib.rs b/diskann-bftree/src/lib.rs index 07fc2fcaf..d35371049 100644 --- a/diskann-bftree/src/lib.rs +++ b/diskann-bftree/src/lib.rs @@ -160,3 +160,45 @@ impl Default for TestCallCount { Self::new() } } + +#[cfg(test)] +mod error_tests { + use super::*; + + #[test] + fn vector_unavailable_display_variants() { + let deleted = VectorUnavailable { + id: 7, + err: VectorError::Deleted, + }; + assert_eq!(deleted.to_string(), "vector 7 was deleted"); + + let not_found = VectorUnavailable { + id: 9, + err: VectorError::NotFound, + }; + assert_eq!(not_found.to_string(), "vector 9 not found"); + } + + #[test] + fn vector_unavailable_acknowledge_is_noop() { + let transient = VectorUnavailable { + id: 1, + err: VectorError::Deleted, + }; + // Acknowledging a transient deletion swallows it without producing an error. + transient.acknowledge("expected during traversal"); + } + + #[test] + fn vector_unavailable_escalate_produces_error() { + let transient = VectorUnavailable { + id: 3, + err: VectorError::NotFound, + }; + let escalated: ANNError = transient.escalate("lookup failed"); + let message = escalated.to_string(); + assert!(message.contains("vector 3 not found"), "got: {message}"); + assert!(message.contains("lookup failed"), "got: {message}"); + } +} diff --git a/diskann-bftree/src/neighbors.rs b/diskann-bftree/src/neighbors.rs index 142ab7663..66316dfd9 100644 --- a/diskann-bftree/src/neighbors.rs +++ b/diskann-bftree/src/neighbors.rs @@ -408,6 +408,24 @@ mod tests { assert!(neighbor_provider.get_neighbors(1, &mut result).is_err()); } + #[tokio::test] + async fn neighbor_error_paths() { + let provider = NeighborProvider::::new_with_config(6, Config::default()).unwrap(); + + // Reading a never-set id returns the NotFound error path. + let mut result = AdjacencyList::with_capacity(10); + assert!(provider.get_neighbors(42, &mut result).is_err()); + + // A neighbor list longer than the max degree is rejected. + let too_long: Vec = (0..=provider.max_degree()).collect(); + let mut buf = vec![0u32; provider.dim]; + assert!(provider.set_neighbors(7, &too_long, &mut buf).is_err()); + + // A write buffer shorter than `dim` is rejected. + let mut short_buf = vec![0u32; 1]; + assert!(provider.set_neighbors(7, &[1, 2], &mut short_buf).is_err()); + } + /// Test the interleaved and parallel traversal of the Bf-Tree /// by invoking the async accessors of the neighbor list provider #[tokio::test(flavor = "multi_thread", worker_threads = 5)] diff --git a/diskann-bftree/src/quant.rs b/diskann-bftree/src/quant.rs index e56d63b1f..a4850e227 100644 --- a/diskann-bftree/src/quant.rs +++ b/diskann-bftree/src/quant.rs @@ -295,6 +295,29 @@ mod tests { assert_eq!(quant_bytes, provider.quantizer.bytes()); } + #[tokio::test] + async fn get_vector_into_error_paths() { + let provider = create_test_provider(); + let expected = provider.quantizer.bytes(); + + // Wrong buffer length → Error. + let mut wrong = vec![0u8; expected + 1]; + match provider.get_vector_into(0, &mut wrong).unwrap_err() { + AccessError::Error(_) => {} + other => panic!("expected Error for wrong buffer length, got {other:?}"), + } + + // Unset id within an empty slot → NotFound transient. + let mut buffer = vec![0u8; expected]; + match provider.get_vector_into(99, &mut buffer).unwrap_err() { + AccessError::Transient(VectorUnavailable { + id, + err: VectorError::NotFound, + }) => assert_eq!(id, 99), + other => panic!("expected NotFound transient, got {other:?}"), + } + } + fn create_test_provider() -> QuantVectorProvider { let dim = 2; diff --git a/diskann-bftree/src/vectors.rs b/diskann-bftree/src/vectors.rs index 9c87b69ad..78ab8e74c 100644 --- a/diskann-bftree/src/vectors.rs +++ b/diskann-bftree/src/vectors.rs @@ -306,4 +306,44 @@ mod tests { let result = provider.get_vector_sync(0).unwrap(); assert_eq!(result, vec![1.0, 2.0, 3.0]); } + + #[test] + fn get_vector_into_wrong_buffer_dim_errors() { + let provider = VectorProvider::::new_with_config(5, 3, 0, Config::default()).unwrap(); + let mut buffer = [0.0f32; 2]; + let err = provider.get_vector_into(0, &mut buffer).unwrap_err(); + assert!(matches!(err, RankedError::Error(_))); + } + + #[test] + fn get_unset_vector_reports_not_found() { + let provider = VectorProvider::::new_with_config(5, 3, 0, Config::default()).unwrap(); + match provider.get_vector_sync(2).unwrap_err() { + RankedError::Transient(VectorUnavailable { + id, + err: VectorError::NotFound, + }) => assert_eq!(id, 2), + other => panic!("expected NotFound transient, got {other:?}"), + } + } + + #[test] + fn deleted_vector_reports_deleted() { + let provider = VectorProvider::::new_with_config(5, 3, 0, Config::default()).unwrap(); + provider.set_vector_sync(0, &[1.0, 2.0, 3.0]).unwrap(); + provider.delete_vector(0); + match provider.get_vector_sync(0).unwrap_err() { + RankedError::Transient(VectorUnavailable { + id, + err: VectorError::Deleted, + }) => assert_eq!(id, 0), + other => panic!("expected Deleted transient, got {other:?}"), + } + } + + #[test] + fn starting_points_and_getters() { + let provider = VectorProvider::::new_with_config(5, 3, 2, Config::default()).unwrap(); + assert_eq!(provider.starting_points().unwrap(), vec![5u32, 6u32]); + } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index 02634a0ce..426b1cc90 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -178,3 +178,221 @@ where self.status_by_internal_id(context, internal_id).await } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::attribute::{Attribute, AttributeValue}; + use crate::document::Document; + use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; + use crate::traits::attribute_store::AttributeStore; + use diskann::provider::{DefaultContext, ElementStatus, Guard, NoopGuard}; + use std::collections::HashMap; + use std::sync::Mutex; + + /// Minimal in-memory `DataProvider` that records element statuses and uses + /// an identity external-to-internal id mapping. + #[derive(Default)] + struct MockProvider { + statuses: Mutex>, + released: Mutex>, + } + + impl MockProvider { + fn set_status(&self, id: u32, status: ElementStatus) { + self.statuses.lock().unwrap().insert(id, status); + } + } + + impl DataProvider for MockProvider { + type Context = DefaultContext; + type InternalId = u32; + type ExternalId = u32; + type Error = ANNError; + type Guard = NoopGuard; + + fn to_internal_id(&self, _ctx: &Self::Context, gid: &u32) -> Result { + Ok(*gid) + } + + fn to_external_id(&self, _ctx: &Self::Context, id: u32) -> Result { + Ok(id) + } + } + + impl SetElement for MockProvider { + type SetError = ANNError; + + async fn set_element( + &self, + _ctx: &Self::Context, + id: &u32, + _element: T, + ) -> Result, ANNError> { + self.set_status(*id, ElementStatus::Valid); + Ok(NoopGuard::new(*id)) + } + } + + impl Delete for MockProvider { + async fn delete(&self, _ctx: &Self::Context, gid: &u32) -> Result<(), ANNError> { + self.set_status(*gid, ElementStatus::Deleted); + Ok(()) + } + + async fn release(&self, _ctx: &Self::Context, id: u32) -> Result<(), ANNError> { + self.released.lock().unwrap().push(id); + Ok(()) + } + + async fn status_by_internal_id( + &self, + _ctx: &Self::Context, + id: u32, + ) -> Result { + Ok(self + .statuses + .lock() + .unwrap() + .get(&id) + .copied() + .unwrap_or(ElementStatus::Deleted)) + } + + async fn status_by_external_id( + &self, + ctx: &Self::Context, + gid: &u32, + ) -> Result { + let id = self.to_internal_id(ctx, gid)?; + self.status_by_internal_id(ctx, id).await + } + } + + /// Drive a future to completion on the current thread. The futures produced + /// here never suspend, so a simple poll loop is sufficient. + fn block_on(fut: F) -> F::Output { + use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; + fn raw() -> RawWaker { + fn no_op(_: *const ()) {} + fn clone(_: *const ()) -> RawWaker { + raw() + } + RawWaker::new( + std::ptr::null(), + &RawWakerVTable::new(clone, no_op, no_op, no_op), + ) + } + // SAFETY: the vtable functions are all no-ops operating on a null pointer. + let waker = unsafe { Waker::from_raw(raw()) }; + let mut cx = Context::from_waker(&waker); + let mut fut = std::pin::pin!(fut); + loop { + match fut.as_mut().poll(&mut cx) { + Poll::Ready(v) => return v, + Poll::Pending => std::hint::spin_loop(), + } + } + } + + type Dp = DocumentProvider>; + + fn make() -> Dp { + DocumentProvider::new(MockProvider::default(), RoaringAttributeStore::::new()) + } + + fn attrs() -> Vec { + vec![Attribute::from_value( + "category", + AttributeValue::String("electronics".to_owned()), + )] + } + + #[test] + fn id_translation_delegates_to_inner() { + let dp = make(); + let ctx = DefaultContext; + assert_eq!(dp.to_internal_id(&ctx, &9).unwrap(), 9); + assert_eq!(dp.to_external_id(&ctx, 9).unwrap(), 9); + } + + #[test] + fn attribute_accessor_is_available() { + let dp = make(); + assert!(dp.attribute_accessor().is_ok()); + } + + #[test] + fn set_element_marks_vector_and_attributes() { + let dp = make(); + let ctx = DefaultContext; + let vector = vec![1.0_f32, 2.0, 3.0]; + let doc = Document::new(&vector, attrs()); + + let guard = block_on(dp.set_element(&ctx, &1, doc)).unwrap(); + assert_eq!(guard.id(), 1); + + // Present in both the data store and the attribute store -> valid. + let status = block_on(dp.status_by_internal_id(&ctx, 1)).unwrap(); + assert_eq!(status, ElementStatus::Valid); + + // The external-id path resolves to the same status. + let status = block_on(dp.status_by_external_id(&ctx, &1)).unwrap(); + assert_eq!(status, ElementStatus::Valid); + } + + #[test] + fn status_of_unknown_id_is_deleted() { + let dp = make(); + let ctx = DefaultContext; + // Absent from both stores -> reported as deleted. + let status = block_on(dp.status_by_internal_id(&ctx, 42)).unwrap(); + assert_eq!(status, ElementStatus::Deleted); + } + + #[test] + fn status_errors_when_attribute_present_but_data_deleted() { + let dp = make(); + let ctx = DefaultContext; + // Attribute store knows the id, data store does not. + dp.attribute_store().set_element(&5, &attrs()).unwrap(); + let result = block_on(dp.status_by_internal_id(&ctx, 5)); + assert!(result.is_err()); + } + + #[test] + fn status_errors_when_data_valid_but_attribute_absent() { + let dp = make(); + let ctx = DefaultContext; + // Data store reports valid, attribute store has no record. + dp.inner_provider().set_status(6, ElementStatus::Valid); + let result = block_on(dp.status_by_internal_id(&ctx, 6)); + assert!(result.is_err()); + } + + #[test] + fn delete_delegates_to_inner_provider() { + let dp = make(); + let ctx = DefaultContext; + let vector = vec![0.0_f32]; + block_on(dp.set_element(&ctx, &1, Document::new(&vector, attrs()))).unwrap(); + + block_on(dp.delete(&ctx, &1)).unwrap(); + + // The inner provider now reports the id as deleted. + let status = block_on(dp.inner_provider().status_by_internal_id(&ctx, 1)).unwrap(); + assert_eq!(status, ElementStatus::Deleted); + } + + #[test] + fn release_removes_attributes_and_releases_slot() { + let dp = make(); + let ctx = DefaultContext; + let vector = vec![0.0_f32]; + block_on(dp.set_element(&ctx, &1, Document::new(&vector, attrs()))).unwrap(); + + block_on(dp.release(&ctx, 1)).unwrap(); + + assert!(dp.inner_provider().released.lock().unwrap().contains(&1)); + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_attribute_accessor.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_attribute_accessor.rs index 23d30bf04..11175bf7e 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_attribute_accessor.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_attribute_accessor.rs @@ -124,3 +124,87 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::set::roaring_set_provider::RoaringTreemapSetProvider; + + fn provider_with_labels() -> Arc>> { + let mut sp = RoaringTreemapSetProvider::::new(); + // id 1 has labels {100, 200}; id 5 has label {300}. + sp.insert(&1u32, &100u64).unwrap(); + sp.insert(&1u32, &200u64).unwrap(); + sp.insert(&5u32, &300u64).unwrap(); + Arc::new(RwLock::new(sp)) + } + + #[test] + fn visit_labels_of_point_present_and_absent() { + let mut accessor = EncodedAttributeAccessor::new(provider_with_labels()); + + // Present id -> Some set with both labels. + let count = accessor + .visit_labels_of_point(1u32, |id, set| { + assert_eq!(id, 1u32); + let set = set.expect("expected a set for id 1"); + assert!(set.contains(100u64)); + assert!(set.contains(200u64)); + set.len() + }) + .unwrap(); + assert_eq!(count, 2u64); + + // Absent id -> None. + let is_none = accessor + .visit_labels_of_point(42u32, |_id, set| set.is_none()) + .unwrap(); + assert!(is_none); + } + + #[test] + fn visit_labels_of_points_iterates_all() { + let mut accessor = EncodedAttributeAccessor::new(provider_with_labels()); + + let mut seen: Vec<(u32, bool)> = Vec::new(); + accessor + .visit_labels_of_points([1u32, 5u32, 7u32], |id, set| { + seen.push((id, set.is_some())); + }) + .unwrap(); + + assert_eq!(seen, vec![(1u32, true), (5u32, true), (7u32, false)]); + } + + #[test] + fn visit_recovers_from_poisoned_lock() { + let provider = provider_with_labels(); + + // Poison the lock by panicking while holding the write guard. + let p2 = provider.clone(); + let _ = std::thread::spawn(move || { + let mut sp = p2.write().unwrap(); + sp.insert(&9u32, &900u64).unwrap(); + panic!("intentional panic to poison the lock"); + }) + .join(); + assert!(provider.is_poisoned()); + + // The accessor should still be able to read despite poisoning. + let mut accessor = EncodedAttributeAccessor::new(provider); + let present = accessor + .visit_labels_of_point(1u32, |_id, set| set.is_some()) + .unwrap(); + assert!(present); + + let mut count = 0usize; + accessor + .visit_labels_of_points([1u32, 5u32], |_id, set| { + if set.is_some() { + count += 1; + } + }) + .unwrap(); + assert_eq!(count, 2); + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index 3b62c35ec..6882e1313 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -190,3 +190,66 @@ where Ok(true) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::attribute::AttributeValue; + + fn attr(name: &str, value: &str) -> Attribute { + Attribute::from_value(name, AttributeValue::String(value.to_owned())) + } + + #[test] + fn set_element_then_id_exists_and_delete() { + let store = RoaringAttributeStore::::new(); + let attrs = vec![attr("color", "red"), attr("size", "large")]; + + // Insert a vector with two attributes. + assert!(store.set_element(&7u32, &attrs).unwrap()); + assert!(store.id_exists(&7u32).unwrap()); + + // A different id does not exist. + assert!(!store.id_exists(&8u32).unwrap()); + + // The attribute map should now contain the inserted attributes. + { + let map = store.attribute_map(); + let guard = map.read().unwrap(); + assert!(guard.get(&attr("color", "red")).is_some()); + assert!(guard.get(&attr("size", "large")).is_some()); + } + + // Deleting the vector removes it. + assert!(store.delete(&7u32).unwrap()); + assert!(!store.id_exists(&7u32).unwrap()); + } + + #[test] + fn set_element_with_no_attributes_is_error() { + let store = RoaringAttributeStore::::new(); + assert!(store.set_element(&1u32, &[]).is_err()); + } + + #[test] + fn delete_missing_id_returns_false() { + let store = RoaringAttributeStore::::new(); + assert!(!store.delete(&99u32).unwrap()); + } + + #[test] + fn set_element_twice_updates_existing_labels() { + let store = RoaringAttributeStore::::new(); + + assert!(store.set_element(&3u32, &[attr("k", "v1")]).unwrap()); + // Re-setting the same id exercises the "delete old labels" path. + assert!(store.set_element(&3u32, &[attr("k", "v2")]).unwrap()); + assert!(store.id_exists(&3u32).unwrap()); + } + + #[test] + fn attribute_accessor_is_constructible() { + let store = RoaringAttributeStore::::new(); + assert!(store.attribute_accessor().is_ok()); + } +} diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 23532ecd2..4a64959ce 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -149,3 +149,78 @@ where self.inner_accessor.is_not_start_point() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::attribute::{Attribute, AttributeValue}; + use crate::parser::ast::CompareOp; + use crate::set::SetProvider; + use serde_json::Value; + + /// Minimal inner accessor that only carries an id type. + struct MockAccessor; + + impl HasId for MockAccessor { + type Id = u64; + } + + /// Builds an accessor whose filter is `category == "electronics"`, with + /// vector `match_id` carrying the matching attribute and vector + /// `other_id` carrying an unrelated attribute. + fn build_accessor( + beta: f32, + match_id: u64, + other_id: u64, + ) -> EncodedDocumentAccessor { + let mut encoder = AttributeEncoder::new(); + let matching = Attribute::from_value( + "category".to_string(), + AttributeValue::String("electronics".to_string()), + ); + let unrelated = Attribute::from_value( + "category".to_string(), + AttributeValue::String("furniture".to_string()), + ); + let match_attr = encoder.insert(&matching); + let other_attr = encoder.insert(&unrelated); + let map = Arc::new(RwLock::new(encoder)); + + let mut provider: RoaringTreemapSetProvider = RoaringTreemapSetProvider::new(); + provider.insert(&match_id, &match_attr).unwrap(); + provider.insert(&other_id, &other_attr).unwrap(); + let attribute_accessor = EncodedAttributeAccessor::new(Arc::new(RwLock::new(provider))); + + let ast = ASTExpr::Compare { + field: "category".to_string(), + op: CompareOp::Eq(Value::String("electronics".to_string())), + }; + + EncodedDocumentAccessor::new(MockAccessor, attribute_accessor, map, &ast, beta).unwrap() + } + + #[test] + fn attributes_for_scales_distance_when_filter_matches() { + let mut accessor = build_accessor(0.5, 7, 8); + let scaled = accessor + .attributes_for(7, |computer, set| computer.apply(10.0, &set)) + .unwrap(); + assert_eq!(scaled, 5.0); + } + + #[test] + fn attributes_for_keeps_distance_when_filter_does_not_match() { + let mut accessor = build_accessor(0.5, 7, 8); + let unchanged = accessor + .attributes_for(8, |computer, set| computer.apply(10.0, &set)) + .unwrap(); + assert_eq!(unchanged, 10.0); + } + + #[test] + fn attributes_for_errors_when_point_has_no_labels() { + let mut accessor = build_accessor(0.5, 7, 8); + let result = accessor.attributes_for(99, |computer, set| computer.apply(10.0, &set)); + assert!(result.is_err()); + } +} diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 2e3c7497a..a3ef0d9e1 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -180,3 +180,49 @@ where .map_err(|e| e.into()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::attribute::{Attribute, AttributeValue}; + use crate::encoded_attribute_provider::attribute_encoder::AttributeEncoder; + use crate::parser::ast::{ASTExpr, CompareOp}; + use serde_json::Value; + use std::sync::{Arc, RwLock}; + + /// Build an [`InlineBetaComputer`] whose filter is `category == "electronics"`, + /// returning the encoded attribute id that a matching point must carry. + fn build_computer(beta: f32) -> (InlineBetaComputer, u64) { + let mut encoder = AttributeEncoder::new(); + let attr = Attribute::from_value( + "category".to_string(), + AttributeValue::String("electronics".to_string()), + ); + let id = encoder.insert(&attr); + let map = Arc::new(RwLock::new(encoder)); + + let ast = ASTExpr::Compare { + field: "category".to_string(), + op: CompareOp::Eq(Value::String("electronics".to_string())), + }; + let filter_expr = EncodedFilterExpr::new(&ast, map).unwrap(); + (InlineBetaComputer::new(beta, filter_expr), id) + } + + #[test] + fn apply_scales_distance_when_predicate_matches() { + let (computer, id) = build_computer(0.5); + let mut attrs = RoaringTreemap::new(); + attrs.insert(id); + // Predicate matches -> distance is scaled by beta. + assert_eq!(computer.apply(10.0, &attrs), 5.0); + } + + #[test] + fn apply_returns_distance_when_predicate_does_not_match() { + let (computer, _id) = build_computer(0.5); + // Point carries no attributes -> predicate fails -> distance unchanged. + let attrs = RoaringTreemap::new(); + assert_eq!(computer.apply(10.0, &attrs), 10.0); + } +} diff --git a/diskann-label-filter/src/inline_beta_search/predicate_evaluator.rs b/diskann-label-filter/src/inline_beta_search/predicate_evaluator.rs index dc89473cf..30d87af16 100644 --- a/diskann-label-filter/src/inline_beta_search/predicate_evaluator.rs +++ b/diskann-label-filter/src/inline_beta_search/predicate_evaluator.rs @@ -91,3 +91,94 @@ where self.labels_of_point.contains(label_id) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::encoded_attribute_provider::ast_id_expr::ASTIdExpr; + use roaring::RoaringTreemap; + + fn labels(ids: &[u64]) -> RoaringTreemap { + let mut s = RoaringTreemap::new(); + for &id in ids { + s.insert(id); + } + s + } + + #[test] + fn terminal_checks_membership() { + let set = labels(&[1, 2, 3]); + let eval = PredicateEvaluator::new(&set); + assert!(ASTIdExpr::Terminal(2u64).accept(&eval).unwrap()); + assert!(!ASTIdExpr::Terminal(9u64).accept(&eval).unwrap()); + } + + #[test] + fn and_requires_all_members() { + let set = labels(&[1, 2]); + let eval = PredicateEvaluator::new(&set); + let all_present = + ASTIdExpr::And(vec![ASTIdExpr::Terminal(1u64), ASTIdExpr::Terminal(2u64)]); + assert!(all_present.accept(&eval).unwrap()); + let one_missing = + ASTIdExpr::And(vec![ASTIdExpr::Terminal(1u64), ASTIdExpr::Terminal(3u64)]); + assert!(!one_missing.accept(&eval).unwrap()); + } + + #[test] + fn empty_and_is_vacuously_true() { + let set = labels(&[]); + let eval = PredicateEvaluator::new(&set); + assert!(ASTIdExpr::::And(vec![]).accept(&eval).unwrap()); + } + + #[test] + fn or_requires_any_member() { + let set = labels(&[1]); + let eval = PredicateEvaluator::new(&set); + let any_present = ASTIdExpr::Or(vec![ASTIdExpr::Terminal(9u64), ASTIdExpr::Terminal(1u64)]); + assert!(any_present.accept(&eval).unwrap()); + let none_present = + ASTIdExpr::Or(vec![ASTIdExpr::Terminal(8u64), ASTIdExpr::Terminal(9u64)]); + assert!(!none_present.accept(&eval).unwrap()); + } + + #[test] + fn empty_or_is_false() { + let set = labels(&[1]); + let eval = PredicateEvaluator::new(&set); + assert!(!ASTIdExpr::::Or(vec![]).accept(&eval).unwrap()); + } + + #[test] + fn not_negates_inner() { + let set = labels(&[1]); + let eval = PredicateEvaluator::new(&set); + assert!(ASTIdExpr::Not(Box::new(ASTIdExpr::Terminal(9u64))) + .accept(&eval) + .unwrap()); + assert!(!ASTIdExpr::Not(Box::new(ASTIdExpr::Terminal(1u64))) + .accept(&eval) + .unwrap()); + } + + #[test] + fn nested_and_or_not_combination() { + let set = labels(&[1, 2, 3]); + let eval = PredicateEvaluator::new(&set); + // (1 AND 2) OR (NOT 5) -> both branches true, overall true + let expr = ASTIdExpr::Or(vec![ + ASTIdExpr::And(vec![ASTIdExpr::Terminal(1u64), ASTIdExpr::Terminal(2u64)]), + ASTIdExpr::Not(Box::new(ASTIdExpr::Terminal(5u64))), + ]); + assert!(expr.accept(&eval).unwrap()); + + // (1 AND 9) AND (NOT 2) -> first branch false, overall false + let expr2 = ASTIdExpr::And(vec![ + ASTIdExpr::And(vec![ASTIdExpr::Terminal(1u64), ASTIdExpr::Terminal(9u64)]), + ASTIdExpr::Not(Box::new(ASTIdExpr::Terminal(2u64))), + ]); + assert!(!expr2.accept(&eval).unwrap()); + } +} diff --git a/diskann-label-filter/src/kv_index/generic_index.rs b/diskann-label-filter/src/kv_index/generic_index.rs index 63bb943e4..4575c038e 100644 --- a/diskann-label-filter/src/kv_index/generic_index.rs +++ b/diskann-label-filter/src/kv_index/generic_index.rs @@ -601,4 +601,56 @@ mod tests { assert_eq!(LOCATION_SERIALIZE_KEY_LIST, "serialize_key_list"); assert_eq!(DATA_TYPE_POSTING_LIST, "posting_list"); } + + fn mem_index() -> TestIndex { + let store = BfTreeStore::memory().expect("open store"); + TestIndex::new(Arc::new(store)) + } + + #[test] + fn test_normalize_field_default_is_identity() { + let index = mem_index(); + assert_eq!(index.normalize_field("color"), "color"); + assert_eq!(index.normalize_field("a.b.c"), "a.b.c"); + } + + #[test] + fn test_with_field_normalizer_applies_custom_logic() { + let index = mem_index().with_field_normalizer(|f| format!("/{}", f.replace('.', "/"))); + assert_eq!(index.normalize_field("a.b.c"), "/a/b/c"); + assert_eq!(index.normalize_field("flat"), "/flat"); + } + + #[test] + fn test_get_or_empty_posting_list_missing_and_present() { + let index = mem_index(); + + // Missing key yields an empty posting list. + let empty = index.get_or_empty_posting_list(b"no-such-key").unwrap(); + assert_eq!(empty.len(), 0); + + // After storing a serialized posting list, it round-trips. + let mut pl = RoaringPostingList::empty(); + pl.insert(7); + pl.insert(42); + index.store().set(b"some-key", &pl.serialize()).unwrap(); + + let loaded = index.get_or_empty_posting_list(b"some-key").unwrap(); + assert_eq!(loaded.len(), 2); + assert!(loaded.contains(7)); + assert!(loaded.contains(42)); + } + + #[test] + fn test_get_or_empty_posting_list_corrupt_is_error() { + let index = mem_index(); + index.store().set(b"bad", b"not-a-posting-list").unwrap(); + + let result = index.get_or_empty_posting_list(b"bad"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + IndexError::Serialization { .. } + )); + } } diff --git a/diskann-label-filter/src/kv_index/inverted_index_provider_impl.rs b/diskann-label-filter/src/kv_index/inverted_index_provider_impl.rs index f5e4ff00f..968545ef8 100644 --- a/diskann-label-filter/src/kv_index/inverted_index_provider_impl.rs +++ b/diskann-label-filter/src/kv_index/inverted_index_provider_impl.rs @@ -169,3 +169,105 @@ where Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::attribute::AttributeValue; + use crate::stores::bftree_store::BfTreeStore; + use crate::traits::key_codec::DefaultKeyCodec; + use crate::traits::posting_list_trait::RoaringPostingList; + use std::collections::HashMap; + use std::sync::Arc; + + type TestIndex = GenericIndex; + + fn index() -> TestIndex { + let store = BfTreeStore::memory().expect("memory store"); + TestIndex::new(Arc::new(store)) + } + + fn attrs(pairs: &[(&str, &str)]) -> Attributes { + pairs + .iter() + .map(|(f, v)| (f.to_string(), AttributeValue::String(v.to_string()))) + .collect::>() + } + + /// Read the posting list for a string field/value pair. + fn posting(index: &TestIndex, field: &str, value: &str) -> RoaringPostingList { + let codec = DefaultKeyCodec::default(); + let key = codec.encode_field_value_key(field, &AttributeValue::String(value.to_string())); + index.get_or_empty_posting_list(&key).unwrap() + } + + #[test] + fn insert_adds_document_to_posting_list() { + let mut idx = index(); + idx.insert(1, &attrs(&[("color", "red")])).unwrap(); + + let pl = posting(&idx, "color", "red"); + assert!(pl.contains(1)); + } + + #[test] + fn delete_removes_document_and_empties_posting_list() { + let mut idx = index(); + idx.insert(1, &attrs(&[("color", "red")])).unwrap(); + idx.delete(1).unwrap(); + + // Posting list became empty and was removed from the store. + assert_eq!(posting(&idx, "color", "red").len(), 0); + } + + #[test] + fn delete_missing_document_is_noop() { + let mut idx = index(); + // No reverse mapping exists -> early return, still Ok. + idx.delete(99).unwrap(); + } + + #[test] + fn delete_keeps_posting_list_with_remaining_documents() { + let mut idx = index(); + idx.insert(1, &attrs(&[("color", "red")])).unwrap(); + idx.insert(2, &attrs(&[("color", "red")])).unwrap(); + + idx.delete(1).unwrap(); + + let pl = posting(&idx, "color", "red"); + assert!(!pl.contains(1)); + assert!(pl.contains(2)); + } + + #[test] + fn update_moves_document_between_values() { + let mut idx = index(); + idx.insert(7, &attrs(&[("color", "red")])).unwrap(); + idx.update(7, &attrs(&[("color", "blue")])).unwrap(); + + assert!(!posting(&idx, "color", "red").contains(7)); + assert!(posting(&idx, "color", "blue").contains(7)); + } + + #[test] + fn batch_insert_delete_update_round_trip() { + let mut idx = index(); + + idx.batch_insert(&[ + (10, attrs(&[("color", "red")])), + (11, attrs(&[("color", "red")])), + ]) + .unwrap(); + assert!(posting(&idx, "color", "red").contains(10)); + assert!(posting(&idx, "color", "red").contains(11)); + + idx.batch_update(&[(10, attrs(&[("color", "green")]))]) + .unwrap(); + assert!(posting(&idx, "color", "green").contains(10)); + assert!(!posting(&idx, "color", "red").contains(10)); + + idx.batch_delete(&[11]).unwrap(); + assert!(!posting(&idx, "color", "red").contains(11)); + } +} diff --git a/diskann-label-filter/src/kv_index/posting_list_accessor_impl.rs b/diskann-label-filter/src/kv_index/posting_list_accessor_impl.rs index dd47d43fd..df415659c 100644 --- a/diskann-label-filter/src/kv_index/posting_list_accessor_impl.rs +++ b/diskann-label-filter/src/kv_index/posting_list_accessor_impl.rs @@ -47,3 +47,58 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::stores::bftree_store::BfTreeStore; + use crate::traits::key_codec::DefaultKeyCodec; + use crate::traits::posting_list_trait::RoaringPostingList; + use std::sync::Arc; + + type TestIndex = GenericIndex; + + fn index() -> TestIndex { + let store = BfTreeStore::memory().expect("memory store"); + TestIndex::new(Arc::new(store)) + } + + fn key(field: &str, value: &AttributeValue) -> Vec { + DefaultKeyCodec::default().encode_field_value_key(field, value) + } + + #[test] + fn get_posting_list_returns_stored_list() { + let idx = index(); + let value = AttributeValue::String("red".to_string()); + + let mut pl = RoaringPostingList::empty(); + pl.insert(3); + pl.insert(9); + idx.store() + .set(&key("color", &value), &pl.serialize()) + .unwrap(); + + let loaded = idx.get_posting_list("color", &value).unwrap().unwrap(); + assert!(loaded.contains(3)); + assert!(loaded.contains(9)); + } + + #[test] + fn get_posting_list_missing_returns_none() { + let idx = index(); + let value = AttributeValue::String("absent".to_string()); + assert!(idx.get_posting_list("color", &value).unwrap().is_none()); + } + + #[test] + fn get_posting_list_corrupt_is_error() { + let idx = index(); + let value = AttributeValue::String("bad".to_string()); + idx.store() + .set(&key("color", &value), b"not-valid") + .unwrap(); + + assert!(idx.get_posting_list("color", &value).is_err()); + } +} diff --git a/diskann-label-filter/src/kv_index/query_evaluator_impl.rs b/diskann-label-filter/src/kv_index/query_evaluator_impl.rs index 7d74ae9c4..38ce2cf31 100644 --- a/diskann-label-filter/src/kv_index/query_evaluator_impl.rs +++ b/diskann-label-filter/src/kv_index/query_evaluator_impl.rs @@ -867,4 +867,55 @@ mod tests { let result = index.evaluate_query(&expr).unwrap(); assert!(result.is_empty()); } + + #[test] + fn test_evaluate_query_integer_range_corrupt_value_is_error() { + let store = Arc::new(DummyKvStore::default()); + let mut index = + GenericIndex::::new(store.clone()); + + let mut attrs = Attributes::new(); + attrs.insert( + "age".to_string(), + AttributeValue::try_from(&json!(30)).unwrap(), + ); + index.insert(1, &attrs).unwrap(); + + // Inject corrupt posting-list bytes at an in-range integer key so the + // range scan's deserialize fails. + let codec = DefaultKeyCodec::default(); + let key = codec.encode_field_value_key("age", &AttributeValue::Integer(35)); + store.set(&key, b"not-a-posting-list").unwrap(); + + let expr = ASTExpr::Compare { + field: "age".to_string(), + op: CompareOp::Gte(20.0), + }; + assert!(index.evaluate_query(&expr).is_err()); + } + + #[test] + fn test_evaluate_query_float_range_corrupt_value_is_error() { + let store = Arc::new(DummyKvStore::default()); + let mut index = + GenericIndex::::new(store.clone()); + + let mut attrs = Attributes::new(); + attrs.insert( + "price".to_string(), + AttributeValue::try_from(&json!(25.0)).unwrap(), + ); + index.insert(1, &attrs).unwrap(); + + // Inject corrupt posting-list bytes at an in-range float key. + let codec = DefaultKeyCodec::default(); + let key = codec.encode_field_value_key("price", &AttributeValue::Real(30.0)); + store.set(&key, b"not-a-posting-list").unwrap(); + + let expr = ASTExpr::Compare { + field: "price".to_string(), + op: CompareOp::Gte(20.0), + }; + assert!(index.evaluate_query(&expr).is_err()); + } } diff --git a/diskann-label-filter/src/parser/ast.rs b/diskann-label-filter/src/parser/ast.rs index 251432b0e..e35ac2337 100644 --- a/diskann-label-filter/src/parser/ast.rs +++ b/diskann-label-filter/src/parser/ast.rs @@ -287,4 +287,64 @@ mod tests { let expected_nested = "AND(\n OR(\n age>30,\n age<20\n ),\n NOT(name==\"Admin\")\n)"; assert_eq!(nested_expr.to_string(), expected_nested); } + + #[test] + fn test_compare_op_display_all_variants() { + assert_eq!(CompareOp::Eq(json!(1)).to_string(), "=="); + assert_eq!(CompareOp::Ne(json!(1)).to_string(), "!="); + assert_eq!(CompareOp::Lt(1.0).to_string(), "<"); + assert_eq!(CompareOp::Lte(1.0).to_string(), "<="); + assert_eq!(CompareOp::Gt(1.0).to_string(), ">"); + assert_eq!(CompareOp::Gte(1.0).to_string(), ">="); + } + + #[test] + fn test_single_and_or_collapse() { + // A single-element AND/OR prints just the inner expression. + let inner = ASTExpr::Compare { + field: "x".to_string(), + op: CompareOp::Eq(json!(1)), + }; + assert_eq!(ASTExpr::And(vec![inner.clone()]).to_string(), "x==1"); + assert_eq!(ASTExpr::Or(vec![inner]).to_string(), "x==1"); + } + + #[test] + fn test_empty_and_or_print() { + assert_eq!(ASTExpr::And(vec![]).to_string(), "true"); + assert_eq!(ASTExpr::Or(vec![]).to_string(), "false"); + } + + #[test] + fn test_to_string_with_indent_custom() { + let expr = ASTExpr::And(vec![ + ASTExpr::Compare { + field: "a".to_string(), + op: CompareOp::Eq(json!(1)), + }, + ASTExpr::Compare { + field: "b".to_string(), + op: CompareOp::Eq(json!(2)), + }, + ]); + // Two-space indent is applied per nesting level. + let printed = expr.to_string_with_indent(" "); + assert_eq!(printed, "AND(\n a==1,\n b==2\n)"); + } + + #[test] + fn test_value_to_string_array_and_string_escaping() { + // Array values and embedded quotes are rendered by the print visitor. + let expr = ASTExpr::Compare { + field: "tags".to_string(), + op: CompareOp::Eq(json!(["a", "b"])), + }; + assert_eq!(expr.to_string(), "tags==[\"a\", \"b\"]"); + + let expr = ASTExpr::Compare { + field: "name".to_string(), + op: CompareOp::Eq(json!("a\"b")), + }; + assert_eq!(expr.to_string(), "name==\"a\\\"b\""); + } } diff --git a/diskann-label-filter/src/parser/query_parser.rs b/diskann-label-filter/src/parser/query_parser.rs index dcc196b99..014008295 100644 --- a/diskann-label-filter/src/parser/query_parser.rs +++ b/diskann-label-filter/src/parser/query_parser.rs @@ -812,4 +812,138 @@ mod tests { _ => panic!("Expected ParseFailure error, got: {:?}", result), } } + + #[test] + fn test_ne_operator() { + let filter = json!({"a": {"$ne": 1}}); + let ast = parse_query_filter(&filter).unwrap(); + assert!(matches!( + ast, + ASTExpr::Compare { field, op: CompareOp::Ne(v) } if field == "a" && v == json!(1) + )); + } + + #[test] + fn test_all_numeric_operators() { + for (key, build) in [ + ("$lt", "lt"), + ("$lte", "lte"), + ("$gt", "gt"), + ("$gte", "gte"), + ] { + let filter = json!({ "score": { key: 5 } }); + let ast = parse_query_filter(&filter) + .unwrap_or_else(|e| panic!("failed to parse {}: {:?}", key, e)); + let matched = match (build, &ast) { + ( + "lt", + ASTExpr::Compare { + op: CompareOp::Lt(n), + .. + }, + ) => *n == 5.0, + ( + "lte", + ASTExpr::Compare { + op: CompareOp::Lte(n), + .. + }, + ) => *n == 5.0, + ( + "gt", + ASTExpr::Compare { + op: CompareOp::Gt(n), + .. + }, + ) => *n == 5.0, + ( + "gte", + ASTExpr::Compare { + op: CompareOp::Gte(n), + .. + }, + ) => *n == 5.0, + _ => false, + }; + assert!( + matched, + "operator {} did not parse correctly: {:?}", + key, ast + ); + } + } + + #[test] + fn test_invalid_value_type_for_each_numeric_operator() { + for key in ["$lt", "$lte", "$gt", "$gte"] { + let filter = json!({ "score": { key: "nan" } }); + match parse_query_filter(&filter) { + Err(QueryFilterError::InvalidValueType(expected, _)) => { + assert_eq!(expected, "numeric"); + } + other => panic!("expected InvalidValueType for {}, got {:?}", key, other), + } + } + } + + #[test] + fn test_in_and_nin_unsupported() { + for key in ["$in", "$nin"] { + let filter = json!({ "tags": { key: [1, 2] } }); + match parse_query_filter(&filter) { + Err(QueryFilterError::UnsupportedComparisonOperator(op)) => { + assert_eq!(op, key); + } + other => panic!("expected unsupported operator for {}, got {:?}", key, other), + } + } + } + + #[test] + fn test_not_with_nested_error_propagates() { + // $not whose child has an invalid numeric value -> InvalidValueType bubbles up. + let filter = json!({"$not": {"a": {"$lt": "x"}}}); + assert!(matches!( + parse_query_filter(&filter), + Err(QueryFilterError::InvalidValueType(_, _)) + )); + } + + #[test] + fn test_get_value_by_path_on_non_object() { + let scalar = json!(42); + assert_eq!(get_value_by_path(&scalar, "a"), None); + // Traversing through a scalar mid-path returns None. + let label = json!({"a": 1}); + assert_eq!(get_value_by_path(&label, "a.b"), None); + } + + #[test] + fn test_query_filter_error_display_all_variants() { + let cases: Vec<(QueryFilterError, &str)> = vec![ + ( + QueryFilterError::NestingTooDeep { max_depth: 2 }, + "Maximum nesting depth of 2 exceeded", + ), + ( + QueryFilterError::UnsupportedLogicalOperator("$any".to_string()), + "Unsupported logical operator: $any", + ), + ( + QueryFilterError::UnsupportedComparisonOperator("$cmp".to_string()), + "Unsupported comparison operator: $cmp", + ), + ( + QueryFilterError::InvalidValueType("numeric".to_string(), "\"x\"".to_string()), + "Invalid value type: expected numeric, got \"x\"", + ), + ( + QueryFilterError::ParseFailure("boom".to_string()), + "Parse failure: boom", + ), + ]; + for (err, expected) in cases { + assert_eq!(format!("{}", err), expected); + } + } } diff --git a/diskann-label-filter/src/set/roaring_set_provider.rs b/diskann-label-filter/src/set/roaring_set_provider.rs index 4172a4527..cae95685f 100644 --- a/diskann-label-filter/src/set/roaring_set_provider.rs +++ b/diskann-label-filter/src/set/roaring_set_provider.rs @@ -288,6 +288,34 @@ mod tests { assert!(set2.contains(30)); } + #[test] + fn test_roaring_set_provider_exists_insert_values_clear() { + let mut provider: RoaringSetProvider = RoaringSetProvider::new(); + + // exists on empty provider + assert!(!provider.exists(&1).unwrap()); + + // insert_values adds multiple values at once + assert!(provider.insert_values(&1, &[10, 20, 30]).unwrap()); + + // exists after insertion + assert!(provider.exists(&1).unwrap()); + + let set1 = provider.get(&1).unwrap().unwrap(); + assert_eq!(set1.len(), 3); + assert!(set1.contains(10)); + assert!(set1.contains(20)); + assert!(set1.contains(30)); + + // re-inserting existing values reports not-all-newly-inserted + assert!(!provider.insert_values(&1, &[10, 20]).unwrap()); + + // clear removes everything + provider.clear().unwrap(); + assert_eq!(provider.count().unwrap(), 0usize); + assert!(!provider.exists(&1).unwrap()); + } + #[test] fn test_roaring_treemap_provider_delete() { let mut provider: RoaringTreemapSetProvider = RoaringTreemapSetProvider::new(); @@ -314,4 +342,54 @@ mod tests { // deleted entry returns empty set assert!(provider.get(&1).unwrap().is_none()); } + + #[test] + fn test_roaring_treemap_provider_exists_insert_values_clear() { + let mut provider: RoaringTreemapSetProvider = RoaringTreemapSetProvider::new(); + + assert!(!provider.exists(&1).unwrap()); + + let big: u64 = 1 << 50; + assert!(provider.insert_values(&1, &[10u64, 20u64, big]).unwrap()); + assert!(provider.exists(&1).unwrap()); + + let set1 = provider.get(&1).unwrap().unwrap(); + assert_eq!(set1.len(), 3); + assert!(set1.contains(big)); + + // re-inserting existing values reports not-all-newly-inserted + assert!(!provider.insert_values(&1, &[10u64]).unwrap()); + + // delete_from_set on a missing key returns false + assert!(!provider.delete_from_set(&99, &10u64).unwrap()); + + provider.clear().unwrap(); + assert_eq!(provider.count().unwrap(), 0usize); + assert!(!provider.exists(&1).unwrap()); + } + + #[test] + fn test_identity_hasher_specializations() { + let mut h = IdentityHasher::default(); + h.write_u8(5); + assert_eq!(h.finish(), 5); + h.write_u16(500); + assert_eq!(h.finish(), 500); + h.write_u32(70_000); + assert_eq!(h.finish(), 70_000); + h.write_u64(1u64 << 40); + assert_eq!(h.finish(), 1u64 << 40); + + // write_u128 folds the high and low halves with XOR. + h.write_u128((1u128 << 64) | 7); + assert_eq!(h.finish(), 1u64 ^ 7u64); + + // The byte fallback uses FNV-1a; verify it is deterministic. + let mut a = IdentityHasher::default(); + let mut b = IdentityHasher::default(); + a.write(&[1, 2, 3]); + b.write(&[1, 2, 3]); + assert_eq!(a.finish(), b.finish()); + assert_ne!(a.finish(), 0); + } } diff --git a/diskann-label-filter/src/stores/bftree_store.rs b/diskann-label-filter/src/stores/bftree_store.rs index 1257b239c..f0472ff7a 100644 --- a/diskann-label-filter/src/stores/bftree_store.rs +++ b/diskann-label-filter/src/stores/bftree_store.rs @@ -275,6 +275,15 @@ impl KvStore for BfTreeStore { mod tests { use super::*; + #[test] + fn error_conversions_and_display() { + let from_str: BfTreeStoreError = "boom".into(); + assert!(from_str.to_string().contains("boom")); + + let from_string: BfTreeStoreError = String::from("kaboom").into(); + assert!(from_string.to_string().contains("kaboom")); + } + #[test] fn small_key_small_value() { let store = BfTreeStore::memory().unwrap(); diff --git a/diskann-label-filter/src/traits/query_evaluator.rs b/diskann-label-filter/src/traits/query_evaluator.rs index 25e416555..3ef212c6b 100644 --- a/diskann-label-filter/src/traits/query_evaluator.rs +++ b/diskann-label-filter/src/traits/query_evaluator.rs @@ -72,3 +72,50 @@ pub trait QueryEvaluator { Ok(bs.len()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::traits::posting_list_trait::{PostingList, RoaringPostingList}; + use crate::CompareOp; + + /// A minimal evaluator that always returns a fixed posting list, used to + /// exercise the default `is_match` and `count_matches` methods of the trait. + struct FixedEvaluator { + list: RoaringPostingList, + } + + impl QueryEvaluator for FixedEvaluator { + type Error = ::Error; + type PostingList = RoaringPostingList; + type DocId = usize; + + fn evaluate_query( + &self, + _query_expr: &ASTExpr, + ) -> std::result::Result { + Ok(self.list.clone()) + } + } + + fn sample_expr() -> ASTExpr { + ASTExpr::Compare { + field: "field".to_owned(), + op: CompareOp::Eq(serde_json::json!("value")), + } + } + + #[test] + fn default_is_match_and_count_matches() { + let mut list = RoaringPostingList::empty(); + list.insert(1); + list.insert(2); + list.insert(3); + let evaluator = FixedEvaluator { list }; + + let expr = sample_expr(); + assert!(evaluator.is_match(2, &expr).unwrap()); + assert!(!evaluator.is_match(9, &expr).unwrap()); + assert_eq!(evaluator.count_matches(&expr).unwrap(), 3); + } +} diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 96b04d0df..87d7ad92f 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -394,4 +394,48 @@ mod tests { &AttributeValue::String("test".into()) ); } + + #[test] + fn test_flatten_remaining_wrappers() { + let value = json!({"user": {"name": "John"}}); + + // underscore_notation preset + let map = + flatten_json_pointers_map_with_config(&value, &FlattenConfig::underscore_notation()); + assert_eq!( + map.get("_user_name").unwrap(), + &AttributeValue::String("John".into()) + ); + + // json_pointer preset matches the default "/" separator + let map = flatten_json_pointers_map_with_config(&value, &FlattenConfig::json_pointer()); + assert_eq!( + map.get("/user/name").unwrap(), + &AttributeValue::String("John".into()) + ); + + // Vec-returning custom-separator wrapper + let vec = flatten_json_pointers_with_separator(&value, "-"); + assert_eq!( + vec, + vec![( + "-user-name".to_string(), + AttributeValue::String("John".into()) + )] + ); + + // Map-returning default wrapper + let map = flatten_json_pointers_map(&value); + assert_eq!( + map.get("/user/name").unwrap(), + &AttributeValue::String("John".into()) + ); + + // Map-returning custom-separator wrapper + let map = flatten_json_pointers_map_with_separator(&value, "_"); + assert_eq!( + map.get("_user_name").unwrap(), + &AttributeValue::String("John".into()) + ); + } } diff --git a/diskann-label-filter/src/utils/jsonl_reader.rs b/diskann-label-filter/src/utils/jsonl_reader.rs index 0b376b8a9..811c13e36 100644 --- a/diskann-label-filter/src/utils/jsonl_reader.rs +++ b/diskann-label-filter/src/utils/jsonl_reader.rs @@ -334,4 +334,91 @@ mod tests { assert_eq!(parsed_queries.len(), 2); // Additional validation of the parsed expressions would be done in the parser tests } + + fn write_file(contents: &str) -> (tempfile::TempDir, std::path::PathBuf) { + let dir = tempdir().unwrap(); + let path = dir.path().join("data.jsonl"); + let mut file = File::create(&path).unwrap(); + write!(file, "{}", contents).unwrap(); + (dir, path) + } + + #[test] + fn test_missing_file_is_io_error() { + let err = read_baselabels("does_not_exist_12345.jsonl").unwrap_err(); + assert!(matches!(err, JsonlReadError::IoError(_))); + // Display path for IoError. + assert!(format!("{}", err).starts_with("IO error:")); + } + + #[test] + fn test_malformed_label_line_is_parse_error() { + let (_dir, path) = write_file("not valid json\n"); + let err = read_baselabels(&path).unwrap_err(); + match err { + JsonlReadError::ParseError(msg) => assert!(msg.contains("line 1")), + other => panic!("expected ParseError, got {:?}", other), + } + } + + #[test] + fn test_malformed_query_line_is_parse_error() { + let (_dir, path) = write_file("{ broken\n"); + assert!(matches!( + read_queries(&path).unwrap_err(), + JsonlReadError::ParseError(_) + )); + } + + #[test] + fn test_empty_ground_truth_file() { + let (_dir, path) = write_file(""); + match read_ground_truth(&path).unwrap_err() { + JsonlReadError::ParseError(msg) => assert!(msg.contains("empty")), + other => panic!("expected ParseError, got {:?}", other), + } + } + + #[test] + fn test_ground_truth_bad_metadata_and_bad_result() { + // Bad metadata line. + let (_d1, p1) = write_file("not-json\n"); + match read_ground_truth(&p1).unwrap_err() { + JsonlReadError::ParseError(msg) => assert!(msg.contains("metadata")), + other => panic!("expected metadata ParseError, got {:?}", other), + } + + // Good metadata, bad result line. + let (_d2, p2) = + write_file("{\"distance_func\": \"l2\", \"query_num\": 1}\nbroken-result\n"); + match read_ground_truth(&p2).unwrap_err() { + JsonlReadError::ParseError(msg) => assert!(msg.contains("result line")), + other => panic!("expected result ParseError, got {:?}", other), + } + } + + #[test] + fn test_read_and_parse_queries_propagates_filter_error() { + // A valid query line whose filter uses an unsupported operator. + let (_dir, path) = write_file("{\"query_id\": 3, \"filter\": {\"a\": {\"$bogus\": 1}}}\n"); + match read_and_parse_queries(&path).unwrap_err() { + JsonlReadError::ParseError(msg) => assert!(msg.contains("query ID 3")), + other => panic!("expected ParseError, got {:?}", other), + } + } + + #[test] + fn test_error_display_and_from_conversions() { + let json_err: JsonlReadError = serde_json::from_str::("oops") + .unwrap_err() + .into(); + assert!(matches!(json_err, JsonlReadError::JsonError(_))); + assert!(format!("{}", json_err).starts_with("JSON parsing error:")); + + let io_err: JsonlReadError = io::Error::other("boom").into(); + assert!(format!("{}", io_err).starts_with("IO error:")); + + let parse_err = JsonlReadError::ParseError("bad".to_string()); + assert_eq!(format!("{}", parse_err), "Parse error: bad"); + } } diff --git a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs index 2973db8fd..4e812fc46 100644 --- a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs +++ b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs @@ -1106,6 +1106,86 @@ mod fixed_chunk_pq_table_test { pq_table.populate_chunk_distances(&query_vec, &mut aligned_pq_table_dist_scratch); assert!(result.is_err()); } + + /// Build a small two-chunk table with two pivots for the direct-distance tests. + /// + /// dim = 4, offsets = [0, 2, 4], pivot 0 = [1,2 | 3,4], pivot 1 = [5,6 | 7,8]. + fn small_table() -> FixedChunkPQTable { + let dim = 4; + let offsets = vec![0, 2, 4]; + let pq_table = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + FixedChunkPQTable::new(dim, pq_table.into(), offsets.into()).unwrap() + } + + #[test] + fn test_direct_distance_methods() { + let table = small_table(); + let query = [1.0f32, 1.0, 1.0, 1.0]; + // chunk0 -> pivot0 dims[0..2] = [1,2], chunk1 -> pivot1 dims[2..4] = [7,8] + let base = [0u8, 1u8]; + let reconstructed = [1.0f32, 2.0, 7.0, 8.0]; + let max_relative = 1.0e-6; + + assert_relative_eq!( + table.l2_distance(&query, &base), + SquaredL2::evaluate(&query[..], &reconstructed[..]), + max_relative = max_relative + ); + // `InnerProduct::evaluate` returns the negated inner product, while + // `inner_product_raw` returns the raw (positive) inner product. + let ip_eval: f32 = InnerProduct::evaluate(&query[..], &reconstructed[..]); + assert_relative_eq!( + table.inner_product_raw(&query, &base), + -ip_eval, + max_relative = max_relative + ); + // `inner_product` negates the raw inner product (matching the `evaluate` convention). + assert_relative_eq!( + table.inner_product(&query, &base), + ip_eval, + max_relative = max_relative + ); + // `cosine_distance` follows the same `1 - similarity` convention as `Cosine::evaluate`. + let cos = table.cosine_distance(&query, &base); + assert_relative_eq!( + cos, + distance::Cosine::evaluate(&query[..], &reconstructed[..]), + max_relative = max_relative + ); + // `cosine_normalized_distance` delegates to `cosine_distance`. + assert_relative_eq!( + table.cosine_normalized_distance(&query, &base), + cos, + max_relative = max_relative + ); + } + + #[test] + fn test_pq_dist_lookup_single() { + // 2 chunks, 3 centers each, row-major [chunk0 | chunk1]. + let num_centers = 3; + let precomputed = [10.0f32, 11.0, 12.0, 20.0, 21.0, 22.0]; + let coords = [1u8, 0u8]; + // distances[0*3 + 1] + distances[1*3 + 0] = 11 + 20 = 31. + assert_eq!( + pq_dist_lookup_single(&coords, &precomputed, num_centers), + 31.0 + ); + } + + #[test] + fn test_compress_into_and_inflate_roundtrip() { + let table = small_table(); + // A vector exactly equal to pivot 0 must select centroid index 0 for both chunks. + let full = [1.0f32, 2.0, 3.0, 4.0]; + let mut codes = [255u8; 2]; + table.compress_into(&full[..], &mut codes[..]).unwrap(); + assert_eq!(codes, [0, 0]); + + // Inflating those codes recovers pivot 0's values. + let inflated = table.inflate_vector(&codes); + assert_eq!(inflated, vec![1.0, 2.0, 3.0, 4.0]); + } } #[cfg(test)] mod pq_index_prune_query_test { diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index 4e0dd769e..6e062e6be 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -716,6 +716,87 @@ mod pq_test { assert_eq!(full_pivot_data.len(), 16); } + #[test] + fn generate_pq_pivots_membuf_error_paths() { + let num_train = 5; + let dim = 8; + let num_centers = 2; + let num_pq_chunks = 2; + let train_data: Vec = vec![1.0; num_train * dim]; + let args = + GeneratePivotArguments::new(num_train, dim, num_centers, num_pq_chunks, 5).unwrap(); + let pool = create_thread_pool_for_test(); + + // Wrong full_pivot_data size (not num_centers * dim). + let mut bad_pivot = vec![0.0f32; num_centers * dim + 1]; + let mut offsets = vec![0usize; num_pq_chunks + 1]; + assert!( + generate_pq_pivots_from_membuf( + &args, + &train_data, + &mut offsets, + &mut bad_pivot, + &mut crate::utils::create_rnd_in_tests(), + &mut false, + pool.as_ref(), + ) + .is_err() + ); + + // Wrong offsets buffer size (not num_pq_chunks + 1). + let mut full_pivot = vec![0.0f32; num_centers * dim]; + let mut bad_offsets = vec![0usize; num_pq_chunks]; + assert!( + generate_pq_pivots_from_membuf( + &args, + &train_data, + &mut bad_offsets, + &mut full_pivot, + &mut crate::utils::create_rnd_in_tests(), + &mut false, + pool.as_ref(), + ) + .is_err() + ); + + // Cancellation requested by caller. + let mut full_pivot = vec![0.0f32; num_centers * dim]; + let mut offsets = vec![0usize; num_pq_chunks + 1]; + assert!( + generate_pq_pivots_from_membuf( + &args, + &train_data, + &mut offsets, + &mut full_pivot, + &mut crate::utils::create_rnd_in_tests(), + &mut true, + pool.as_ref(), + ) + .is_err() + ); + } + + #[test] + fn generate_pq_data_from_pivots_from_membuf_invalid_pivots_errors() { + let dim = 4; + let num_pivots = 2; + // pivot_data length does not equal num_pivots * dim, so MatrixView::try_from fails. + let pivot_data = vec![0.0f32; num_pivots * dim + 1]; + let offsets = vec![0usize, 2, dim]; + let vector_data = vec![1.0f32; dim]; + let mut pq_out = vec![0u8; offsets.len() - 1]; + assert!( + generate_pq_data_from_pivots_from_membuf( + &vector_data, + &pivot_data, + num_pivots, + &offsets, + &mut pq_out, + ) + .is_err() + ); + } + #[test] fn read_pivot_metadata_existing_test() { // no real data except pivot data. diff --git a/diskann-providers/src/storage/pq_storage.rs b/diskann-providers/src/storage/pq_storage.rs index a2d49391b..7987f117d 100644 --- a/diskann-providers/src/storage/pq_storage.rs +++ b/diskann-providers/src/storage/pq_storage.rs @@ -405,6 +405,19 @@ mod pq_storage_tests { PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE)); } + #[test] + fn get_data_path_test() { + // With a data path, `get_data_path` returns it and the compressed path getter works. + let with_data = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE)); + assert_eq!(with_data.get_data_path().unwrap(), DATA_FILE); + assert_eq!(with_data.get_compressed_data_path(), PQ_COMPRESSED_PATH); + + // Without a data path, `get_data_path` returns a config error. + let without_data = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, None); + assert!(without_data.get_data_path().is_err()); + assert_eq!(without_data.get_compressed_data_path(), PQ_COMPRESSED_PATH); + } + #[test] fn write_compressed_pivot_metadata_test() { let storage_provider = VirtualStorageProvider::new_memory(); @@ -469,6 +482,54 @@ mod pq_storage_tests { assert_eq!(chunk_offsets.len(), 2); } + #[test] + fn load_existing_pivot_data_dimension_mismatch_errors() { + let storage_provider = VirtualStorageProvider::new_memory(); + let pivot_path = "/mismatch_pivots.bin"; + + let num_centers = 3; + let dim = 4; + let pivots: Vec = (0..num_centers * dim).map(|i| i as f32).collect(); + let chunk_offsets = vec![0, 2, dim]; + + let pq_storage = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, None); + pq_storage + .write_pivot_data( + &pivots, + None, + &chunk_offsets, + num_centers, + dim, + &storage_provider, + ) + .unwrap(); + + // Wrong number of centers triggers the pivots dimension-mismatch error. + assert!( + pq_storage + .load_existing_pivot_data(&2, &(num_centers + 1), &dim, &storage_provider) + .is_err() + ); + + // Wrong chunk count triggers the chunk-offsets dimension-mismatch error. + assert!( + pq_storage + .load_existing_pivot_data(&99, &num_centers, &dim, &storage_provider) + .is_err() + ); + } + + #[test] + fn load_pq_pivots_bin_missing_file_errors() { + let storage_provider = VirtualStorageProvider::new_memory(); + let pq_storage = PQStorage::new("/missing_pivots.bin", PQ_COMPRESSED_PATH, None); + assert!( + pq_storage + .load_pq_pivots_bin("/missing_pivots.bin", 2, &storage_provider) + .is_err() + ); + } + /// Write pivot data with `centroid = None`, read it back via /// `load_existing_pivot_data`, and verify the pivots are unchanged and the /// centroid is all zeros. diff --git a/diskann-providers/src/utils/timer.rs b/diskann-providers/src/utils/timer.rs index b81f7c9db..9fa0dafcf 100644 --- a/diskann-providers/src/utils/timer.rs +++ b/diskann-providers/src/utils/timer.rs @@ -221,4 +221,13 @@ mod timer_tests { let peak_memory_usage = timer.get_peak_memory_usage(); assert!(peak_memory_usage >= 0.0); } + + #[test] + fn test_elapsed_seconds_and_gcycles() { + let timer = Timer::new(); + assert!(timer.elapsed_seconds() >= 0.0); + // elapsed_gcycles returns 0.0 when cycle counters are unavailable, + // and a non-negative value otherwise. + assert!(timer.elapsed_gcycles() >= 0.0); + } } diff --git a/diskann-quantization/src/multi_vector/distance/factory.rs b/diskann-quantization/src/multi_vector/distance/factory.rs index 5dcd4b8cd..2b23ba633 100644 --- a/diskann-quantization/src/multi_vector/distance/factory.rs +++ b/diskann-quantization/src/multi_vector/distance/factory.rs @@ -617,6 +617,59 @@ mod tests { } } + /// Exercise every [`MaxSimIsa`] variant through the explicit match arms in + /// `MaxSimElement::build` (the existing agreement tests only drive `Auto` + /// and `Reference`). Supported ISAs must build + compute; unsupported ones + /// must surface a [`NotSupported`] naming the requested ISA. + fn check_all_isas(label: &str) + where + T: MaxSimElement + FromF32, + InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>, + { + let query_data = make_test_data::(3 * 8, 8, 0); + let doc_data = make_test_data::(2 * 8, 8, 1); + let query = make_mat(&query_data, 3, 8); + let doc = make_mat(&doc_data, 2, 8); + + for isa in [ + MaxSimIsa::Auto, + MaxSimIsa::Scalar, + MaxSimIsa::X86_64_V3, + MaxSimIsa::X86_64_V4, + MaxSimIsa::Neon, + MaxSimIsa::Reference, + ] { + let result = build_max_sim::(isa, query, BoxErase); + if isa.is_available() { + let kernel = result + .unwrap_or_else(|e| panic!("{label}{isa} available but build failed: {e}")); + assert_eq!(kernel.nrows(), 3, "{label}{isa} nrows"); + let mut scores = vec![0.0f32; 3]; + kernel + .compute_max_sim(doc, &mut scores) + .unwrap_or_else(|e| panic!("{label}{isa} compute failed: {e:?}")); + } else { + let err = result + .err() + .unwrap_or_else(|| panic!("{label}{isa} should be unsupported")); + assert_eq!(err.isa, isa, "{label}{isa} NotSupported.isa mismatch"); + assert!(!err.reason.is_empty(), "{label}{isa} reason empty"); + // Exercise the Display impl for the error path. + assert!(err.to_string().contains(&isa.to_string())); + } + } + } + + #[test] + fn all_isas_f32() { + check_all_isas::("f32 "); + } + + #[test] + fn all_isas_f16() { + check_all_isas::("f16 "); + } + macro_rules! test_matches_fallback { ($mod_name:ident, $ty:ty, $tol:expr, $label:literal) => { mod $mod_name { diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 573b0f05c..95620837c 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -39,6 +39,7 @@ serde_json.workspace = true [dev-dependencies] rstest.workspace = true vfs = { workspace = true } +tempfile.workspace = true diskann-providers = { workspace = true, default-features = false, features = [ "virtual_storage", ] } diff --git a/diskann-tools/src/bin/generate_minmax.rs b/diskann-tools/src/bin/generate_minmax.rs index 7ced9e2ef..f064ee2e5 100644 --- a/diskann-tools/src/bin/generate_minmax.rs +++ b/diskann-tools/src/bin/generate_minmax.rs @@ -193,3 +193,63 @@ where Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use diskann_utils::io::Metadata; + use std::io::Write; + use tempfile::TempDir; + + /// Write a standard `.bin` file (8-byte header + row-major f32 data). + fn write_f32_bin(path: &str, npts: usize, dim: usize) { + let mut f = File::create(path).unwrap(); + Metadata::new(npts, dim).unwrap().write(&mut f).unwrap(); + let data: Vec = (0..npts * dim).map(|i| (i % 17) as f32).collect(); + f.write_all(bytemuck::cast_slice(&data)).unwrap(); + f.flush().unwrap(); + } + + #[test] + fn dispatch_rejects_unsupported_bit_width() { + let err = + dispatch_process_file::(3, "unused_in.bin", "unused_out.bin", 1, 1.0).unwrap_err(); + assert!(err.to_string().contains("Unsupported bit width")); + } + + #[test] + fn process_file_quantizes_and_writes_header() { + let dir = TempDir::new().unwrap(); + let input = dir.path().join("in.bin"); + let output = dir.path().join("out.bin"); + let input_str = input.to_string_lossy().to_string(); + let output_str = output.to_string_lossy().to_string(); + + let npts = 16; + let dim = 8; + write_f32_bin(&input_str, npts, dim); + + process_file::<4, f32>(&input_str, &output_str, 42, 1.0).unwrap(); + + let mut r = File::open(&output).unwrap(); + let meta = Metadata::read(&mut r).unwrap(); + assert_eq!(meta.npoints(), npts); + // Output dim equals the per-vector byte count, which must be non-zero. + assert!(meta.ndims() > 0); + } + + #[test] + fn dispatch_process_file_supports_valid_widths() { + let dir = TempDir::new().unwrap(); + let input = dir.path().join("in.bin"); + let input_str = input.to_string_lossy().to_string(); + write_f32_bin(&input_str, 8, 8); + + for bits in [1u8, 2, 4, 8] { + let output = dir.path().join(format!("out_{}.bin", bits)); + let output_str = output.to_string_lossy().to_string(); + dispatch_process_file::(bits, &input_str, &output_str, 1, 1.0).unwrap(); + assert!(output.exists(), "output missing for {} bits", bits); + } + } +} diff --git a/diskann-tools/src/bin/subsample_bin.rs b/diskann-tools/src/bin/subsample_bin.rs index 6612ea91b..3e2a0a08f 100644 --- a/diskann-tools/src/bin/subsample_bin.rs +++ b/diskann-tools/src/bin/subsample_bin.rs @@ -115,3 +115,87 @@ fn main() -> Result<()> { DataType::Fp16 => run_for_type::(&args), } } + +#[cfg(test)] +mod tests { + use super::*; + use diskann_utils::io::Metadata; + use tempfile::TempDir; + + /// Write a standard `.bin` file (8-byte header + row-major f32 data). + fn write_f32_bin(path: &std::path::Path, npts: usize, dim: usize) { + let mut f = std::fs::File::create(path).unwrap(); + Metadata::new(npts, dim).unwrap().write(&mut f).unwrap(); + let data: Vec = (0..npts * dim).map(|i| i as f32).collect(); + f.write_all(bytemuck::cast_slice(&data)).unwrap(); + f.flush().unwrap(); + } + + fn args_for(input: PathBuf, output: PathBuf, probability: f64) -> Args { + Args { + data_type: DataType::Float, + base_bin_file: input, + sampled_output_file: output, + sampling_probability: probability, + random_seed: Some(7), + } + } + + #[test] + fn create_rng_is_deterministic_with_seed() { + let mut a = create_rng(Some(123)); + let mut b = create_rng(Some(123)); + let dist = StandardUniform; + let x: u64 = dist.sample(&mut a); + let y: u64 = dist.sample(&mut b); + assert_eq!(x, y); + } + + #[test] + fn create_rng_without_seed_produces_values() { + // Smoke test the seedless branch (non-deterministic, just exercise it). + let mut rng = create_rng(None); + let _: u64 = StandardUniform.sample(&mut rng); + } + + #[test] + fn run_for_type_samples_all_with_probability_one() { + let dir = TempDir::new().unwrap(); + let input = dir.path().join("in.bin"); + let output = dir.path().join("out.bin"); + write_f32_bin(&input, 12, 4); + + run_for_type::(&args_for(input, output.clone(), 1.0)).unwrap(); + + let mut r = std::fs::File::open(&output).unwrap(); + let (npts, dim) = Metadata::read(&mut r).unwrap().into_dims(); + assert_eq!(npts, 12); + assert_eq!(dim, 4); + } + + #[test] + fn run_for_type_samples_none_with_probability_zero() { + let dir = TempDir::new().unwrap(); + let input = dir.path().join("in.bin"); + let output = dir.path().join("out.bin"); + write_f32_bin(&input, 12, 4); + + run_for_type::(&args_for(input, output.clone(), 0.0)).unwrap(); + + let mut r = std::fs::File::open(&output).unwrap(); + let (npts, dim) = Metadata::read(&mut r).unwrap().into_dims(); + assert_eq!(npts, 0); + assert_eq!(dim, 4); + } + + #[test] + fn run_for_type_rejects_out_of_range_probability() { + let dir = TempDir::new().unwrap(); + let input = dir.path().join("in.bin"); + let output = dir.path().join("out.bin"); + write_f32_bin(&input, 4, 2); + + let err = run_for_type::(&args_for(input, output, 1.5)).unwrap_err(); + assert!(err.to_string().contains("sampling_probability")); + } +} diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 98e16713a..2f845b4b7 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -643,3 +643,488 @@ where Ok(neighbor_queues) } + +#[cfg(test)] +mod tests { + use super::*; + use diskann_disk::data_model::AdHoc; + use diskann_providers::storage::VirtualStorageProvider; + use std::io::Read; + + type GraphDataF32 = AdHoc; + + /// Write a `.bin` vector file: 8-byte header (npts, dim) then row-major f32 data. + fn write_vectors( + provider: &impl StorageWriteProvider, + path: &str, + dim: usize, + rows: &[Vec], + ) { + let mut w = provider.create_for_write(path).unwrap(); + Metadata::new(rows.len(), dim) + .unwrap() + .write(&mut w) + .unwrap(); + for row in rows { + w.write_all(cast_slice::(row)).unwrap(); + } + w.flush().unwrap(); + } + + /// Write a range-truthset (vector filters) file: header (npts, total_ids), + /// then `npts` i32 counts, then concatenated u32 ids. + fn write_filters(provider: &impl StorageWriteProvider, path: &str, rows: &[Vec]) { + let total: usize = rows.iter().map(|r| r.len()).sum(); + let mut w = provider.create_for_write(path).unwrap(); + Metadata::new(rows.len(), total) + .unwrap() + .write(&mut w) + .unwrap(); + for row in rows { + w.write_all(&(row.len() as i32).to_le_bytes()).unwrap(); + } + for row in rows { + w.write_all(cast_slice::(row)).unwrap(); + } + w.flush().unwrap(); + } + + /// Read back a standard ground-truth file: returns (npts, dim, ids, distances). + fn read_ground_truth( + provider: &impl StorageReadProvider, + path: &str, + ) -> (usize, usize, Vec, Vec) { + let mut f = provider.open_reader(path).unwrap(); + let (npts, dim) = Metadata::read(&mut f).unwrap().into_dims(); + let mut id_bytes = vec![0u8; npts * dim * size_of::()]; + f.read_exact(&mut id_bytes).unwrap(); + let mut dist_bytes = vec![0u8; npts * dim * size_of::()]; + f.read_exact(&mut dist_bytes).unwrap(); + ( + npts, + dim, + cast_slice::(&id_bytes).to_vec(), + cast_slice::(&dist_bytes).to_vec(), + ) + } + + #[test] + fn test_compute_ground_truth_basic() { + let provider = VirtualStorageProvider::new_memory(); + // 4 base points on a line; single query at origin. + write_vectors( + &provider, + "/base.bin", + 2, + &[ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![2.0, 0.0], + vec![10.0, 0.0], + ], + ); + write_vectors(&provider, "/query.bin", 2, &[vec![0.0, 0.0]]); + + compute_ground_truth_from_datafiles::( + &provider, + Metric::L2, + "/base.bin", + "/query.bin", + "/gt.bin", + None, + 2, + None, + None, + None, + None, + None, + ) + .unwrap(); + + let (npts, dim, ids, dists) = read_ground_truth(&provider, "/gt.bin"); + assert_eq!(npts, 1); + assert_eq!(dim, 2); + // Closest two points to origin are ids 0 and 1. + assert_eq!(ids, vec![0, 1]); + assert_eq!(dists[0], 0.0); + assert_eq!(dists[1], 1.0); + } + + #[test] + fn test_compute_ground_truth_with_skip_base() { + let provider = VirtualStorageProvider::new_memory(); + write_vectors( + &provider, + "/base.bin", + 2, + &[vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]], + ); + write_vectors(&provider, "/query.bin", 2, &[vec![0.0, 0.0]]); + + // Skip the first base point; nearest remaining is id 1. + compute_ground_truth_from_datafiles::( + &provider, + Metric::L2, + "/base.bin", + "/query.bin", + "/gt.bin", + None, + 1, + None, + Some(1), + None, + None, + None, + ) + .unwrap(); + + let (_, _, ids, _) = read_ground_truth(&provider, "/gt.bin"); + // After skipping the first base point, remaining points are re-indexed + // from 0, so the nearest (original [1,0]) is reported as id 0. + assert_eq!(ids, vec![0]); + } + + #[test] + fn test_compute_ground_truth_with_insert_file() { + let provider = VirtualStorageProvider::new_memory(); + write_vectors(&provider, "/base.bin", 2, &[vec![5.0, 0.0]]); + // The inserted vector is the true nearest neighbor; it gets id 1. + write_vectors(&provider, "/insert.bin", 2, &[vec![0.0, 0.0]]); + write_vectors(&provider, "/query.bin", 2, &[vec![0.0, 0.0]]); + + compute_ground_truth_from_datafiles::( + &provider, + Metric::L2, + "/base.bin", + "/query.bin", + "/gt.bin", + None, + 1, + Some("/insert.bin"), + None, + None, + None, + None, + ) + .unwrap(); + + let (_, _, ids, dists) = read_ground_truth(&provider, "/gt.bin"); + assert_eq!(ids, vec![1]); + assert_eq!(dists[0], 0.0); + } + + #[test] + fn test_compute_ground_truth_with_vector_filters() { + let provider = VirtualStorageProvider::new_memory(); + write_vectors( + &provider, + "/base.bin", + 2, + &[ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![2.0, 0.0], + vec![3.0, 0.0], + ], + ); + write_vectors(&provider, "/query.bin", 2, &[vec![0.0, 0.0]]); + // Restrict the single query to base ids {2, 3}. + write_filters(&provider, "/filters.bin", &[vec![2, 3]]); + + compute_ground_truth_from_datafiles::( + &provider, + Metric::L2, + "/base.bin", + "/query.bin", + "/gt.bin", + Some("/filters.bin"), + 2, + None, + None, + None, + None, + None, + ) + .unwrap(); + + // Filtered output uses the range-search format: header then per-query + // counts then ids. Nearest allowed point to origin is id 2. + let mut f = provider.open_reader("/gt.bin").unwrap(); + let (num_queries, _) = Metadata::read(&mut f).unwrap().into_dims(); + assert_eq!(num_queries, 1); + let mut count_bytes = vec![0u8; num_queries * size_of::()]; + f.read_exact(&mut count_bytes).unwrap(); + let counts = cast_slice::(&count_bytes); + assert_eq!(counts, &[2]); + let mut id_bytes = vec![0u8; 2 * size_of::()]; + f.read_exact(&mut id_bytes).unwrap(); + let ids = cast_slice::(&id_bytes); + assert_eq!(ids[0], 2); + } + + #[test] + fn test_error_only_base_labels_provided() { + let provider = VirtualStorageProvider::new_memory(); + write_vectors(&provider, "/base.bin", 2, &[vec![0.0, 0.0]]); + write_vectors(&provider, "/query.bin", 2, &[vec![0.0, 0.0]]); + + let err = compute_ground_truth_from_datafiles::( + &provider, + Metric::L2, + "/base.bin", + "/query.bin", + "/gt.bin", + None, + 1, + None, + None, + None, + Some("/base_labels.txt"), + None, + ) + .unwrap_err(); + assert!(err + .details + .contains("must be provided or both must be not provided")); + } + + #[test] + fn test_error_base_labels_and_vector_filters() { + let provider = VirtualStorageProvider::new_memory(); + write_vectors(&provider, "/base.bin", 2, &[vec![0.0, 0.0]]); + write_vectors(&provider, "/query.bin", 2, &[vec![0.0, 0.0]]); + + let err = compute_ground_truth_from_datafiles::( + &provider, + Metric::L2, + "/base.bin", + "/query.bin", + "/gt.bin", + Some("/filters.bin"), + 1, + None, + None, + None, + Some("/base_labels.txt"), + Some("/query_labels.txt"), + ) + .unwrap_err(); + assert!(err + .details + .contains("base_file_labels and vector_filters_file cannot be provided")); + } + + /// Write a multivec `.bin` file: header (num_points, dim, total_results), + /// then `num_points` u32 per-point vector counts, then concatenated + /// row-major f32 data (each point is `count * dim` values). + fn write_multivec( + provider: &impl StorageWriteProvider, + path: &str, + dim: usize, + points: &[Vec>], + ) { + let total: usize = points.iter().map(|p| p.len()).sum(); + let mut w = provider.create_for_write(path).unwrap(); + w.write_all(&(points.len() as u32).to_le_bytes()).unwrap(); + w.write_all(&(dim as u32).to_le_bytes()).unwrap(); + w.write_all(&(total as u32).to_le_bytes()).unwrap(); + for p in points { + w.write_all(&(p.len() as u32).to_le_bytes()).unwrap(); + } + for p in points { + for row in p { + w.write_all(cast_slice::(row)).unwrap(); + } + } + w.flush().unwrap(); + } + + #[test] + fn test_compute_multivec_ground_truth_basic() { + let provider = VirtualStorageProvider::new_memory(); + // Three base "points", each a single vector on a line. + write_multivec( + &provider, + "/mbase.bin", + 2, + &[ + vec![vec![0.0, 0.0]], + vec![vec![1.0, 0.0]], + vec![vec![5.0, 0.0]], + ], + ); + // One query point with a single vector at the origin. + write_multivec(&provider, "/mquery.bin", 2, &[vec![vec![0.0, 0.0]]]); + + compute_multivec_ground_truth_from_datafiles::( + &provider, + Metric::L2, + MultivecAggregationMethod::MinPairwise, + "/mbase.bin", + "/mquery.bin", + "/mgt.bin", + 2, + None, + None, + ) + .unwrap(); + + let (npts, dim, ids, _) = read_ground_truth(&provider, "/mgt.bin"); + assert_eq!(npts, 1); + assert_eq!(dim, 2); + // Closest two base points to origin are ids 0 and 1. + assert_eq!(ids, vec![0, 1]); + } + + #[test] + fn test_compute_multivec_aggregation_methods() { + for method in [ + MultivecAggregationMethod::AveragePairwise, + MultivecAggregationMethod::MinPairwise, + MultivecAggregationMethod::AvgofMins, + ] { + let provider = VirtualStorageProvider::new_memory(); + // Base points are multi-vector sets. + write_multivec( + &provider, + "/mbase.bin", + 2, + &[ + vec![vec![0.0, 0.0], vec![0.5, 0.0]], + vec![vec![9.0, 0.0], vec![10.0, 0.0]], + ], + ); + write_multivec(&provider, "/mquery.bin", 2, &[vec![vec![0.0, 0.0]]]); + + compute_multivec_ground_truth_from_datafiles::( + &provider, + Metric::L2, + method.clone(), + "/mbase.bin", + "/mquery.bin", + "/mgt.bin", + 1, + None, + None, + ) + .unwrap(); + + let (_, _, ids, _) = read_ground_truth(&provider, "/mgt.bin"); + // The nearest base point to the origin is always id 0. + assert_eq!(ids, vec![0], "method {:?} picked wrong neighbor", method); + } + } + + #[test] + fn test_compute_multivec_error_only_one_label_file() { + let provider = VirtualStorageProvider::new_memory(); + write_multivec(&provider, "/mbase.bin", 2, &[vec![vec![0.0, 0.0]]]); + write_multivec(&provider, "/mquery.bin", 2, &[vec![vec![0.0, 0.0]]]); + + let err = compute_multivec_ground_truth_from_datafiles::( + &provider, + Metric::L2, + MultivecAggregationMethod::MinPairwise, + "/mbase.bin", + "/mquery.bin", + "/mgt.bin", + 1, + Some("/base_labels.txt"), + None, + ) + .unwrap_err(); + assert!(err + .details + .contains("must be provided or both must be not provided")); + } + + #[test] + fn test_aggregation_method_from_str() { + assert!(matches!( + "average_pairwise".parse::(), + Ok(MultivecAggregationMethod::AveragePairwise) + )); + assert!(matches!( + "MIN_PAIRWISE".parse::(), + Ok(MultivecAggregationMethod::MinPairwise) + )); + assert!(matches!( + "avg_of_mins".parse::(), + Ok(MultivecAggregationMethod::AvgofMins) + )); + + let err = "nope".parse::().unwrap_err(); + assert_eq!( + format!("{}", err), + "Invalid format for Aggregation Method: nope" + ); + } + + #[test] + fn test_compute_multivec_ground_truth_with_labels() { + let provider = VirtualStorageProvider::new_memory(); + // 3 base points, 2 query points (single vector each). + write_multivec( + &provider, + "/mbase.bin", + 2, + &[ + vec![vec![0.0, 0.0]], + vec![vec![1.0, 0.0]], + vec![vec![2.0, 0.0]], + ], + ); + write_multivec( + &provider, + "/mquery.bin", + 2, + &[vec![vec![0.0, 0.0]], vec![vec![2.0, 0.0]]], + ); + + // Base/query label files live on the real filesystem (read via std::fs). + let dir = tempfile::TempDir::new().unwrap(); + let base_labels = dir.path().join("base.jsonl"); + let query_labels = dir.path().join("query.jsonl"); + { + let mut f = std::fs::File::create(&base_labels).unwrap(); + writeln!(f, r#"{{"doc_id": 0, "g": "a"}}"#).unwrap(); + writeln!(f, r#"{{"doc_id": 1, "g": "b"}}"#).unwrap(); + writeln!(f, r#"{{"doc_id": 2, "g": "a"}}"#).unwrap(); + } + { + let mut f = std::fs::File::create(&query_labels).unwrap(); + writeln!(f, r#"{{"query_id": 0, "filter": {{"g": {{"$eq": "a"}}}}}}"#).unwrap(); + writeln!(f, r#"{{"query_id": 1, "filter": {{"g": {{"$eq": "a"}}}}}}"#).unwrap(); + } + + compute_multivec_ground_truth_from_datafiles::( + &provider, + Metric::L2, + MultivecAggregationMethod::MinPairwise, + "/mbase.bin", + "/mquery.bin", + "/mgt_range.bin", + 2, + Some(base_labels.to_str().unwrap()), + Some(query_labels.to_str().unwrap()), + ) + .unwrap(); + + // Range-search GT format: header (num_queries, total_neighbors), + // then num_queries u32 queue sizes, then total_neighbors u32 ids. + let mut f = provider.open_reader("/mgt_range.bin").unwrap(); + let (num_queries, total_neighbors) = Metadata::read(&mut f).unwrap().into_dims(); + assert_eq!(num_queries, 2); + + let mut size_bytes = vec![0u8; num_queries * size_of::()]; + f.read_exact(&mut size_bytes).unwrap(); + let sizes = cast_slice::(&size_bytes).to_vec(); + assert_eq!(sizes.iter().sum::() as usize, total_neighbors); + + let mut id_bytes = vec![0u8; total_neighbors * size_of::()]; + f.read_exact(&mut id_bytes).unwrap(); + let ids = cast_slice::(&id_bytes).to_vec(); + // Only base docs with g=="a" (ids 0 and 2) may appear in the results. + assert!(ids.iter().all(|&id| id == 0 || id == 2), "ids: {:?}", ids); + } +} diff --git a/diskann-tools/src/utils/search_index_utils.rs b/diskann-tools/src/utils/search_index_utils.rs index 176f2e1ec..295c346cd 100644 --- a/diskann-tools/src/utils/search_index_utils.rs +++ b/diskann-tools/src/utils/search_index_utils.rs @@ -774,4 +774,121 @@ mod test_search_index_utils { assert_eq!(recall.get_k(), 5); assert_eq!(recall.get_n(), 5); } + + use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider}; + use std::io::Write as _; + + /// Write a truthset bin file: 8-byte header (npts, dim), then `npts*dim` + /// u32 ids, optionally followed by `npts*dim` f32 distances. + fn write_truthset( + provider: &impl StorageWriteProvider, + path: &str, + npts: usize, + dim: usize, + ids: &[u32], + dists: Option<&[f32]>, + ) { + let mut w = provider.create_for_write(path).unwrap(); + Metadata::new(npts, dim).unwrap().write(&mut w).unwrap(); + w.write_all(cast_slice::(ids)).unwrap(); + if let Some(dists) = dists { + w.write_all(cast_slice::(dists)).unwrap(); + } + w.flush().unwrap(); + } + + #[test] + fn test_load_truthset_with_distances() { + let provider = VirtualStorageProvider::new_memory(); + let path = "/truthset_with_dists.bin"; + let ids: Vec = vec![0, 1, 2, 3, 4, 5]; + let dists: Vec = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; + write_truthset(&provider, path, 2, 3, &ids, Some(&dists)); + + let ts = load_truthset(&provider, path).unwrap(); + assert_eq!(ts.index_num_points, 2); + assert_eq!(ts.index_dimension, 3); + assert_eq!(ts.index_nodes, ids); + assert_eq!(ts.distances, Some(dists)); + } + + #[test] + fn test_load_truthset_ids_only() { + let provider = VirtualStorageProvider::new_memory(); + let path = "/truthset_ids_only.bin"; + let ids: Vec = vec![10, 11, 12, 13]; + write_truthset(&provider, path, 2, 2, &ids, None); + + let ts = load_truthset(&provider, path).unwrap(); + assert_eq!(ts.index_num_points, 2); + assert_eq!(ts.index_dimension, 2); + assert_eq!(ts.index_nodes, ids); + assert!(ts.distances.is_none()); + } + + #[test] + fn test_load_truthset_size_mismatch_errors() { + let provider = VirtualStorageProvider::new_memory(); + let path = "/truthset_bad_size.bin"; + // Header claims 2x3 but only one id follows -> neither expected size. + let mut w = provider.create_for_write(path).unwrap(); + Metadata::new(2usize, 3usize) + .unwrap() + .write(&mut w) + .unwrap(); + w.write_all(cast_slice::(&[0u32])).unwrap(); + w.flush().unwrap(); + drop(w); + + let err = match load_truthset(&provider, path) { + Ok(_) => panic!("expected size mismatch error"), + Err(e) => e, + }; + assert!(format!("{}", err).contains("File size mismatch")); + } + + /// Write a range-truthset bin file: 8-byte header (npts, total_ids), + /// then `npts` i32 per-query counts, then the concatenated u32 ids. + fn write_range_truthset(provider: &impl StorageWriteProvider, path: &str, rows: &[Vec]) { + let total_ids: usize = rows.iter().map(|r| r.len()).sum(); + let mut w = provider.create_for_write(path).unwrap(); + Metadata::new(rows.len(), total_ids) + .unwrap() + .write(&mut w) + .unwrap(); + for row in rows { + w.write_all(&(row.len() as i32).to_le_bytes()).unwrap(); + } + for row in rows { + w.write_all(cast_slice::(row)).unwrap(); + } + w.flush().unwrap(); + } + + #[test] + fn test_load_range_truthset() { + let provider = VirtualStorageProvider::new_memory(); + let path = "/range_truthset.bin"; + let rows = vec![vec![1u32, 2, 3], vec![4], vec![7, 8]]; + write_range_truthset(&provider, path, &rows); + + let ts = load_range_truthset(&provider, path).unwrap(); + assert_eq!(ts.index_num_points, 3); + assert_eq!(ts.index_dimensions, vec![3, 1, 2]); + assert_eq!(ts.index_nodes, rows); + assert!(ts.distances.is_none()); + } + + #[test] + fn test_load_vector_filters() { + let provider = VirtualStorageProvider::new_memory(); + let path = "/vector_filters.bin"; + let rows = vec![vec![1u32, 2, 2, 3], vec![5]]; + write_range_truthset(&provider, path, &rows); + + let filters = load_vector_filters(&provider, path).unwrap(); + assert_eq!(filters.len(), 2); + assert_eq!(filters[0], HashSet::from([1, 2, 3])); + assert_eq!(filters[1], HashSet::from([5])); + } } diff --git a/vectorset/Cargo.toml b/vectorset/Cargo.toml index 7d5df0ba6..ca768a0be 100644 --- a/vectorset/Cargo.toml +++ b/vectorset/Cargo.toml @@ -27,3 +27,6 @@ tokio = { workspace = true, features = [ "sync", ] } toml = "0.9.7" + +[dev-dependencies] +tempfile.workspace = true diff --git a/vectorset/src/loader.rs b/vectorset/src/loader.rs index 9e0f7a856..0150ba351 100644 --- a/vectorset/src/loader.rs +++ b/vectorset/src/loader.rs @@ -1,133 +1,281 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use anyhow::{Result, anyhow}; -use std::{collections::HashMap, marker::PhantomData, mem, path::Path, sync::Arc}; -use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; - -const BATCH_SIZE: usize = 1024; - -/// Loads base or query vectors from a given path and allow iteration over them. -pub struct DatasetLoader { - file: Mutex<(File, usize)>, - num_vectors: usize, - dim: usize, - type_: PhantomData, -} - -impl DatasetLoader { - pub async fn new + Clone>(path: P) -> Result> { - let path = path.as_ref().to_path_buf(); - - // Calculate total vectors in all paths - let mut file = File::open(&path).await?; - let num_vectors = file.read_i32_le().await? as usize; - let dim = file.read_i32_le().await? as usize; - - Ok(Arc::new(Self { - file: Mutex::new((file, 0)), - num_vectors, - dim, - type_: PhantomData, - })) - } - - pub fn dim(&self) -> usize { - self.dim - } - - pub fn len(&self) -> usize { - self.num_vectors - } - - pub fn batch_size(&self) -> usize { - BATCH_SIZE - } - - /// Load the next vectors into `buffer`. - /// - /// Returns (count, first_id) where `count` is the number of vectors loaded - /// and `first_id` is the id of the first vector. - pub async fn next(&self, buffer: &mut Vec) -> Result<(usize, usize)> { - let mut f = self.file.lock().await; - - let mut count; - let mut first_id; - loop { - first_id = f.1; - if f.1 >= self.num_vectors { - buffer.clear(); - return Ok((0, first_id)); - } - - buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); - - let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); - while let bytes_read = f.0.read(buf).await? - && bytes_read > 0 - { - buf = &mut buf[bytes_read..]; - } - - let elements_left = buf.len() / mem::size_of::(); - if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { - return Err(anyhow!("unexpected EOF")); - } - - count = BATCH_SIZE - elements_left / self.dim; - - if count == 0 { - continue; - } - - break; - } - - f.1 += count; - - Ok((count, first_id)) - } - - /// Load the entire dataset into a Vec. - pub async fn load>(path: P) -> Result>>> { - let mut file = File::open(path).await?; - let num_vectors = file.read_i32_le().await? as usize; - let dim = file.read_i32_le().await? as usize; - - let mut vectors = Vec::with_capacity(num_vectors); - for _ in 0..num_vectors { - let mut v = vec![T::zeroed(); dim]; - file.read_exact(bytemuck::cast_slice_mut(&mut v)).await?; - vectors.push(v); - } - Ok(Arc::new(vectors)) - } - - pub async fn load_groundtruth>( - path: P, - ) -> Result>>> { - let mut file = File::open(&path).await?; - - let num_queries = file.read_i32_le().await? as usize; - let num_neighbors = file.read_i32_le().await? as usize; - - let mut nbuf = vec![0u32; num_queries * num_neighbors]; - let mut dbuf = vec![0f32; num_queries * num_neighbors]; - file.read_exact(bytemuck::cast_slice_mut(&mut nbuf)).await?; - file.read_exact(bytemuck::cast_slice_mut(&mut dbuf)).await?; - - let id_dists: Vec<(u32, f32)> = nbuf.iter().copied().zip(dbuf.iter().copied()).collect(); - - let mut map = HashMap::with_capacity(num_queries); - - for i in 0..num_queries { - let start = i * num_neighbors; - let end = start + num_neighbors; - map.insert(i as u32, id_dists[start..end].to_vec()); - } - - Ok(Arc::new(map)) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use anyhow::{Result, anyhow}; +use std::{collections::HashMap, marker::PhantomData, mem, path::Path, sync::Arc}; +use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; + +const BATCH_SIZE: usize = 1024; + +/// Loads base or query vectors from a given path and allow iteration over them. +pub struct DatasetLoader { + file: Mutex<(File, usize)>, + num_vectors: usize, + dim: usize, + type_: PhantomData, +} + +impl DatasetLoader { + pub async fn new + Clone>(path: P) -> Result> { + let path = path.as_ref().to_path_buf(); + + // Calculate total vectors in all paths + let mut file = File::open(&path).await?; + let num_vectors = file.read_i32_le().await? as usize; + let dim = file.read_i32_le().await? as usize; + + Ok(Arc::new(Self { + file: Mutex::new((file, 0)), + num_vectors, + dim, + type_: PhantomData, + })) + } + + pub fn dim(&self) -> usize { + self.dim + } + + pub fn len(&self) -> usize { + self.num_vectors + } + + pub fn batch_size(&self) -> usize { + BATCH_SIZE + } + + /// Load the next vectors into `buffer`. + /// + /// Returns (count, first_id) where `count` is the number of vectors loaded + /// and `first_id` is the id of the first vector. + pub async fn next(&self, buffer: &mut Vec) -> Result<(usize, usize)> { + let mut f = self.file.lock().await; + + let mut count; + let mut first_id; + loop { + first_id = f.1; + if f.1 >= self.num_vectors { + buffer.clear(); + return Ok((0, first_id)); + } + + buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); + + let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); + while let bytes_read = f.0.read(buf).await? + && bytes_read > 0 + { + buf = &mut buf[bytes_read..]; + } + + let elements_left = buf.len() / mem::size_of::(); + if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { + return Err(anyhow!("unexpected EOF")); + } + + count = BATCH_SIZE - elements_left / self.dim; + + if count == 0 { + continue; + } + + break; + } + + f.1 += count; + + Ok((count, first_id)) + } + + /// Load the entire dataset into a Vec. + pub async fn load>(path: P) -> Result>>> { + let mut file = File::open(path).await?; + let num_vectors = file.read_i32_le().await? as usize; + let dim = file.read_i32_le().await? as usize; + + let mut vectors = Vec::with_capacity(num_vectors); + for _ in 0..num_vectors { + let mut v = vec![T::zeroed(); dim]; + file.read_exact(bytemuck::cast_slice_mut(&mut v)).await?; + vectors.push(v); + } + Ok(Arc::new(vectors)) + } + + pub async fn load_groundtruth>( + path: P, + ) -> Result>>> { + let mut file = File::open(&path).await?; + + let num_queries = file.read_i32_le().await? as usize; + let num_neighbors = file.read_i32_le().await? as usize; + + let mut nbuf = vec![0u32; num_queries * num_neighbors]; + let mut dbuf = vec![0f32; num_queries * num_neighbors]; + file.read_exact(bytemuck::cast_slice_mut(&mut nbuf)).await?; + file.read_exact(bytemuck::cast_slice_mut(&mut dbuf)).await?; + + let id_dists: Vec<(u32, f32)> = nbuf.iter().copied().zip(dbuf.iter().copied()).collect(); + + let mut map = HashMap::with_capacity(num_queries); + + for i in 0..num_queries { + let start = i * num_neighbors; + let end = start + num_neighbors; + map.insert(i as u32, id_dists[start..end].to_vec()); + } + + Ok(Arc::new(map)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use std::path::PathBuf; + use tempfile::TempDir; + + /// Write a `.bin` dataset file (i32 num_vectors, i32 dim, then row-major data). + fn write_bin( + dir: &TempDir, + name: &str, + dim: usize, + rows: &[Vec], + ) -> PathBuf { + let path = dir.path().join(name); + let mut f = std::fs::File::create(&path).unwrap(); + f.write_all(&(rows.len() as i32).to_le_bytes()).unwrap(); + f.write_all(&(dim as i32).to_le_bytes()).unwrap(); + for row in rows { + f.write_all(bytemuck::cast_slice(row)).unwrap(); + } + path + } + + /// Write raw bytes after a header (used for malformed-file tests). + fn write_raw(dir: &TempDir, name: &str, num_vectors: i32, dim: i32, body: &[u8]) -> PathBuf { + let path = dir.path().join(name); + let mut f = std::fs::File::create(&path).unwrap(); + f.write_all(&num_vectors.to_le_bytes()).unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(body).unwrap(); + path + } + + #[tokio::test] + async fn new_reads_header() { + let dir = TempDir::new().unwrap(); + let rows = vec![vec![1.0f32, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]; + let path = write_bin(&dir, "base.bin", 2, &rows); + + let loader = DatasetLoader::::new(&path).await.unwrap(); + assert_eq!(loader.len(), 3); + assert_eq!(loader.dim(), 2); + assert_eq!(loader.batch_size(), BATCH_SIZE); + } + + #[tokio::test] + async fn next_reads_all_vectors_in_order() { + let dir = TempDir::new().unwrap(); + let rows = vec![vec![1.0f32, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]; + let path = write_bin(&dir, "base.bin", 2, &rows); + let loader = DatasetLoader::::new(&path).await.unwrap(); + + let mut buf = Vec::new(); + let (count, first_id) = loader.next(&mut buf).await.unwrap(); + assert_eq!(count, 3); + assert_eq!(first_id, 0); + assert_eq!(&buf[..count * 2], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + // Subsequent call yields nothing and clears the buffer. + let (count, first_id) = loader.next(&mut buf).await.unwrap(); + assert_eq!(count, 0); + assert_eq!(first_id, 3); + assert!(buf.is_empty()); + } + + #[tokio::test] + async fn next_spans_multiple_batches() { + let dir = TempDir::new().unwrap(); + let total = BATCH_SIZE + 5; + let rows: Vec> = (0..total).map(|i| vec![i as f32]).collect(); + let path = write_bin(&dir, "base.bin", 1, &rows); + let loader = DatasetLoader::::new(&path).await.unwrap(); + + let mut buf = Vec::new(); + let (count, first_id) = loader.next(&mut buf).await.unwrap(); + assert_eq!(count, BATCH_SIZE); + assert_eq!(first_id, 0); + assert_eq!(buf[0], 0.0); + + let (count, first_id) = loader.next(&mut buf).await.unwrap(); + assert_eq!(count, 5); + assert_eq!(first_id, BATCH_SIZE); + assert_eq!(buf[0], BATCH_SIZE as f32); + + let (count, _) = loader.next(&mut buf).await.unwrap(); + assert_eq!(count, 0); + } + + #[tokio::test] + async fn next_on_empty_dataset() { + let dir = TempDir::new().unwrap(); + let path = write_bin::(&dir, "empty.bin", 4, &[]); + let loader = DatasetLoader::::new(&path).await.unwrap(); + + let mut buf = Vec::new(); + let (count, first_id) = loader.next(&mut buf).await.unwrap(); + assert_eq!(count, 0); + assert_eq!(first_id, 0); + assert!(buf.is_empty()); + } + + #[tokio::test] + async fn next_errors_on_truncated_vector() { + let dir = TempDir::new().unwrap(); + // Header claims 2 vectors of dim 3, but only 4 f32 values follow + // (one full vector plus a partial second one). + let body: Vec = [1.0f32, 2.0, 3.0, 4.0] + .iter() + .flat_map(|v| v.to_le_bytes()) + .collect(); + let path = write_raw(&dir, "bad.bin", 2, 3, &body); + let loader = DatasetLoader::::new(&path).await.unwrap(); + + let mut buf = Vec::new(); + assert!(loader.next(&mut buf).await.is_err()); + } + + #[tokio::test] + async fn load_returns_all_vectors() { + let dir = TempDir::new().unwrap(); + let rows = vec![vec![1.0f32, 2.0, 3.0], vec![4.0, 5.0, 6.0]]; + let path = write_bin(&dir, "base.bin", 3, &rows); + + let loaded = DatasetLoader::::load(&path).await.unwrap(); + assert_eq!(*loaded, rows); + } + + #[tokio::test] + async fn load_groundtruth_builds_map() { + let dir = TempDir::new().unwrap(); + let num_queries = 2i32; + let num_neighbors = 2i32; + let ids: [u32; 4] = [10, 11, 20, 21]; + let dists: [f32; 4] = [0.1, 0.2, 0.3, 0.4]; + + let mut body = Vec::new(); + body.extend(bytemuck::cast_slice::(&ids)); + body.extend(bytemuck::cast_slice::(&dists)); + let path = write_raw(&dir, "gt.bin", num_queries, num_neighbors, &body); + + let map = DatasetLoader::::load_groundtruth(&path).await.unwrap(); + assert_eq!(map.len(), 2); + assert_eq!(map[&0], vec![(10, 0.1), (11, 0.2)]); + assert_eq!(map[&1], vec![(20, 0.3), (21, 0.4)]); + } +}