diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs index 067bd2d21..25f5679e3 100644 --- a/diskann/src/graph/misc.rs +++ b/diskann/src/graph/misc.rs @@ -33,30 +33,18 @@ pub enum InplaceDeleteMethod { // Parameters for diverse search #[cfg(feature = "experimental_diversity_search")] -#[derive(Clone, Debug)] -pub struct DiverseSearchParams

-where - P: crate::neighbor::AttributeValueProvider, -{ +#[derive(Clone, Copy, Debug)] +pub struct DiverseSearchParams { pub diverse_attribute_id: usize, pub diverse_results_k: usize, - pub attribute_provider: std::sync::Arc

, } #[cfg(feature = "experimental_diversity_search")] -impl

DiverseSearchParams

-where - P: crate::neighbor::AttributeValueProvider, -{ - pub fn new( - diverse_attribute_id: usize, - diverse_results_k: usize, - attribute_provider: std::sync::Arc

, - ) -> Self { +impl DiverseSearchParams { + pub fn new(diverse_attribute_id: usize, diverse_results_k: usize) -> Self { Self { diverse_attribute_id, diverse_results_k, - attribute_provider, } } } diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index bc5b7de47..622ce7b17 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -18,7 +18,7 @@ use crate::{ index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, - neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue}, + neighbor::{DiverseId, DiverseNeighborQueue, NeighborQueue}, provider::DataProvider, }; @@ -26,22 +26,16 @@ use crate::{ /// /// Returns results that are diverse across a specified attribute. #[derive(Debug)] -pub struct Diverse

-where - P: AttributeValueProvider, -{ +pub struct Diverse { /// Base k-NN search parameters. inner: Knn, /// Diversity-specific parameters. - diverse_params: DiverseSearchParams

, + diverse_params: DiverseSearchParams, } -impl

Diverse

-where - P: AttributeValueProvider, -{ +impl Diverse { /// Create new diverse search parameters. - pub fn new(inner: Knn, diverse_params: DiverseSearchParams

) -> Self { + pub fn new(inner: Knn, diverse_params: DiverseSearchParams) -> Self { Self { inner, diverse_params, @@ -56,7 +50,7 @@ where /// Returns a reference to the diversity-specific parameters. #[inline] - pub fn diverse_params(&self) -> &DiverseSearchParams

{ + pub fn diverse_params(&self) -> &DiverseSearchParams { &self.diverse_params } @@ -64,17 +58,15 @@ where fn create_scratch( &self, index: &DiskANNIndex, - ) -> SearchScratch> + ) -> SearchScratch> where DP: DataProvider, - P: AttributeValueProvider, + DP::InternalId: DiverseId, { - let attribute_provider = self.diverse_params.attribute_provider.clone(); let diverse_queue = DiverseNeighborQueue::new( self.inner.l_value().get(), self.inner.k_value(), self.diverse_params.diverse_results_k, - attribute_provider, ); SearchScratch { @@ -92,12 +84,12 @@ where } } -impl<'a, DP, S, T, P> Search<'a, DP, S, T> for Diverse

+impl<'a, DP, S, T> Search<'a, DP, S, T> for Diverse where DP: DataProvider, + DP::InternalId: DiverseId, T: Copy + Send + Sync, S: SearchStrategy<'a, DP, T, SearchAccessor: SearchAccessor>, - P: AttributeValueProvider, { type Output = SearchStats; diff --git a/diskann/src/neighbor/diverse_priority_queue.rs b/diskann/src/neighbor/diverse_priority_queue.rs index 93f355500..2cec871b2 100644 --- a/diskann/src/neighbor/diverse_priority_queue.rs +++ b/diskann/src/neighbor/diverse_priority_queue.rs @@ -8,7 +8,6 @@ use std::{ fmt::{Debug, Display}, hash::Hash, num::NonZeroUsize, - sync::Arc, }; use crate::neighbor::{ @@ -25,7 +24,23 @@ pub trait Attribute: Hash + Eq + Copy + Default + Debug + Display + Send + Sync // Blanket implementation: any type satisfying these bounds automatically implements Attribute impl Attribute for T where T: Hash + Eq + Copy + Default + Debug + Display + Send + Sync {} -/// A wrapper type for (VectorIdType, attribute) tuples to implement required traits +/// Trait for neighbor ids that carry their own diversity attribute. +/// +/// Instead of looking up attributes through an external provider, the attribute +/// is obtained directly from the id. Ids that do not have an attribute return +/// `None` and are skipped by [`DiverseNeighborQueue`]. +pub trait DiverseId: NeighborPriorityQueueIdType + Hash { + /// The attribute value type carried by this id. + type Attribute: Attribute; + + /// Get the attribute value carried by this id, or `None` if it has none. + fn attribute(&self) -> Option; +} + +/// A wrapper type pairing a vector id with an attribute value. +/// +/// This is a convenience [`DiverseId`] implementation for cases where the +/// attribute is stored alongside the vector id. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub struct VectorIdWithAttribute where @@ -41,7 +56,7 @@ where I: NeighborPriorityQueueIdType, A: Attribute, { - fn new(id: I, attribute: A) -> Self { + pub fn new(id: I, attribute: A) -> Self { Self { id, attribute } } } @@ -56,48 +71,52 @@ where } } +impl DiverseId for VectorIdWithAttribute +where + I: NeighborPriorityQueueIdType + Hash, + A: Attribute, +{ + type Attribute = A; + + fn attribute(&self) -> Option { + Some(self.attribute) + } +} + /// A diverse neighbor priority queue that wraps a standard NeighborPriorityQueue /// and delegates all operations to it. This provides a foundation for implementing /// diversity-aware search algorithms while maintaining the same interface. /// -/// This struct serves as a wrapper around NeighborPriorityQueue and can be extended -/// in the future to implement diversity constraints or other specialized behaviors. +/// The queue is generic over an id type `I` that carries its own diversity +/// attribute via the [`DiverseId`] trait, so no external attribute provider is +/// required. #[derive(Debug, Clone)] -pub struct DiverseNeighborQueue

+pub struct DiverseNeighborQueue where - P: AttributeValueProvider, + I: DiverseId, { - /// The underlying priority queue that handles all core operations - /// Stores VectorIdWithAttribute which contains (VectorIdType, attribute_value) - global_queue: NeighborPriorityQueue>, - /// Map from attribute_id to local neighbor priority queue - local_queue_map: HashMap>, - /// Attribute value provider for managing diversity attributes - attribute_provider: Arc

, + /// The underlying priority queue that handles all core operations. + global_queue: NeighborPriorityQueue, + /// Map from attribute value to local neighbor priority queue. + local_queue_map: HashMap>, /// The calculated diverse_results_l for local queues diverse_results_l: usize, /// The target number of diverse results (k_value for diversity) diverse_results_k: usize, } -impl

DiverseNeighborQueue

+impl DiverseNeighborQueue where - P: AttributeValueProvider, + I: DiverseId, { /// Create a new DiverseNeighborQueue with the specified capacity. /// /// This will implicitly set `l_value` to the provided capacity. - pub fn new( - l_value: usize, - k_value: NonZeroUsize, - diverse_results_k: usize, - attribute_provider: Arc

, - ) -> Self { + pub fn new(l_value: usize, k_value: NonZeroUsize, diverse_results_k: usize) -> Self { let diverse_results_l = diverse_results_k * l_value / k_value.get(); Self { global_queue: NeighborPriorityQueue::new(l_value), local_queue_map: HashMap::new(), - attribute_provider, diverse_results_l, diverse_results_k, } @@ -133,29 +152,29 @@ where // Step 2: Compact global queue using the filter if !removed_items.is_empty() { self.global_queue - .retain(|neighbor| !removed_items.contains(&neighbor.id.id)); + .retain(|neighbor| !removed_items.contains(&neighbor.id)); } } } -impl

NeighborQueue for DiverseNeighborQueue

+impl NeighborQueue for DiverseNeighborQueue where - P: AttributeValueProvider, + I: DiverseId, { type Iter<'a> - = BestCandidatesIterator<'a, P::Id, Self> + = BestCandidatesIterator<'a, I, Self> where Self: 'a, - P::Id: 'a; + I: 'a; - fn insert(&mut self, nbr: Neighbor) { - // Get the attribute value for the current neighbor. + fn insert(&mut self, nbr: Neighbor) { + // Get the attribute value carried by the neighbor's id. // We explicitly skip neighbors without attributes (returning None) rather than using // unwrap_or_default(), because using a default value would conflate "missing attribute" // with "attribute value 0" (or whatever the default is). This could violate diversity // constraints by incorrectly grouping neighbors without attributes together with // neighbors that legitimately have the default attribute value. - let Some(attribute_value) = self.attribute_provider.get(nbr.id) else { + let Some(attribute_value) = nbr.id.attribute() else { return; }; @@ -168,71 +187,52 @@ where let local_queue_full = local_queue.is_full(); let global_queue_full = self.global_queue.is_full(); - // Create a neighbor with VectorIdWithAttribute for global_queue - let nbr_with_attribute = Neighbor::new( - VectorIdWithAttribute::new(nbr.id, attribute_value), - nbr.distance, - ); - if !local_queue_full && !global_queue_full { // Case 1: Both local queue and global queue have space local_queue.insert(nbr); - self.global_queue.insert(nbr_with_attribute); + self.global_queue.insert(nbr); } else if local_queue_full { // Case 2: Local queue is full if nbr.distance < local_queue.get(self.diverse_results_l - 1).distance { // Get the worst neighbor in the local queue let worst_neighbor = local_queue.get(self.diverse_results_l - 1); - // Create the corresponding neighbor with attribute for removal from global queue - let worst_neighbor_with_attribute = Neighbor::new( - VectorIdWithAttribute::new(worst_neighbor.id, attribute_value), - worst_neighbor.distance, - ); // Remove worst neighbor from global queue using the remove method - self.global_queue.remove(worst_neighbor_with_attribute); + self.global_queue.remove(worst_neighbor); // Insert new neighbor into both queues local_queue.insert(nbr); - self.global_queue.insert(nbr_with_attribute); + self.global_queue.insert(nbr); } } else if !local_queue_full && global_queue_full { // Case 3: Local queue has space but global queue is full let l_size = self.global_queue.search_l(); if nbr.distance < self.global_queue.get(l_size - 1).distance { let worst_global = self.global_queue.get(l_size - 1); - // Extract the attribute from VectorIdWithAttribute - let attribute_of_worst_global = worst_global.id.attribute; + // The attribute of the worst global neighbor comes from its id. + let attribute_of_worst_global = worst_global.id.attribute(); // Insert new neighbor into both queues local_queue.insert(nbr); - self.global_queue.insert(nbr_with_attribute); + self.global_queue.insert(nbr); // Remove worst neighbor from its local queue - if let Some(local_queue) = self.local_queue_map.get_mut(&attribute_of_worst_global) + if let Some(attribute_of_worst_global) = attribute_of_worst_global + && let Some(local_queue) = + self.local_queue_map.get_mut(&attribute_of_worst_global) { - let worst_neighbor_without_attribute = - Neighbor::new(worst_global.id.id, worst_global.distance); - local_queue.remove(worst_neighbor_without_attribute); + local_queue.remove(worst_global); } } } } - fn get(&self, index: usize) -> Neighbor { - let neighbor_with_attribute = self.global_queue.get(index); - Neighbor::new( - neighbor_with_attribute.id.id, - neighbor_with_attribute.distance, - ) + fn get(&self, index: usize) -> Neighbor { + self.global_queue.get(index) } - fn closest_notvisited(&mut self) -> Option> { - let neighbor_with_attribute = self.global_queue.closest_notvisited()?; - Some(Neighbor::new( - neighbor_with_attribute.id.id, - neighbor_with_attribute.distance, - )) + fn closest_notvisited(&mut self) -> Option> { + self.global_queue.closest_notvisited() } fn has_notvisited_node(&self) -> bool { @@ -256,89 +256,68 @@ where self.local_queue_map.clear(); } - fn iter(&self) -> BestCandidatesIterator<'_, P::Id, Self> { + fn iter(&self) -> BestCandidatesIterator<'_, I, Self> { let sz = self.global_queue.search_l().min(self.global_queue.size()); BestCandidatesIterator::new(sz, self) } } -/// Trait for providing attribute values for vector IDs. -/// Implementations of this trait can be used with diverse search to retrieve -/// attribute values for vectors during search operations. -pub trait AttributeValueProvider: crate::provider::HasId + Send + Sync + std::fmt::Debug { - type Value: Attribute; - - /// Get the attribute value for a given vector ID. - /// - /// # Arguments - /// * `id` - The vector ID - /// - /// # Returns - /// * `Option` - The attribute value if it exists, None otherwise - fn get(&self, id: Self::Id) -> Option; -} - #[cfg(test)] mod diverse_priority_queue_test { use super::*; - /// A test attribute value provider that stores attribute values for vector IDs. - /// This is a simple in-memory store using a HashMap for testing purposes. - #[derive(Debug, Clone)] - struct TestAttributeValueProvider { - /// Map from vector_id to attribute value - attributes: HashMap, + /// A test id type that carries an optional attribute value, used to exercise + /// both the "has attribute" and "missing attribute" code paths. + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] + struct TestId { + id: u32, + attribute: Option, } - impl TestAttributeValueProvider { - /// Create a new empty TestAttributeValueProvider. - fn new() -> Self { - Self { - attributes: HashMap::new(), - } + impl TestId { + fn new(id: u32, attribute: Option) -> Self { + Self { id, attribute } } + } - /// Insert an attribute value for a given vector ID. - fn insert(&mut self, vector_id: u32, attribute_value: u32) { - self.attributes.insert(vector_id, attribute_value); + impl Display for TestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id) } } - impl crate::provider::HasId for TestAttributeValueProvider { - type Id = u32; + impl DiverseId for TestId { + type Attribute = u32; + + fn attribute(&self) -> Option { + self.attribute + } } - impl AttributeValueProvider for TestAttributeValueProvider { - type Value = u32; + // Type alias for tests to make them more readable + type TestDiverseQueue = DiverseNeighborQueue; - fn get(&self, id: Self::Id) -> Option { - self.attributes.get(&id).copied() - } + /// Build a neighbor whose id carries the attribute `id / 3`. + /// + /// This mirrors the original test setup where vectors 0-2 had attribute 0, + /// 3-5 had attribute 1, and so on. + fn nbr(id: u32, distance: f32) -> Neighbor { + Neighbor::new(TestId::new(id, Some(id / 3)), distance) } - impl Default for TestAttributeValueProvider { - fn default() -> Self { - Self::new() - } + /// Build a neighbor whose id carries an explicit attribute value. + fn nbr_attr(id: u32, attribute: u32, distance: f32) -> Neighbor { + Neighbor::new(TestId::new(id, Some(attribute)), distance) } - // Type alias for tests to make them more readable - type TestDiverseQueue = DiverseNeighborQueue; - - /// Helper function to create a test attribute provider wrapped in Arc - fn create_test_attribute_provider() -> Arc { - let mut provider = TestAttributeValueProvider::new(); - // Set up attributes: vectors 0-2 have attribute 0, vectors 3-5 have attribute 1, etc. - for i in 0..20 { - provider.insert(i, i / 3); - } - Arc::new(provider) + /// Build a neighbor whose id has no attribute. + fn nbr_none(id: u32, distance: f32) -> Neighbor { + Neighbor::new(TestId::new(id, None), distance) } #[test] fn test_new() { - let attribute_provider = create_test_attribute_provider(); - let queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); assert_eq!(queue.size(), 0); assert_eq!(queue.capacity(), 10); @@ -348,14 +327,12 @@ mod diverse_priority_queue_test { #[test] fn test_insert_single_attribute() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); // Insert neighbors with IDs 0, 1, 2 (all have attribute 0) - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(1, 0.5)); - queue.insert(Neighbor::new(2, 1.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(1, 0.5)); + queue.insert(nbr(2, 1.5)); assert_eq!(queue.size(), 3); assert_eq!(queue.local_queue_map.len(), 1); @@ -364,14 +341,12 @@ mod diverse_priority_queue_test { #[test] fn test_insert_multiple_attributes() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); // Insert neighbors with different attributes - queue.insert(Neighbor::new(0, 1.0)); // attribute 0 - queue.insert(Neighbor::new(3, 0.8)); // attribute 1 - queue.insert(Neighbor::new(6, 1.2)); // attribute 2 + queue.insert(nbr(0, 1.0)); // attribute 0 + queue.insert(nbr(3, 0.8)); // attribute 1 + queue.insert(nbr(6, 1.2)); // attribute 2 assert_eq!(queue.size(), 3); assert_eq!(queue.local_queue_map.len(), 3); @@ -382,111 +357,94 @@ mod diverse_priority_queue_test { #[test] fn test_insert_maintains_order() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(1, 0.5)); - queue.insert(Neighbor::new(2, 1.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(1, 0.5)); + queue.insert(nbr(2, 1.5)); // Neighbors should be sorted by distance - assert_eq!(queue.get(0).id, 1); // distance 0.5 - assert_eq!(queue.get(1).id, 0); // distance 1.0 - assert_eq!(queue.get(2).id, 2); // distance 1.5 + assert_eq!(queue.get(0).id.id, 1); // distance 0.5 + assert_eq!(queue.get(1).id.id, 0); // distance 1.0 + assert_eq!(queue.get(2).id.id, 2); // distance 1.5 } #[test] fn test_insert_local_queue_full() { - let mut attribute_provider = TestAttributeValueProvider::new(); - // All IDs 10-15 have the same attribute (attribute 0) - for i in 10..=15 { - attribute_provider.insert(i, 0); - } // l_value=20, k_value=20, diverse_results_k=3 => diverse_results_l = 3 * 20 / 20 = 3 - let mut queue = TestDiverseQueue::new( - 20, - NonZeroUsize::new(20).unwrap(), - 3, - Arc::new(attribute_provider), - ); + // All IDs 10-13 share attribute 0. + let mut queue = TestDiverseQueue::new(20, NonZeroUsize::new(20).unwrap(), 3); // Fill up the local queue for attribute 0 (diverse_results_l = 3) - queue.insert(Neighbor::new(10, 1.0)); - queue.insert(Neighbor::new(11, 0.8)); - queue.insert(Neighbor::new(12, 1.2)); + queue.insert(nbr_attr(10, 0, 1.0)); + queue.insert(nbr_attr(11, 0, 0.8)); + queue.insert(nbr_attr(12, 0, 1.2)); assert_eq!(queue.size(), 3); assert_eq!(queue.local_queue_map[&0].size(), 3); // Try to insert a better neighbor with same attribute (different ID) - queue.insert(Neighbor::new(13, 0.5)); // Better distance, should replace worst + queue.insert(nbr_attr(13, 0, 0.5)); // Better distance, should replace worst assert_eq!(queue.size(), 3); // Size should remain same - assert_eq!(queue.get(0).id, 13); // Best is now the new one with distance 0.5 + assert_eq!(queue.get(0).id.id, 13); // Best is now the new one with distance 0.5 } #[test] fn test_insert_inner_queue_full() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(3, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(3, NonZeroUsize::new(5).unwrap(), 5); // Fill up inner queue (capacity = 3) - queue.insert(Neighbor::new(0, 1.0)); // attribute 0 - queue.insert(Neighbor::new(3, 0.8)); // attribute 1 - queue.insert(Neighbor::new(6, 1.2)); // attribute 2 + queue.insert(nbr(0, 1.0)); // attribute 0 + queue.insert(nbr(3, 0.8)); // attribute 1 + queue.insert(nbr(6, 1.2)); // attribute 2 assert_eq!(queue.size(), 3); // Insert a better neighbor with a new attribute - queue.insert(Neighbor::new(9, 0.5)); // attribute 3, better distance + queue.insert(nbr(9, 0.5)); // attribute 3, better distance assert_eq!(queue.size(), 3); - assert_eq!(queue.get(0).id, 9); // Best should be the new one + assert_eq!(queue.get(0).id.id, 9); // Best should be the new one } #[test] fn test_get() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(1, 0.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(1, 0.5)); - let nbr = queue.get(0); - assert_eq!(nbr.id, 1); - assert_eq!(nbr.distance, 0.5); + let n = queue.get(0); + assert_eq!(n.id.id, 1); + assert_eq!(n.distance, 0.5); - let nbr = queue.get(1); - assert_eq!(nbr.id, 0); - assert_eq!(nbr.distance, 1.0); + let n = queue.get(1); + assert_eq!(n.id.id, 0); + assert_eq!(n.distance, 1.0); } #[test] fn test_closest_notvisited() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(1, 0.5)); - queue.insert(Neighbor::new(2, 1.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(1, 0.5)); + queue.insert(nbr(2, 1.5)); assert!(queue.has_notvisited_node()); - let nbr = queue.closest_notvisited().unwrap(); - assert_eq!(nbr.id, 1); // Best unvisited - assert_eq!(nbr.distance, 0.5); + let n = queue.closest_notvisited().unwrap(); + assert_eq!(n.id.id, 1); // Best unvisited + assert_eq!(n.distance, 0.5); assert!(queue.has_notvisited_node()); - let nbr = queue.closest_notvisited().unwrap(); - assert_eq!(nbr.id, 0); // Next best + let n = queue.closest_notvisited().unwrap(); + assert_eq!(n.id.id, 0); // Next best - let nbr = queue.closest_notvisited().unwrap(); - assert_eq!(nbr.id, 2); // Last one + let n = queue.closest_notvisited().unwrap(); + assert_eq!(n.id.id, 2); // Last one assert!(!queue.has_notvisited_node()); assert!(queue.closest_notvisited().is_none()); @@ -494,13 +452,11 @@ mod diverse_priority_queue_test { #[test] fn test_has_notvisited_node() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); assert!(!queue.has_notvisited_node()); - queue.insert(Neighbor::new(0, 1.0)); + queue.insert(nbr(0, 1.0)); assert!(queue.has_notvisited_node()); assert!(queue.closest_notvisited().is_some()); @@ -510,42 +466,36 @@ mod diverse_priority_queue_test { #[test] fn test_size() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); assert_eq!(queue.size(), 0); - queue.insert(Neighbor::new(0, 1.0)); + queue.insert(nbr(0, 1.0)); assert_eq!(queue.size(), 1); - queue.insert(Neighbor::new(1, 0.5)); + queue.insert(nbr(1, 0.5)); assert_eq!(queue.size(), 2); } #[test] fn test_capacity() { - let attribute_provider = create_test_attribute_provider(); - let queue = TestDiverseQueue::new(15, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let queue = TestDiverseQueue::new(15, NonZeroUsize::new(5).unwrap(), 5); assert_eq!(queue.capacity(), 15); } #[test] fn test_search_l() { - let attribute_provider = create_test_attribute_provider(); - let queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 5); assert_eq!(queue.search_l(), 20); } #[test] fn test_clear() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(3, 0.5)); - queue.insert(Neighbor::new(6, 1.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(3, 0.5)); + queue.insert(nbr(6, 1.5)); assert_eq!(queue.size(), 3); assert_eq!(queue.local_queue_map.len(), 3); @@ -558,29 +508,25 @@ mod diverse_priority_queue_test { #[test] fn test_iter_candidates() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(1, 0.5)); - queue.insert(Neighbor::new(2, 1.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(1, 0.5)); + queue.insert(nbr(2, 1.5)); let candidates: Vec<_> = queue.iter().collect(); assert_eq!(candidates.len(), 3); - assert_eq!(candidates[0].id, 1); - assert_eq!(candidates[1].id, 0); - assert_eq!(candidates[2].id, 2); + assert_eq!(candidates[0].id.id, 1); + assert_eq!(candidates[1].id.id, 0); + assert_eq!(candidates[2].id.id, 2); } #[test] fn test_inner_and_inner_mut() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); + queue.insert(nbr(0, 1.0)); // Test direct access to global_queue assert_eq!(queue.global_queue.size(), 1); @@ -595,112 +541,55 @@ mod diverse_priority_queue_test { let vid_attr = VectorIdWithAttribute::new(42u32, 7); assert_eq!(vid_attr.id, 42); assert_eq!(vid_attr.attribute, 7); + assert_eq!(vid_attr.attribute(), Some(7)); let formatted = format!("{}", vid_attr); assert_eq!(formatted, "(42, 7)"); } - #[test] - fn test_attribute_value_provider() { - let mut provider = TestAttributeValueProvider::new(); - - assert_eq!(provider.get(0), None); - - provider.insert(0, 10); - assert_eq!(provider.get(0), Some(10)); - - provider.insert(5, 20); - assert_eq!(provider.get(5), Some(20)); - - // Update existing value - provider.insert(0, 15); - assert_eq!(provider.get(0), Some(15)); - } - - #[test] - fn test_attribute_value_provider_default() { - let provider = TestAttributeValueProvider::default(); - assert_eq!(provider.get(0), None); - } - #[test] fn test_diverse_queue_complex_scenario() { - let attribute_provider = create_test_attribute_provider(); - let mut queue = - TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 3, attribute_provider); + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 3); // Insert multiple neighbors with different attributes - queue.insert(Neighbor::new(0, 1.0)); // attribute 0 - queue.insert(Neighbor::new(1, 0.5)); // attribute 0 - queue.insert(Neighbor::new(2, 1.5)); // attribute 0 - queue.insert(Neighbor::new(3, 0.8)); // attribute 1 - queue.insert(Neighbor::new(4, 1.2)); // attribute 1 - queue.insert(Neighbor::new(6, 0.7)); // attribute 2 + queue.insert(nbr(0, 1.0)); // attribute 0 + queue.insert(nbr(1, 0.5)); // attribute 0 + queue.insert(nbr(2, 1.5)); // attribute 0 + queue.insert(nbr(3, 0.8)); // attribute 1 + queue.insert(nbr(4, 1.2)); // attribute 1 + queue.insert(nbr(6, 0.7)); // attribute 2 assert_eq!(queue.size(), 6); - // Try to add more to attribute 0 when its local queue is full (use unique ID 17) - // ID 17 has attribute 5, so let's use ID 15 which has attribute 5 but we'll manually set attribute 0 - let mut attribute_provider_updated = TestAttributeValueProvider::new(); - for i in 0..20 { - attribute_provider_updated.insert(i, i / 3); - } - attribute_provider_updated.insert(17, 0); // Set ID 17 to have attribute 0 - - // Create a new queue with updated provider - let mut queue2 = TestDiverseQueue::new( - 10, - NonZeroUsize::new(5).unwrap(), - 3, - Arc::new(attribute_provider_updated), - ); - queue2.insert(Neighbor::new(0, 1.0)); // attribute 0 - queue2.insert(Neighbor::new(1, 0.5)); // attribute 0 - queue2.insert(Neighbor::new(2, 1.5)); // attribute 0 - queue2.insert(Neighbor::new(3, 0.8)); // attribute 1 - queue2.insert(Neighbor::new(4, 1.2)); // attribute 1 - queue2.insert(Neighbor::new(6, 0.7)); // attribute 2 + // Insert ID 17 with attribute 0 and a better distance; it should become the best. + queue.insert(nbr_attr(17, 0, 0.3)); - // Now insert ID 17 with attribute 0 and better distance - queue2.insert(Neighbor::new(17, 0.3)); // Should replace worst in attribute 0 - - // Verify best neighbor is from the new insertion - assert_eq!(queue2.get(0).id, 17); - assert_eq!(queue2.get(0).distance, 0.3); + assert_eq!(queue.get(0).id.id, 17); + assert_eq!(queue.get(0).distance, 0.3); } #[test] fn test_post_process() { - let mut attribute_provider = TestAttributeValueProvider::new(); - // Set up attributes: IDs 0-2 have attribute 0, IDs 3-5 have attribute 1, IDs 6-8 have attribute 2 - for i in 0..9 { - attribute_provider.insert(i, i / 3); - } - // Create queue with l_value=20, k_value=5, diverse_results_k=2 - // This gives diverse_results_l = 2 * 20 / 5 = 8 - let mut queue = TestDiverseQueue::new( - 20, - NonZeroUsize::new(5).unwrap(), - 2, - Arc::new(attribute_provider), - ); + // This gives diverse_results_l = 2 * 20 / 5 = 8. + // IDs 0-2 have attribute 0, IDs 3-5 have attribute 1, IDs 6-8 have attribute 2. + let mut queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 2); // Insert more than diverse_results_k items for each attribute // Attribute 0 - queue.insert(Neighbor::new(0, 1.0)); - queue.insert(Neighbor::new(1, 0.5)); - queue.insert(Neighbor::new(2, 1.5)); + queue.insert(nbr(0, 1.0)); + queue.insert(nbr(1, 0.5)); + queue.insert(nbr(2, 1.5)); // Attribute 1 - queue.insert(Neighbor::new(3, 0.8)); - queue.insert(Neighbor::new(4, 1.2)); - queue.insert(Neighbor::new(5, 0.6)); + queue.insert(nbr(3, 0.8)); + queue.insert(nbr(4, 1.2)); + queue.insert(nbr(5, 0.6)); // Attribute 2 - queue.insert(Neighbor::new(6, 0.7)); - queue.insert(Neighbor::new(7, 1.1)); - queue.insert(Neighbor::new(8, 0.9)); + queue.insert(nbr(6, 0.7)); + queue.insert(nbr(7, 1.1)); + queue.insert(nbr(8, 0.9)); // Before post_process, we should have all 9 items assert_eq!(queue.size(), 9); @@ -721,58 +610,44 @@ mod diverse_priority_queue_test { // Verify the best items from each attribute are kept // Attribute 0: best are ID 1 (0.5) and ID 0 (1.0), worst ID 2 (1.5) should be removed - assert_eq!(queue.local_queue_map[&0].get(0).id, 1); + assert_eq!(queue.local_queue_map[&0].get(0).id.id, 1); assert_eq!(queue.local_queue_map[&0].get(0).distance, 0.5); - assert_eq!(queue.local_queue_map[&0].get(1).id, 0); + assert_eq!(queue.local_queue_map[&0].get(1).id.id, 0); assert_eq!(queue.local_queue_map[&0].get(1).distance, 1.0); // Attribute 1: best are ID 5 (0.6) and ID 3 (0.8), worst ID 4 (1.2) should be removed - assert_eq!(queue.local_queue_map[&1].get(0).id, 5); + assert_eq!(queue.local_queue_map[&1].get(0).id.id, 5); assert_eq!(queue.local_queue_map[&1].get(0).distance, 0.6); - assert_eq!(queue.local_queue_map[&1].get(1).id, 3); + assert_eq!(queue.local_queue_map[&1].get(1).id.id, 3); assert_eq!(queue.local_queue_map[&1].get(1).distance, 0.8); // Attribute 2: best are ID 6 (0.7) and ID 8 (0.9), worst ID 7 (1.1) should be removed - assert_eq!(queue.local_queue_map[&2].get(0).id, 6); + assert_eq!(queue.local_queue_map[&2].get(0).id.id, 6); assert_eq!(queue.local_queue_map[&2].get(0).distance, 0.7); - assert_eq!(queue.local_queue_map[&2].get(1).id, 8); + assert_eq!(queue.local_queue_map[&2].get(1).id.id, 8); assert_eq!(queue.local_queue_map[&2].get(1).distance, 0.9); // Verify global queue has the correct items in sorted order - assert_eq!(queue.get(0).id, 1); // 0.5 - assert_eq!(queue.get(1).id, 5); // 0.6 - assert_eq!(queue.get(2).id, 6); // 0.7 - assert_eq!(queue.get(3).id, 3); // 0.8 - assert_eq!(queue.get(4).id, 8); // 0.9 - assert_eq!(queue.get(5).id, 0); // 1.0 + assert_eq!(queue.get(0).id.id, 1); // 0.5 + assert_eq!(queue.get(1).id.id, 5); // 0.6 + assert_eq!(queue.get(2).id.id, 6); // 0.7 + assert_eq!(queue.get(3).id.id, 3); // 0.8 + assert_eq!(queue.get(4).id.id, 8); // 0.9 + assert_eq!(queue.get(5).id.id, 0); // 1.0 } #[test] fn test_skip_neighbors_without_attributes() { // Test that neighbors without attributes are silently skipped - // rather than being conflated with attribute value 0 - let mut attribute_provider = TestAttributeValueProvider::new(); - - // Set up some vectors with attributes - attribute_provider.insert(0, 0); // ID 0 has attribute 0 - attribute_provider.insert(1, 0); // ID 1 has attribute 0 - attribute_provider.insert(2, 1); // ID 2 has attribute 1 - // ID 3 has no attribute (not in the map) - attribute_provider.insert(4, 0); // ID 4 has attribute 0 - - let mut queue = TestDiverseQueue::new( - 10, - NonZeroUsize::new(5).unwrap(), - 5, - Arc::new(attribute_provider), - ); + // rather than being conflated with attribute value 0. + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); // Insert neighbors, including one without an attribute - queue.insert(Neighbor::new(0, 1.0)); // Has attribute 0 - queue.insert(Neighbor::new(1, 0.5)); // Has attribute 0 - queue.insert(Neighbor::new(2, 0.8)); // Has attribute 1 - queue.insert(Neighbor::new(3, 0.3)); // No attribute - should be skipped - queue.insert(Neighbor::new(4, 1.2)); // Has attribute 0 + queue.insert(nbr_attr(0, 0, 1.0)); // Has attribute 0 + queue.insert(nbr_attr(1, 0, 0.5)); // Has attribute 0 + queue.insert(nbr_attr(2, 1, 0.8)); // Has attribute 1 + queue.insert(nbr_none(3, 0.3)); // No attribute - should be skipped + queue.insert(nbr_attr(4, 0, 1.2)); // Has attribute 0 // Queue should only contain 4 items (ID 3 was skipped) assert_eq!(queue.size(), 4, "Expected 4 items, ID 3 should be skipped"); @@ -792,7 +667,7 @@ mod diverse_priority_queue_test { ); // Verify ID 3 (without attribute) is not in the queue - let ids: Vec = queue.iter().map(|n| n.id).collect(); + let ids: Vec = queue.iter().map(|n| n.id.id).collect(); assert!(!ids.contains(&3), "ID 3 should not be in the queue"); assert_eq!( ids, @@ -803,25 +678,12 @@ mod diverse_priority_queue_test { #[test] fn test_attribute_zero_vs_missing_attribute() { - // Verify that attribute value 0 is distinct from missing attributes - let mut attribute_provider = TestAttributeValueProvider::new(); - - // ID 0 explicitly has attribute 0 - attribute_provider.insert(0, 0); - // ID 1 has no attribute (missing) - // ID 2 explicitly has attribute 0 - attribute_provider.insert(2, 0); - - let mut queue = TestDiverseQueue::new( - 10, - NonZeroUsize::new(5).unwrap(), - 5, - Arc::new(attribute_provider), - ); + // Verify that attribute value 0 is distinct from missing attributes. + let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5); - queue.insert(Neighbor::new(0, 1.0)); // Has attribute 0 - queue.insert(Neighbor::new(1, 0.5)); // Missing attribute - should be skipped - queue.insert(Neighbor::new(2, 0.8)); // Has attribute 0 + queue.insert(nbr_attr(0, 0, 1.0)); // Has attribute 0 + queue.insert(nbr_none(1, 0.5)); // Missing attribute - should be skipped + queue.insert(nbr_attr(2, 0, 0.8)); // Has attribute 0 // Only IDs 0 and 2 should be in the queue assert_eq!(queue.size(), 2); @@ -830,7 +692,7 @@ mod diverse_priority_queue_test { assert_eq!(queue.local_queue_map[&0].size(), 2); // Verify ID 1 is not in the queue - let ids: Vec = queue.iter().map(|n| n.id).collect(); + let ids: Vec = queue.iter().map(|n| n.id.id).collect(); assert_eq!(ids, vec![2, 0], "Queue should only contain IDs 2 and 0"); } } diff --git a/diskann/src/neighbor/mod.rs b/diskann/src/neighbor/mod.rs index 391a5435e..7755c14d8 100644 --- a/diskann/src/neighbor/mod.rs +++ b/diskann/src/neighbor/mod.rs @@ -16,7 +16,7 @@ pub use queue::{NeighborPriorityQueue, NeighborPriorityQueueIdType, NeighborQueu mod diverse_priority_queue; #[cfg(feature = "experimental_diversity_search")] pub use diverse_priority_queue::{ - Attribute, AttributeValueProvider, DiverseNeighborQueue, VectorIdWithAttribute, + Attribute, DiverseId, DiverseNeighborQueue, VectorIdWithAttribute, }; //////////////