diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs
index 067bd2d21..25f5679e3 100644
--- a/diskann/src/graph/misc.rs
+++ b/diskann/src/graph/misc.rs
@@ -33,30 +33,18 @@ pub enum InplaceDeleteMethod {
// Parameters for diverse search
#[cfg(feature = "experimental_diversity_search")]
-#[derive(Clone, Debug)]
-pub struct DiverseSearchParams
-where
- P: crate::neighbor::AttributeValueProvider,
-{
+#[derive(Clone, Copy, Debug)]
+pub struct DiverseSearchParams {
pub diverse_attribute_id: usize,
pub diverse_results_k: usize,
- pub attribute_provider: std::sync::Arc
,
}
#[cfg(feature = "experimental_diversity_search")]
-impl
DiverseSearchParams
-where
- P: crate::neighbor::AttributeValueProvider,
-{
- pub fn new(
- diverse_attribute_id: usize,
- diverse_results_k: usize,
- attribute_provider: std::sync::Arc
,
- ) -> Self {
+impl DiverseSearchParams {
+ pub fn new(diverse_attribute_id: usize, diverse_results_k: usize) -> Self {
Self {
diverse_attribute_id,
diverse_results_k,
- attribute_provider,
}
}
}
diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs
index bc5b7de47..622ce7b17 100644
--- a/diskann/src/graph/search/diverse_search.rs
+++ b/diskann/src/graph/search/diverse_search.rs
@@ -18,7 +18,7 @@ use crate::{
index::{DiskANNIndex, SearchStats},
search_output_buffer::SearchOutputBuffer,
},
- neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue},
+ neighbor::{DiverseId, DiverseNeighborQueue, NeighborQueue},
provider::DataProvider,
};
@@ -26,22 +26,16 @@ use crate::{
///
/// Returns results that are diverse across a specified attribute.
#[derive(Debug)]
-pub struct Diverse
-where
- P: AttributeValueProvider,
-{
+pub struct Diverse {
/// Base k-NN search parameters.
inner: Knn,
/// Diversity-specific parameters.
- diverse_params: DiverseSearchParams
,
+ diverse_params: DiverseSearchParams,
}
-impl
Diverse
-where
- P: AttributeValueProvider,
-{
+impl Diverse {
/// Create new diverse search parameters.
- pub fn new(inner: Knn, diverse_params: DiverseSearchParams
) -> Self {
+ pub fn new(inner: Knn, diverse_params: DiverseSearchParams) -> Self {
Self {
inner,
diverse_params,
@@ -56,7 +50,7 @@ where
/// Returns a reference to the diversity-specific parameters.
#[inline]
- pub fn diverse_params(&self) -> &DiverseSearchParams
{
+ pub fn diverse_params(&self) -> &DiverseSearchParams {
&self.diverse_params
}
@@ -64,17 +58,15 @@ where
fn create_scratch(
&self,
index: &DiskANNIndex,
- ) -> SearchScratch>
+ ) -> SearchScratch>
where
DP: DataProvider,
- P: AttributeValueProvider,
+ DP::InternalId: DiverseId,
{
- let attribute_provider = self.diverse_params.attribute_provider.clone();
let diverse_queue = DiverseNeighborQueue::new(
self.inner.l_value().get(),
self.inner.k_value(),
self.diverse_params.diverse_results_k,
- attribute_provider,
);
SearchScratch {
@@ -92,12 +84,12 @@ where
}
}
-impl<'a, DP, S, T, P> Search<'a, DP, S, T> for Diverse
+impl<'a, DP, S, T> Search<'a, DP, S, T> for Diverse
where
DP: DataProvider,
+ DP::InternalId: DiverseId,
T: Copy + Send + Sync,
S: SearchStrategy<'a, DP, T, SearchAccessor: SearchAccessor>,
- P: AttributeValueProvider,
{
type Output = SearchStats;
diff --git a/diskann/src/neighbor/diverse_priority_queue.rs b/diskann/src/neighbor/diverse_priority_queue.rs
index 93f355500..2cec871b2 100644
--- a/diskann/src/neighbor/diverse_priority_queue.rs
+++ b/diskann/src/neighbor/diverse_priority_queue.rs
@@ -8,7 +8,6 @@ use std::{
fmt::{Debug, Display},
hash::Hash,
num::NonZeroUsize,
- sync::Arc,
};
use crate::neighbor::{
@@ -25,7 +24,23 @@ pub trait Attribute: Hash + Eq + Copy + Default + Debug + Display + Send + Sync
// Blanket implementation: any type satisfying these bounds automatically implements Attribute
impl Attribute for T where T: Hash + Eq + Copy + Default + Debug + Display + Send + Sync {}
-/// A wrapper type for (VectorIdType, attribute) tuples to implement required traits
+/// Trait for neighbor ids that carry their own diversity attribute.
+///
+/// Instead of looking up attributes through an external provider, the attribute
+/// is obtained directly from the id. Ids that do not have an attribute return
+/// `None` and are skipped by [`DiverseNeighborQueue`].
+pub trait DiverseId: NeighborPriorityQueueIdType + Hash {
+ /// The attribute value type carried by this id.
+ type Attribute: Attribute;
+
+ /// Get the attribute value carried by this id, or `None` if it has none.
+ fn attribute(&self) -> Option;
+}
+
+/// A wrapper type pairing a vector id with an attribute value.
+///
+/// This is a convenience [`DiverseId`] implementation for cases where the
+/// attribute is stored alongside the vector id.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct VectorIdWithAttribute
where
@@ -41,7 +56,7 @@ where
I: NeighborPriorityQueueIdType,
A: Attribute,
{
- fn new(id: I, attribute: A) -> Self {
+ pub fn new(id: I, attribute: A) -> Self {
Self { id, attribute }
}
}
@@ -56,48 +71,52 @@ where
}
}
+impl DiverseId for VectorIdWithAttribute
+where
+ I: NeighborPriorityQueueIdType + Hash,
+ A: Attribute,
+{
+ type Attribute = A;
+
+ fn attribute(&self) -> Option {
+ Some(self.attribute)
+ }
+}
+
/// A diverse neighbor priority queue that wraps a standard NeighborPriorityQueue
/// and delegates all operations to it. This provides a foundation for implementing
/// diversity-aware search algorithms while maintaining the same interface.
///
-/// This struct serves as a wrapper around NeighborPriorityQueue and can be extended
-/// in the future to implement diversity constraints or other specialized behaviors.
+/// The queue is generic over an id type `I` that carries its own diversity
+/// attribute via the [`DiverseId`] trait, so no external attribute provider is
+/// required.
#[derive(Debug, Clone)]
-pub struct DiverseNeighborQueue
+pub struct DiverseNeighborQueue
where
- P: AttributeValueProvider,
+ I: DiverseId,
{
- /// The underlying priority queue that handles all core operations
- /// Stores VectorIdWithAttribute which contains (VectorIdType, attribute_value)
- global_queue: NeighborPriorityQueue>,
- /// Map from attribute_id to local neighbor priority queue
- local_queue_map: HashMap>,
- /// Attribute value provider for managing diversity attributes
- attribute_provider: Arc,
+ /// The underlying priority queue that handles all core operations.
+ global_queue: NeighborPriorityQueue,
+ /// Map from attribute value to local neighbor priority queue.
+ local_queue_map: HashMap>,
/// The calculated diverse_results_l for local queues
diverse_results_l: usize,
/// The target number of diverse results (k_value for diversity)
diverse_results_k: usize,
}
-impl DiverseNeighborQueue
+impl DiverseNeighborQueue
where
- P: AttributeValueProvider,
+ I: DiverseId,
{
/// Create a new DiverseNeighborQueue with the specified capacity.
///
/// This will implicitly set `l_value` to the provided capacity.
- pub fn new(
- l_value: usize,
- k_value: NonZeroUsize,
- diverse_results_k: usize,
- attribute_provider: Arc
,
- ) -> Self {
+ pub fn new(l_value: usize, k_value: NonZeroUsize, diverse_results_k: usize) -> Self {
let diverse_results_l = diverse_results_k * l_value / k_value.get();
Self {
global_queue: NeighborPriorityQueue::new(l_value),
local_queue_map: HashMap::new(),
- attribute_provider,
diverse_results_l,
diverse_results_k,
}
@@ -133,29 +152,29 @@ where
// Step 2: Compact global queue using the filter
if !removed_items.is_empty() {
self.global_queue
- .retain(|neighbor| !removed_items.contains(&neighbor.id.id));
+ .retain(|neighbor| !removed_items.contains(&neighbor.id));
}
}
}
-impl
NeighborQueue for DiverseNeighborQueue
+impl NeighborQueue for DiverseNeighborQueue
where
- P: AttributeValueProvider,
+ I: DiverseId,
{
type Iter<'a>
- = BestCandidatesIterator<'a, P::Id, Self>
+ = BestCandidatesIterator<'a, I, Self>
where
Self: 'a,
- P::Id: 'a;
+ I: 'a;
- fn insert(&mut self, nbr: Neighbor) {
- // Get the attribute value for the current neighbor.
+ fn insert(&mut self, nbr: Neighbor) {
+ // Get the attribute value carried by the neighbor's id.
// We explicitly skip neighbors without attributes (returning None) rather than using
// unwrap_or_default(), because using a default value would conflate "missing attribute"
// with "attribute value 0" (or whatever the default is). This could violate diversity
// constraints by incorrectly grouping neighbors without attributes together with
// neighbors that legitimately have the default attribute value.
- let Some(attribute_value) = self.attribute_provider.get(nbr.id) else {
+ let Some(attribute_value) = nbr.id.attribute() else {
return;
};
@@ -168,71 +187,52 @@ where
let local_queue_full = local_queue.is_full();
let global_queue_full = self.global_queue.is_full();
- // Create a neighbor with VectorIdWithAttribute for global_queue
- let nbr_with_attribute = Neighbor::new(
- VectorIdWithAttribute::new(nbr.id, attribute_value),
- nbr.distance,
- );
-
if !local_queue_full && !global_queue_full {
// Case 1: Both local queue and global queue have space
local_queue.insert(nbr);
- self.global_queue.insert(nbr_with_attribute);
+ self.global_queue.insert(nbr);
} else if local_queue_full {
// Case 2: Local queue is full
if nbr.distance < local_queue.get(self.diverse_results_l - 1).distance {
// Get the worst neighbor in the local queue
let worst_neighbor = local_queue.get(self.diverse_results_l - 1);
- // Create the corresponding neighbor with attribute for removal from global queue
- let worst_neighbor_with_attribute = Neighbor::new(
- VectorIdWithAttribute::new(worst_neighbor.id, attribute_value),
- worst_neighbor.distance,
- );
// Remove worst neighbor from global queue using the remove method
- self.global_queue.remove(worst_neighbor_with_attribute);
+ self.global_queue.remove(worst_neighbor);
// Insert new neighbor into both queues
local_queue.insert(nbr);
- self.global_queue.insert(nbr_with_attribute);
+ self.global_queue.insert(nbr);
}
} else if !local_queue_full && global_queue_full {
// Case 3: Local queue has space but global queue is full
let l_size = self.global_queue.search_l();
if nbr.distance < self.global_queue.get(l_size - 1).distance {
let worst_global = self.global_queue.get(l_size - 1);
- // Extract the attribute from VectorIdWithAttribute
- let attribute_of_worst_global = worst_global.id.attribute;
+ // The attribute of the worst global neighbor comes from its id.
+ let attribute_of_worst_global = worst_global.id.attribute();
// Insert new neighbor into both queues
local_queue.insert(nbr);
- self.global_queue.insert(nbr_with_attribute);
+ self.global_queue.insert(nbr);
// Remove worst neighbor from its local queue
- if let Some(local_queue) = self.local_queue_map.get_mut(&attribute_of_worst_global)
+ if let Some(attribute_of_worst_global) = attribute_of_worst_global
+ && let Some(local_queue) =
+ self.local_queue_map.get_mut(&attribute_of_worst_global)
{
- let worst_neighbor_without_attribute =
- Neighbor::new(worst_global.id.id, worst_global.distance);
- local_queue.remove(worst_neighbor_without_attribute);
+ local_queue.remove(worst_global);
}
}
}
}
- fn get(&self, index: usize) -> Neighbor {
- let neighbor_with_attribute = self.global_queue.get(index);
- Neighbor::new(
- neighbor_with_attribute.id.id,
- neighbor_with_attribute.distance,
- )
+ fn get(&self, index: usize) -> Neighbor {
+ self.global_queue.get(index)
}
- fn closest_notvisited(&mut self) -> Option> {
- let neighbor_with_attribute = self.global_queue.closest_notvisited()?;
- Some(Neighbor::new(
- neighbor_with_attribute.id.id,
- neighbor_with_attribute.distance,
- ))
+ fn closest_notvisited(&mut self) -> Option> {
+ self.global_queue.closest_notvisited()
}
fn has_notvisited_node(&self) -> bool {
@@ -256,89 +256,68 @@ where
self.local_queue_map.clear();
}
- fn iter(&self) -> BestCandidatesIterator<'_, P::Id, Self> {
+ fn iter(&self) -> BestCandidatesIterator<'_, I, Self> {
let sz = self.global_queue.search_l().min(self.global_queue.size());
BestCandidatesIterator::new(sz, self)
}
}
-/// Trait for providing attribute values for vector IDs.
-/// Implementations of this trait can be used with diverse search to retrieve
-/// attribute values for vectors during search operations.
-pub trait AttributeValueProvider: crate::provider::HasId + Send + Sync + std::fmt::Debug {
- type Value: Attribute;
-
- /// Get the attribute value for a given vector ID.
- ///
- /// # Arguments
- /// * `id` - The vector ID
- ///
- /// # Returns
- /// * `Option` - The attribute value if it exists, None otherwise
- fn get(&self, id: Self::Id) -> Option;
-}
-
#[cfg(test)]
mod diverse_priority_queue_test {
use super::*;
- /// A test attribute value provider that stores attribute values for vector IDs.
- /// This is a simple in-memory store using a HashMap for testing purposes.
- #[derive(Debug, Clone)]
- struct TestAttributeValueProvider {
- /// Map from vector_id to attribute value
- attributes: HashMap,
+ /// A test id type that carries an optional attribute value, used to exercise
+ /// both the "has attribute" and "missing attribute" code paths.
+ #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
+ struct TestId {
+ id: u32,
+ attribute: Option,
}
- impl TestAttributeValueProvider {
- /// Create a new empty TestAttributeValueProvider.
- fn new() -> Self {
- Self {
- attributes: HashMap::new(),
- }
+ impl TestId {
+ fn new(id: u32, attribute: Option) -> Self {
+ Self { id, attribute }
}
+ }
- /// Insert an attribute value for a given vector ID.
- fn insert(&mut self, vector_id: u32, attribute_value: u32) {
- self.attributes.insert(vector_id, attribute_value);
+ impl Display for TestId {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.id)
}
}
- impl crate::provider::HasId for TestAttributeValueProvider {
- type Id = u32;
+ impl DiverseId for TestId {
+ type Attribute = u32;
+
+ fn attribute(&self) -> Option {
+ self.attribute
+ }
}
- impl AttributeValueProvider for TestAttributeValueProvider {
- type Value = u32;
+ // Type alias for tests to make them more readable
+ type TestDiverseQueue = DiverseNeighborQueue;
- fn get(&self, id: Self::Id) -> Option {
- self.attributes.get(&id).copied()
- }
+ /// Build a neighbor whose id carries the attribute `id / 3`.
+ ///
+ /// This mirrors the original test setup where vectors 0-2 had attribute 0,
+ /// 3-5 had attribute 1, and so on.
+ fn nbr(id: u32, distance: f32) -> Neighbor {
+ Neighbor::new(TestId::new(id, Some(id / 3)), distance)
}
- impl Default for TestAttributeValueProvider {
- fn default() -> Self {
- Self::new()
- }
+ /// Build a neighbor whose id carries an explicit attribute value.
+ fn nbr_attr(id: u32, attribute: u32, distance: f32) -> Neighbor {
+ Neighbor::new(TestId::new(id, Some(attribute)), distance)
}
- // Type alias for tests to make them more readable
- type TestDiverseQueue = DiverseNeighborQueue;
-
- /// Helper function to create a test attribute provider wrapped in Arc
- fn create_test_attribute_provider() -> Arc {
- let mut provider = TestAttributeValueProvider::new();
- // Set up attributes: vectors 0-2 have attribute 0, vectors 3-5 have attribute 1, etc.
- for i in 0..20 {
- provider.insert(i, i / 3);
- }
- Arc::new(provider)
+ /// Build a neighbor whose id has no attribute.
+ fn nbr_none(id: u32, distance: f32) -> Neighbor {
+ Neighbor::new(TestId::new(id, None), distance)
}
#[test]
fn test_new() {
- let attribute_provider = create_test_attribute_provider();
- let queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
assert_eq!(queue.size(), 0);
assert_eq!(queue.capacity(), 10);
@@ -348,14 +327,12 @@ mod diverse_priority_queue_test {
#[test]
fn test_insert_single_attribute() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
// Insert neighbors with IDs 0, 1, 2 (all have attribute 0)
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(1, 0.5));
- queue.insert(Neighbor::new(2, 1.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(1, 0.5));
+ queue.insert(nbr(2, 1.5));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map.len(), 1);
@@ -364,14 +341,12 @@ mod diverse_priority_queue_test {
#[test]
fn test_insert_multiple_attributes() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
// Insert neighbors with different attributes
- queue.insert(Neighbor::new(0, 1.0)); // attribute 0
- queue.insert(Neighbor::new(3, 0.8)); // attribute 1
- queue.insert(Neighbor::new(6, 1.2)); // attribute 2
+ queue.insert(nbr(0, 1.0)); // attribute 0
+ queue.insert(nbr(3, 0.8)); // attribute 1
+ queue.insert(nbr(6, 1.2)); // attribute 2
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map.len(), 3);
@@ -382,111 +357,94 @@ mod diverse_priority_queue_test {
#[test]
fn test_insert_maintains_order() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(1, 0.5));
- queue.insert(Neighbor::new(2, 1.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(1, 0.5));
+ queue.insert(nbr(2, 1.5));
// Neighbors should be sorted by distance
- assert_eq!(queue.get(0).id, 1); // distance 0.5
- assert_eq!(queue.get(1).id, 0); // distance 1.0
- assert_eq!(queue.get(2).id, 2); // distance 1.5
+ assert_eq!(queue.get(0).id.id, 1); // distance 0.5
+ assert_eq!(queue.get(1).id.id, 0); // distance 1.0
+ assert_eq!(queue.get(2).id.id, 2); // distance 1.5
}
#[test]
fn test_insert_local_queue_full() {
- let mut attribute_provider = TestAttributeValueProvider::new();
- // All IDs 10-15 have the same attribute (attribute 0)
- for i in 10..=15 {
- attribute_provider.insert(i, 0);
- }
// l_value=20, k_value=20, diverse_results_k=3 => diverse_results_l = 3 * 20 / 20 = 3
- let mut queue = TestDiverseQueue::new(
- 20,
- NonZeroUsize::new(20).unwrap(),
- 3,
- Arc::new(attribute_provider),
- );
+ // All IDs 10-13 share attribute 0.
+ let mut queue = TestDiverseQueue::new(20, NonZeroUsize::new(20).unwrap(), 3);
// Fill up the local queue for attribute 0 (diverse_results_l = 3)
- queue.insert(Neighbor::new(10, 1.0));
- queue.insert(Neighbor::new(11, 0.8));
- queue.insert(Neighbor::new(12, 1.2));
+ queue.insert(nbr_attr(10, 0, 1.0));
+ queue.insert(nbr_attr(11, 0, 0.8));
+ queue.insert(nbr_attr(12, 0, 1.2));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map[&0].size(), 3);
// Try to insert a better neighbor with same attribute (different ID)
- queue.insert(Neighbor::new(13, 0.5)); // Better distance, should replace worst
+ queue.insert(nbr_attr(13, 0, 0.5)); // Better distance, should replace worst
assert_eq!(queue.size(), 3); // Size should remain same
- assert_eq!(queue.get(0).id, 13); // Best is now the new one with distance 0.5
+ assert_eq!(queue.get(0).id.id, 13); // Best is now the new one with distance 0.5
}
#[test]
fn test_insert_inner_queue_full() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(3, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(3, NonZeroUsize::new(5).unwrap(), 5);
// Fill up inner queue (capacity = 3)
- queue.insert(Neighbor::new(0, 1.0)); // attribute 0
- queue.insert(Neighbor::new(3, 0.8)); // attribute 1
- queue.insert(Neighbor::new(6, 1.2)); // attribute 2
+ queue.insert(nbr(0, 1.0)); // attribute 0
+ queue.insert(nbr(3, 0.8)); // attribute 1
+ queue.insert(nbr(6, 1.2)); // attribute 2
assert_eq!(queue.size(), 3);
// Insert a better neighbor with a new attribute
- queue.insert(Neighbor::new(9, 0.5)); // attribute 3, better distance
+ queue.insert(nbr(9, 0.5)); // attribute 3, better distance
assert_eq!(queue.size(), 3);
- assert_eq!(queue.get(0).id, 9); // Best should be the new one
+ assert_eq!(queue.get(0).id.id, 9); // Best should be the new one
}
#[test]
fn test_get() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(1, 0.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(1, 0.5));
- let nbr = queue.get(0);
- assert_eq!(nbr.id, 1);
- assert_eq!(nbr.distance, 0.5);
+ let n = queue.get(0);
+ assert_eq!(n.id.id, 1);
+ assert_eq!(n.distance, 0.5);
- let nbr = queue.get(1);
- assert_eq!(nbr.id, 0);
- assert_eq!(nbr.distance, 1.0);
+ let n = queue.get(1);
+ assert_eq!(n.id.id, 0);
+ assert_eq!(n.distance, 1.0);
}
#[test]
fn test_closest_notvisited() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(1, 0.5));
- queue.insert(Neighbor::new(2, 1.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(1, 0.5));
+ queue.insert(nbr(2, 1.5));
assert!(queue.has_notvisited_node());
- let nbr = queue.closest_notvisited().unwrap();
- assert_eq!(nbr.id, 1); // Best unvisited
- assert_eq!(nbr.distance, 0.5);
+ let n = queue.closest_notvisited().unwrap();
+ assert_eq!(n.id.id, 1); // Best unvisited
+ assert_eq!(n.distance, 0.5);
assert!(queue.has_notvisited_node());
- let nbr = queue.closest_notvisited().unwrap();
- assert_eq!(nbr.id, 0); // Next best
+ let n = queue.closest_notvisited().unwrap();
+ assert_eq!(n.id.id, 0); // Next best
- let nbr = queue.closest_notvisited().unwrap();
- assert_eq!(nbr.id, 2); // Last one
+ let n = queue.closest_notvisited().unwrap();
+ assert_eq!(n.id.id, 2); // Last one
assert!(!queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_none());
@@ -494,13 +452,11 @@ mod diverse_priority_queue_test {
#[test]
fn test_has_notvisited_node() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
assert!(!queue.has_notvisited_node());
- queue.insert(Neighbor::new(0, 1.0));
+ queue.insert(nbr(0, 1.0));
assert!(queue.has_notvisited_node());
assert!(queue.closest_notvisited().is_some());
@@ -510,42 +466,36 @@ mod diverse_priority_queue_test {
#[test]
fn test_size() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
assert_eq!(queue.size(), 0);
- queue.insert(Neighbor::new(0, 1.0));
+ queue.insert(nbr(0, 1.0));
assert_eq!(queue.size(), 1);
- queue.insert(Neighbor::new(1, 0.5));
+ queue.insert(nbr(1, 0.5));
assert_eq!(queue.size(), 2);
}
#[test]
fn test_capacity() {
- let attribute_provider = create_test_attribute_provider();
- let queue = TestDiverseQueue::new(15, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let queue = TestDiverseQueue::new(15, NonZeroUsize::new(5).unwrap(), 5);
assert_eq!(queue.capacity(), 15);
}
#[test]
fn test_search_l() {
- let attribute_provider = create_test_attribute_provider();
- let queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 5);
assert_eq!(queue.search_l(), 20);
}
#[test]
fn test_clear() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(3, 0.5));
- queue.insert(Neighbor::new(6, 1.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(3, 0.5));
+ queue.insert(nbr(6, 1.5));
assert_eq!(queue.size(), 3);
assert_eq!(queue.local_queue_map.len(), 3);
@@ -558,29 +508,25 @@ mod diverse_priority_queue_test {
#[test]
fn test_iter_candidates() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(1, 0.5));
- queue.insert(Neighbor::new(2, 1.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(1, 0.5));
+ queue.insert(nbr(2, 1.5));
let candidates: Vec<_> = queue.iter().collect();
assert_eq!(candidates.len(), 3);
- assert_eq!(candidates[0].id, 1);
- assert_eq!(candidates[1].id, 0);
- assert_eq!(candidates[2].id, 2);
+ assert_eq!(candidates[0].id.id, 1);
+ assert_eq!(candidates[1].id.id, 0);
+ assert_eq!(candidates[2].id.id, 2);
}
#[test]
fn test_inner_and_inner_mut() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0));
+ queue.insert(nbr(0, 1.0));
// Test direct access to global_queue
assert_eq!(queue.global_queue.size(), 1);
@@ -595,112 +541,55 @@ mod diverse_priority_queue_test {
let vid_attr = VectorIdWithAttribute::new(42u32, 7);
assert_eq!(vid_attr.id, 42);
assert_eq!(vid_attr.attribute, 7);
+ assert_eq!(vid_attr.attribute(), Some(7));
let formatted = format!("{}", vid_attr);
assert_eq!(formatted, "(42, 7)");
}
- #[test]
- fn test_attribute_value_provider() {
- let mut provider = TestAttributeValueProvider::new();
-
- assert_eq!(provider.get(0), None);
-
- provider.insert(0, 10);
- assert_eq!(provider.get(0), Some(10));
-
- provider.insert(5, 20);
- assert_eq!(provider.get(5), Some(20));
-
- // Update existing value
- provider.insert(0, 15);
- assert_eq!(provider.get(0), Some(15));
- }
-
- #[test]
- fn test_attribute_value_provider_default() {
- let provider = TestAttributeValueProvider::default();
- assert_eq!(provider.get(0), None);
- }
-
#[test]
fn test_diverse_queue_complex_scenario() {
- let attribute_provider = create_test_attribute_provider();
- let mut queue =
- TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 3, attribute_provider);
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 3);
// Insert multiple neighbors with different attributes
- queue.insert(Neighbor::new(0, 1.0)); // attribute 0
- queue.insert(Neighbor::new(1, 0.5)); // attribute 0
- queue.insert(Neighbor::new(2, 1.5)); // attribute 0
- queue.insert(Neighbor::new(3, 0.8)); // attribute 1
- queue.insert(Neighbor::new(4, 1.2)); // attribute 1
- queue.insert(Neighbor::new(6, 0.7)); // attribute 2
+ queue.insert(nbr(0, 1.0)); // attribute 0
+ queue.insert(nbr(1, 0.5)); // attribute 0
+ queue.insert(nbr(2, 1.5)); // attribute 0
+ queue.insert(nbr(3, 0.8)); // attribute 1
+ queue.insert(nbr(4, 1.2)); // attribute 1
+ queue.insert(nbr(6, 0.7)); // attribute 2
assert_eq!(queue.size(), 6);
- // Try to add more to attribute 0 when its local queue is full (use unique ID 17)
- // ID 17 has attribute 5, so let's use ID 15 which has attribute 5 but we'll manually set attribute 0
- let mut attribute_provider_updated = TestAttributeValueProvider::new();
- for i in 0..20 {
- attribute_provider_updated.insert(i, i / 3);
- }
- attribute_provider_updated.insert(17, 0); // Set ID 17 to have attribute 0
-
- // Create a new queue with updated provider
- let mut queue2 = TestDiverseQueue::new(
- 10,
- NonZeroUsize::new(5).unwrap(),
- 3,
- Arc::new(attribute_provider_updated),
- );
- queue2.insert(Neighbor::new(0, 1.0)); // attribute 0
- queue2.insert(Neighbor::new(1, 0.5)); // attribute 0
- queue2.insert(Neighbor::new(2, 1.5)); // attribute 0
- queue2.insert(Neighbor::new(3, 0.8)); // attribute 1
- queue2.insert(Neighbor::new(4, 1.2)); // attribute 1
- queue2.insert(Neighbor::new(6, 0.7)); // attribute 2
+ // Insert ID 17 with attribute 0 and a better distance; it should become the best.
+ queue.insert(nbr_attr(17, 0, 0.3));
- // Now insert ID 17 with attribute 0 and better distance
- queue2.insert(Neighbor::new(17, 0.3)); // Should replace worst in attribute 0
-
- // Verify best neighbor is from the new insertion
- assert_eq!(queue2.get(0).id, 17);
- assert_eq!(queue2.get(0).distance, 0.3);
+ assert_eq!(queue.get(0).id.id, 17);
+ assert_eq!(queue.get(0).distance, 0.3);
}
#[test]
fn test_post_process() {
- let mut attribute_provider = TestAttributeValueProvider::new();
- // Set up attributes: IDs 0-2 have attribute 0, IDs 3-5 have attribute 1, IDs 6-8 have attribute 2
- for i in 0..9 {
- attribute_provider.insert(i, i / 3);
- }
-
// Create queue with l_value=20, k_value=5, diverse_results_k=2
- // This gives diverse_results_l = 2 * 20 / 5 = 8
- let mut queue = TestDiverseQueue::new(
- 20,
- NonZeroUsize::new(5).unwrap(),
- 2,
- Arc::new(attribute_provider),
- );
+ // This gives diverse_results_l = 2 * 20 / 5 = 8.
+ // IDs 0-2 have attribute 0, IDs 3-5 have attribute 1, IDs 6-8 have attribute 2.
+ let mut queue = TestDiverseQueue::new(20, NonZeroUsize::new(5).unwrap(), 2);
// Insert more than diverse_results_k items for each attribute
// Attribute 0
- queue.insert(Neighbor::new(0, 1.0));
- queue.insert(Neighbor::new(1, 0.5));
- queue.insert(Neighbor::new(2, 1.5));
+ queue.insert(nbr(0, 1.0));
+ queue.insert(nbr(1, 0.5));
+ queue.insert(nbr(2, 1.5));
// Attribute 1
- queue.insert(Neighbor::new(3, 0.8));
- queue.insert(Neighbor::new(4, 1.2));
- queue.insert(Neighbor::new(5, 0.6));
+ queue.insert(nbr(3, 0.8));
+ queue.insert(nbr(4, 1.2));
+ queue.insert(nbr(5, 0.6));
// Attribute 2
- queue.insert(Neighbor::new(6, 0.7));
- queue.insert(Neighbor::new(7, 1.1));
- queue.insert(Neighbor::new(8, 0.9));
+ queue.insert(nbr(6, 0.7));
+ queue.insert(nbr(7, 1.1));
+ queue.insert(nbr(8, 0.9));
// Before post_process, we should have all 9 items
assert_eq!(queue.size(), 9);
@@ -721,58 +610,44 @@ mod diverse_priority_queue_test {
// Verify the best items from each attribute are kept
// Attribute 0: best are ID 1 (0.5) and ID 0 (1.0), worst ID 2 (1.5) should be removed
- assert_eq!(queue.local_queue_map[&0].get(0).id, 1);
+ assert_eq!(queue.local_queue_map[&0].get(0).id.id, 1);
assert_eq!(queue.local_queue_map[&0].get(0).distance, 0.5);
- assert_eq!(queue.local_queue_map[&0].get(1).id, 0);
+ assert_eq!(queue.local_queue_map[&0].get(1).id.id, 0);
assert_eq!(queue.local_queue_map[&0].get(1).distance, 1.0);
// Attribute 1: best are ID 5 (0.6) and ID 3 (0.8), worst ID 4 (1.2) should be removed
- assert_eq!(queue.local_queue_map[&1].get(0).id, 5);
+ assert_eq!(queue.local_queue_map[&1].get(0).id.id, 5);
assert_eq!(queue.local_queue_map[&1].get(0).distance, 0.6);
- assert_eq!(queue.local_queue_map[&1].get(1).id, 3);
+ assert_eq!(queue.local_queue_map[&1].get(1).id.id, 3);
assert_eq!(queue.local_queue_map[&1].get(1).distance, 0.8);
// Attribute 2: best are ID 6 (0.7) and ID 8 (0.9), worst ID 7 (1.1) should be removed
- assert_eq!(queue.local_queue_map[&2].get(0).id, 6);
+ assert_eq!(queue.local_queue_map[&2].get(0).id.id, 6);
assert_eq!(queue.local_queue_map[&2].get(0).distance, 0.7);
- assert_eq!(queue.local_queue_map[&2].get(1).id, 8);
+ assert_eq!(queue.local_queue_map[&2].get(1).id.id, 8);
assert_eq!(queue.local_queue_map[&2].get(1).distance, 0.9);
// Verify global queue has the correct items in sorted order
- assert_eq!(queue.get(0).id, 1); // 0.5
- assert_eq!(queue.get(1).id, 5); // 0.6
- assert_eq!(queue.get(2).id, 6); // 0.7
- assert_eq!(queue.get(3).id, 3); // 0.8
- assert_eq!(queue.get(4).id, 8); // 0.9
- assert_eq!(queue.get(5).id, 0); // 1.0
+ assert_eq!(queue.get(0).id.id, 1); // 0.5
+ assert_eq!(queue.get(1).id.id, 5); // 0.6
+ assert_eq!(queue.get(2).id.id, 6); // 0.7
+ assert_eq!(queue.get(3).id.id, 3); // 0.8
+ assert_eq!(queue.get(4).id.id, 8); // 0.9
+ assert_eq!(queue.get(5).id.id, 0); // 1.0
}
#[test]
fn test_skip_neighbors_without_attributes() {
// Test that neighbors without attributes are silently skipped
- // rather than being conflated with attribute value 0
- let mut attribute_provider = TestAttributeValueProvider::new();
-
- // Set up some vectors with attributes
- attribute_provider.insert(0, 0); // ID 0 has attribute 0
- attribute_provider.insert(1, 0); // ID 1 has attribute 0
- attribute_provider.insert(2, 1); // ID 2 has attribute 1
- // ID 3 has no attribute (not in the map)
- attribute_provider.insert(4, 0); // ID 4 has attribute 0
-
- let mut queue = TestDiverseQueue::new(
- 10,
- NonZeroUsize::new(5).unwrap(),
- 5,
- Arc::new(attribute_provider),
- );
+ // rather than being conflated with attribute value 0.
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
// Insert neighbors, including one without an attribute
- queue.insert(Neighbor::new(0, 1.0)); // Has attribute 0
- queue.insert(Neighbor::new(1, 0.5)); // Has attribute 0
- queue.insert(Neighbor::new(2, 0.8)); // Has attribute 1
- queue.insert(Neighbor::new(3, 0.3)); // No attribute - should be skipped
- queue.insert(Neighbor::new(4, 1.2)); // Has attribute 0
+ queue.insert(nbr_attr(0, 0, 1.0)); // Has attribute 0
+ queue.insert(nbr_attr(1, 0, 0.5)); // Has attribute 0
+ queue.insert(nbr_attr(2, 1, 0.8)); // Has attribute 1
+ queue.insert(nbr_none(3, 0.3)); // No attribute - should be skipped
+ queue.insert(nbr_attr(4, 0, 1.2)); // Has attribute 0
// Queue should only contain 4 items (ID 3 was skipped)
assert_eq!(queue.size(), 4, "Expected 4 items, ID 3 should be skipped");
@@ -792,7 +667,7 @@ mod diverse_priority_queue_test {
);
// Verify ID 3 (without attribute) is not in the queue
- let ids: Vec = queue.iter().map(|n| n.id).collect();
+ let ids: Vec = queue.iter().map(|n| n.id.id).collect();
assert!(!ids.contains(&3), "ID 3 should not be in the queue");
assert_eq!(
ids,
@@ -803,25 +678,12 @@ mod diverse_priority_queue_test {
#[test]
fn test_attribute_zero_vs_missing_attribute() {
- // Verify that attribute value 0 is distinct from missing attributes
- let mut attribute_provider = TestAttributeValueProvider::new();
-
- // ID 0 explicitly has attribute 0
- attribute_provider.insert(0, 0);
- // ID 1 has no attribute (missing)
- // ID 2 explicitly has attribute 0
- attribute_provider.insert(2, 0);
-
- let mut queue = TestDiverseQueue::new(
- 10,
- NonZeroUsize::new(5).unwrap(),
- 5,
- Arc::new(attribute_provider),
- );
+ // Verify that attribute value 0 is distinct from missing attributes.
+ let mut queue = TestDiverseQueue::new(10, NonZeroUsize::new(5).unwrap(), 5);
- queue.insert(Neighbor::new(0, 1.0)); // Has attribute 0
- queue.insert(Neighbor::new(1, 0.5)); // Missing attribute - should be skipped
- queue.insert(Neighbor::new(2, 0.8)); // Has attribute 0
+ queue.insert(nbr_attr(0, 0, 1.0)); // Has attribute 0
+ queue.insert(nbr_none(1, 0.5)); // Missing attribute - should be skipped
+ queue.insert(nbr_attr(2, 0, 0.8)); // Has attribute 0
// Only IDs 0 and 2 should be in the queue
assert_eq!(queue.size(), 2);
@@ -830,7 +692,7 @@ mod diverse_priority_queue_test {
assert_eq!(queue.local_queue_map[&0].size(), 2);
// Verify ID 1 is not in the queue
- let ids: Vec = queue.iter().map(|n| n.id).collect();
+ let ids: Vec = queue.iter().map(|n| n.id.id).collect();
assert_eq!(ids, vec![2, 0], "Queue should only contain IDs 2 and 0");
}
}
diff --git a/diskann/src/neighbor/mod.rs b/diskann/src/neighbor/mod.rs
index 391a5435e..7755c14d8 100644
--- a/diskann/src/neighbor/mod.rs
+++ b/diskann/src/neighbor/mod.rs
@@ -16,7 +16,7 @@ pub use queue::{NeighborPriorityQueue, NeighborPriorityQueueIdType, NeighborQueu
mod diverse_priority_queue;
#[cfg(feature = "experimental_diversity_search")]
pub use diverse_priority_queue::{
- Attribute, AttributeValueProvider, DiverseNeighborQueue, VectorIdWithAttribute,
+ Attribute, DiverseId, DiverseNeighborQueue, VectorIdWithAttribute,
};
//////////////