From f1275a126694e4efce6e29eb29c57b1977d03091 Mon Sep 17 00:00:00 2001 From: Jordan Maples Date: Mon, 29 Jun 2026 14:50:03 -0700 Subject: [PATCH 1/2] Make BfTreeProvider generic over vertex id type The provider hardcoded `u32` vertex ids, capping any bf-tree-backed index at ~4.29B vectors. That ceiling is at odds with the crate's purpose of supporting larger-than-memory, billion-scale-and-beyond datasets. Introduce a `BfTreeId` trait (`VectorId + IntoUsize` plus index-construction) implemented for `u32` and `u64`, and thread a new id type parameter `I` through `BfTreeProvider` and all of its accessors, strategies, and save/load impls. The default `I = u32` keeps every existing caller (benchmark harness, tests) compiling and behaving identically, while `u64` is now available opt-in for indexes that exceed the 32-bit range. `iter()`/`IntoIterator` now map `0..total` through `I::from_index` instead of relying on `Range` (a generic `Range` would require the unstable `Step` trait). `VectorProvider::starting_points` is likewise generic over the id type. The on-disk neighbor format follows the id width and is self- consistent on load (the width is not persisted, mirroring the metric). A new `test_quantized_index_search_u64_ids` builds and searches a `BfTreeProvider<_, _, u64>` index end-to-end to prove the u64 path is functional, not merely compilable. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- diskann-bftree/src/id.rs | 54 +++++ diskann-bftree/src/lib.rs | 3 + diskann-bftree/src/provider.rs | 380 +++++++++++++++++++++------------ diskann-bftree/src/vectors.rs | 8 +- 4 files changed, 310 insertions(+), 135 deletions(-) create mode 100644 diskann-bftree/src/id.rs diff --git a/diskann-bftree/src/id.rs b/diskann-bftree/src/id.rs new file mode 100644 index 000000000..af8cd3bfc --- /dev/null +++ b/diskann-bftree/src/id.rs @@ -0,0 +1,54 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Vertex-id abstraction for the bf-tree provider. + +use diskann::utils::{IntoUsize, VectorId}; + +/// Identifier type usable as a `BfTreeProvider` vertex id. +/// +/// This bundles the bounds the core algorithm requires of an id ([`VectorId`]) +/// with the ability to convert *to* `usize` ([`IntoUsize`], used to key the +/// per-vector stores) and *from* a zero-based index. The provider mints ids +/// densely from `0..total`, so it needs a way to build an `I` from a counter. +/// +/// Implemented for `u32` (the default, capping at ~4.29B vertices) and `u64` +/// (for billion-scale-and-beyond, larger-than-memory datasets). On a 64-bit +/// target `u64` covers every representable `usize`, so its conversions never +/// fail. +pub trait BfTreeId: VectorId + IntoUsize { + /// Build an id from a zero-based index, truncating on overflow. + /// + /// Only call this for indices already known to fit (e.g. ids drawn from + /// `0..total`, which the provider guarantees fit by construction). + fn from_index(index: usize) -> Self; + + /// Build an id from a zero-based index, returning `None` if it does not fit. + fn try_from_index(index: usize) -> Option; +} + +impl BfTreeId for u32 { + #[inline(always)] + fn from_index(index: usize) -> Self { + index as u32 + } + + #[inline(always)] + fn try_from_index(index: usize) -> Option { + u32::try_from(index).ok() + } +} + +impl BfTreeId for u64 { + #[inline(always)] + fn from_index(index: usize) -> Self { + index as u64 + } + + #[inline(always)] + fn try_from_index(index: usize) -> Option { + u64::try_from(index).ok() + } +} diff --git a/diskann-bftree/src/lib.rs b/diskann-bftree/src/lib.rs index 334e0dd82..f0d40e58b 100644 --- a/diskann-bftree/src/lib.rs +++ b/diskann-bftree/src/lib.rs @@ -9,6 +9,7 @@ //! [`DataProvider`](diskann::provider::DataProvider) trait, enabling indexes that can //! transparently spill to disk for datasets larger than available memory. +pub mod id; pub mod neighbors; pub mod provider; pub mod quant; @@ -16,6 +17,8 @@ pub mod vectors; mod locks; +pub use id::BfTreeId; + // Accessors pub use provider::{ AsVectorDtype, BfTreePaths, BfTreeProvider, BfTreeProviderParameters, CreateQuantProvider, diff --git a/diskann-bftree/src/provider.rs b/diskann-bftree/src/provider.rs index 839115f09..03d58ac0c 100644 --- a/diskann-bftree/src/provider.rs +++ b/diskann-bftree/src/provider.rs @@ -32,7 +32,7 @@ use diskann::{ }, neighbor::Neighbor, provider::{DataProvider, DefaultContext, Delete, ElementStatus, HasId, NoopGuard, SetElement}, - utils::{IntoUsize, VectorRepr}, + utils::VectorRepr, ANNError, ANNResult, }; use diskann_utils::{ @@ -45,7 +45,7 @@ use super::{ neighbors::{NeighborAccessor, NeighborProvider}, quant::QuantVectorProvider, vectors::VectorProvider, - AccessError, NoStore, + AccessError, BfTreeId, NoStore, }; use crate::locks::StripedLocks; use diskann_providers::model::graph::provider::async_::distances::UnwrapErr; @@ -177,9 +177,10 @@ use diskann_providers::storage::{LoadWith, SaveWith, StorageReadProvider, Storag /// quantizer, /// ); /// ``` -pub struct BfTreeProvider +pub struct BfTreeProvider where T: VectorRepr, + I: BfTreeId, { // The quant vector store. If `Q == NoStore`, the quantized operations are disabled. // @@ -191,7 +192,7 @@ where // Provider that holds the graph structure as neighbors of vectors. // - pub(crate) neighbor_provider: NeighborProvider, + pub(crate) neighbor_provider: NeighborProvider, // The metric to use for distances // @@ -265,9 +266,10 @@ pub struct BfTreeProviderParameters { pub use_snapshot: bool, } -impl BfTreeProvider +impl BfTreeProvider where T: VectorRepr, + I: BfTreeId, { /// Construct a new data provider from empty. Callers of this are required to manually set start /// points before performing search tasks. @@ -360,7 +362,7 @@ where // until BF-tree API is improved to handle `exists` queries. let mut scratch = provider.neighbor_provider.scratch(&provider.locks); for i in 0..params.max_points { - let vector_id = i as u32; + let vector_id = I::from_index(i); scratch.write_neighbors(vector_id, &[])?; } } @@ -378,13 +380,13 @@ where // } /// Return a vector of starting points. - pub fn starting_points(&self) -> ANNResult> { + pub fn starting_points(&self) -> ANNResult> { self.full_vectors.starting_points() } /// An iterator over all ids including start points (even if they are deleted). - pub fn iter(&self) -> std::ops::Range { - 0..(self.full_vectors.total() as u32) + pub fn iter(&self) -> std::iter::Map, fn(usize) -> I> { + (0..self.full_vectors.total()).map(I::from_index as fn(usize) -> I) } pub fn num_start_points(&self) -> usize { @@ -412,9 +414,10 @@ where } } -impl BfTreeProvider +impl BfTreeProvider where T: VectorRepr, + I: BfTreeId, { /// Return the number of vector reads for full-precision and quant-vectors respectively /// @@ -426,9 +429,10 @@ where } } -impl BfTreeProvider +impl BfTreeProvider where T: VectorRepr, + I: BfTreeId, { /// Return the number of vector reads for full-precision and quant-vectors respectively /// @@ -467,10 +471,11 @@ impl DeleteQuant for NoStore { /// [`InplaceDeleteStrategy::get_delete_element`]) *after* the delete has already been committed. /// Use [`InplaceDeleteMethod::OneHop`] or [`InplaceDeleteMethod::TwoHopAndOneHop`] instead, /// as these strategies only require neighbor topology (which remains accessible). -impl Delete for BfTreeProvider +impl Delete for BfTreeProvider where T: VectorRepr, Q: AsyncFriendly + DeleteQuant, + I: BfTreeId, { fn release( &self, @@ -486,12 +491,12 @@ where gid: &Self::ExternalId, ) -> impl std::future::Future> + Send { let id = *gid; - let _guard = self.locks.lock(id as usize); + let _guard = self.locks.lock(id.into_usize()); // Only delete vector data here. Neighbor adjacency cleanup (zeroing the // deleted vertex's edge list and patching neighbors-of-neighbors) is // handled by `DiskANNIndex::inplace_delete` → `drop_adj_list`. - self.full_vectors.delete_vector(id as usize); - self.quant_vectors.delete_vector(id as usize); + self.full_vectors.delete_vector(id.into_usize()); + self.quant_vectors.delete_vector(id.into_usize()); std::future::ready(Ok(())) } @@ -522,12 +527,13 @@ where /// Allow `&BfTreeProvider` to implement `IntoIter` /// -impl IntoIterator for &BfTreeProvider +impl IntoIterator for &BfTreeProvider where T: VectorRepr, + I: BfTreeId, { - type Item = u32; - type IntoIter = std::ops::Range; + type Item = I; + type IntoIter = std::iter::Map, fn(usize) -> I>; fn into_iter(self) -> Self::IntoIter { self.iter() } @@ -565,12 +571,13 @@ impl CreateQuantProvider for Poly { } } -impl BfTreeProvider +impl BfTreeProvider where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - pub fn neighbors(&self) -> &NeighborProvider { + pub fn neighbors(&self) -> &NeighborProvider { &self.neighbor_provider } } @@ -579,27 +586,28 @@ where // Data Provider // /////////////////// -impl DataProvider for BfTreeProvider +impl DataProvider for BfTreeProvider where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { type Context = DefaultContext; // The `BfTreeProvider` uses the identity map for IDs. // - type InternalId = u32; + type InternalId = I; // The `BfTreeProvider` uses the identity map for IDs. // - type ExternalId = u32; + type ExternalId = I; // Use a general error type for now. // type Error = ANNError; // No insert-ID recovery. - type Guard = NoopGuard; + type Guard = NoopGuard; // Translate an external id to its corresponding internal id. // @@ -622,12 +630,13 @@ where } } -impl HasId for BfTreeProvider +impl HasId for BfTreeProvider where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - type Id = u32; + type Id = I; } //////////////// @@ -636,9 +645,10 @@ where /// Assign to both the full-precision and quant vector stores /// -impl SetElement<&[T]> for BfTreeProvider +impl SetElement<&[T]> for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { type SetError = ANNError; @@ -651,7 +661,7 @@ where fn set_element( &self, _context: &Self::Context, - id: &u32, + id: &I, element: &[T], ) -> impl Future> + Send { let _guard = self.locks.lock(id.into_usize()); @@ -676,9 +686,10 @@ where /// Assign to just the full-precision store /// -impl SetElement<&[T]> for BfTreeProvider +impl SetElement<&[T]> for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { type SetError = ANNError; @@ -687,7 +698,7 @@ where fn set_element( &self, _context: &Self::Context, - id: &u32, + id: &I, element: &[T], ) -> impl Future> + Send { let _guard = self.locks.lock(id.into_usize()); @@ -732,12 +743,13 @@ pub trait StartPoint { /// /// This implementation sets both the full-precision and quantized vectors for each /// start point, as well as initializing empty neighbor lists. -impl StartPoint for BfTreeProvider +impl StartPoint for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { fn set_start_points(&self, _hidden: Hidden, start_points: MatrixView<'_, T>) -> ANNResult<()> { - let start_point_ids = self.full_vectors.starting_points()?; + let start_point_ids: Vec = self.full_vectors.starting_points()?; if start_points.nrows() != start_point_ids.len() { return Err(ANNError::log_async_index_error(format!( "expected start_points to contain `{}` rows, instead it has {}", @@ -763,12 +775,13 @@ where /// /// This implementation sets the full-precision vectors for each start point /// and initializes empty neighbor lists. -impl StartPoint for BfTreeProvider +impl StartPoint for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { fn set_start_points(&self, _hidden: Hidden, start_points: MatrixView<'_, T>) -> ANNResult<()> { - let start_point_ids = self.full_vectors.starting_points()?; + let start_point_ids: Vec = self.full_vectors.starting_points()?; if start_points.nrows() != start_point_ids.len() { return Err(ANNError::log_async_index_error(format!( "expected start_points to contain `{}` rows, instead it has {}", @@ -800,25 +813,27 @@ where /// * [`Accessor`] for the [`BfTreeProvider`]. /// * [`BuildQueryComputer`]. /// -pub struct FullAccessor<'a, T, Q> +pub struct FullAccessor<'a, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { /// The host provider. - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, /// The fused query-distance computer. computer: T::QueryDistance, /// A buffer to store retrieved elements. element: Box<[T]>, } -impl<'a, T, Q> FullAccessor<'a, T, Q> +impl<'a, T, Q, I> FullAccessor<'a, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - pub(crate) fn new(provider: &'a BfTreeProvider, query: &[T]) -> Self { + pub(crate) fn new(provider: &'a BfTreeProvider, query: &[T]) -> Self { Self { provider, computer: T::query_distance(query, provider.metric), @@ -828,7 +843,7 @@ where } } - fn get_distance(&mut self, id: u32) -> Result { + fn get_distance(&mut self, id: I) -> Result { self.provider .full_vectors .get_vector_into(id.into_usize(), &mut self.element) @@ -836,20 +851,22 @@ where } } -impl HasId for FullAccessor<'_, T, Q> +impl HasId for FullAccessor<'_, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - type Id = u32; + type Id = I; } -impl glue::SearchAccessor for FullAccessor<'_, T, Q> +impl glue::SearchAccessor for FullAccessor<'_, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - fn starting_points(&self) -> impl Future>> { + fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } @@ -904,23 +921,25 @@ where /// /// * [`Accessor`] for the `BfTreeProvider`. /// -pub struct QuantAccessor<'a, T> +pub struct QuantAccessor<'a, T, I> where T: VectorRepr, + I: BfTreeId, { - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, /// The fused query-distance computer. computer: super::quant::QuantQueryComputer, /// A buffer to store retrieved elements. element: Box<[u8]>, } -impl<'a, T> QuantAccessor<'a, T> +impl<'a, T, I> QuantAccessor<'a, T, I> where T: VectorRepr, + I: BfTreeId, { pub(crate) fn new( - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, query: &[T], ) -> ANNResult { let computer = provider.quant_vectors.query_computer(query)?; @@ -933,7 +952,7 @@ where }) } - fn get_distance(&mut self, id: u32) -> Result { + fn get_distance(&mut self, id: I) -> Result { match self .provider .quant_vectors @@ -948,18 +967,20 @@ where } } -impl HasId for QuantAccessor<'_, T> +impl HasId for QuantAccessor<'_, T, I> where T: VectorRepr, + I: BfTreeId, { - type Id = u32; + type Id = I; } -impl glue::SearchAccessor for QuantAccessor<'_, T> +impl glue::SearchAccessor for QuantAccessor<'_, T, I> where T: VectorRepr, + I: BfTreeId, { - fn starting_points(&self) -> impl Future>> { + fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } @@ -1009,25 +1030,27 @@ where /////////////////////// /// A [`glue::PruneAccessor`] for full-precision vectors in the `BfTreeProvider`. -pub struct FullPruneAccessor<'a, T, Q> +pub struct FullPruneAccessor<'a, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - provider: &'a BfTreeProvider, - neighbors: NeighborAccessor<'a, u32>, - set: map::Map, map::Ref<[T]>>, + provider: &'a BfTreeProvider, + neighbors: NeighborAccessor<'a, I>, + set: map::Map, map::Ref<[T]>>, distance: T::Distance, } -impl<'a, T, Q> FullPruneAccessor<'a, T, Q> +impl<'a, T, Q, I> FullPruneAccessor<'a, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { fn new( - provider: &'a BfTreeProvider, - set: map::Map, map::Ref<[T]>>, + provider: &'a BfTreeProvider, + set: map::Map, map::Ref<[T]>>, ) -> Self { Self { provider, @@ -1038,23 +1061,25 @@ where } } -impl HasId for FullPruneAccessor<'_, T, Q> +impl HasId for FullPruneAccessor<'_, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - type Id = u32; + type Id = I; } -impl<'q, T, Q> glue::PruneAccessor for FullPruneAccessor<'q, T, Q> +impl<'q, T, Q, I> glue::PruneAccessor for FullPruneAccessor<'q, T, Q, I> where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { type ElementRef<'a> = &'a [T]; type View<'a> - = map::View<'a, u32, Box<[T]>, map::Ref<[T]>> + = map::View<'a, I, Box<[T]>, map::Ref<[T]>> where Self: 'a; @@ -1064,7 +1089,7 @@ where Self: 'a; type Neighbors<'a> - = diskann::provider::Neighbors<'a, NeighborAccessor<'q, u32>> + = diskann::provider::Neighbors<'a, NeighborAccessor<'q, I>> where Self: 'a; @@ -1077,7 +1102,7 @@ where { let mut buf: Option> = None; - let view = self.set.fill(itr, |i: u32| -> ANNResult<_> { + let view = self.set.fill(itr, |i: I| -> ANNResult<_> { let mut b = match buf.take() { Some(b) => b, None => std::iter::repeat_n(T::default(), self.provider.dim()).collect(), @@ -1111,22 +1136,24 @@ where //////////////////////// /// A [`glue::PruneAccessor`] for quantized vectors in the `BfTreeProvider`. -pub struct QuantPruneAccessor<'a, T> +pub struct QuantPruneAccessor<'a, T, I> where T: VectorRepr, + I: BfTreeId, { - provider: &'a BfTreeProvider, - neighbors: NeighborAccessor<'a, u32>, - set: map::Map, + provider: &'a BfTreeProvider, + neighbors: NeighborAccessor<'a, I>, + set: map::Map, distance: UnwrapErr, } -impl<'a, T> QuantPruneAccessor<'a, T> +impl<'a, T, I> QuantPruneAccessor<'a, T, I> where T: VectorRepr, + I: BfTreeId, { fn new( - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, capacity: usize, ) -> ANNResult { let distance = provider @@ -1143,21 +1170,23 @@ where } } -impl HasId for QuantPruneAccessor<'_, T> +impl HasId for QuantPruneAccessor<'_, T, I> where T: VectorRepr, + I: BfTreeId, { - type Id = u32; + type Id = I; } -impl<'q, T> glue::PruneAccessor for QuantPruneAccessor<'q, T> +impl<'q, T, I> glue::PruneAccessor for QuantPruneAccessor<'q, T, I> where T: VectorRepr, + I: BfTreeId, { type ElementRef<'a> = Opaque<'a>; type View<'a> - = map::View<'a, u32, Owned> + = map::View<'a, I, Owned> where Self: 'a; @@ -1167,7 +1196,7 @@ where Self: 'a; type Neighbors<'a> - = diskann::provider::Neighbors<'a, NeighborAccessor<'q, u32>> + = diskann::provider::Neighbors<'a, NeighborAccessor<'q, I>> where Self: 'a; @@ -1181,7 +1210,7 @@ where let mut buf: Option> = None; let bytes = self.provider.quant_vectors.quantizer.bytes(); - let view = self.set.fill(itr, |i: u32| -> ANNResult<_> { + let view = self.set.fill(itr, |i: I| -> ANNResult<_> { let mut b = match buf.take() { Some(b) => b, None => std::iter::repeat_n(0, bytes).collect(), @@ -1233,17 +1262,18 @@ impl<'short> diskann_utils::Reborrow<'short> for Owned { /// Perform a search entirely in the full-precision space. /// /// Starting points are not filtered out of the final results. -impl<'a, T, Q> SearchStrategy<'a, BfTreeProvider, &'a [T]> for FullPrecision +impl<'a, T, Q, I> SearchStrategy<'a, BfTreeProvider, &'a [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - type SearchAccessor = FullAccessor<'a, T, Q>; + type SearchAccessor = FullAccessor<'a, T, Q, I>; type SearchAccessorError = Infallible; fn search_accessor( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, query: &'a [T], ) -> Result { @@ -1251,26 +1281,28 @@ where } } -impl<'a, T, Q> DefaultPostProcessor<'a, BfTreeProvider, &'a [T]> for FullPrecision +impl<'a, T, Q, I> DefaultPostProcessor<'a, BfTreeProvider, &'a [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { default_post_processor!(glue::Pipeline); } // Pruning -impl PruneStrategy> for FullPrecision +impl PruneStrategy> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { - type PruneAccessor<'a> = FullPruneAccessor<'a, T, Q>; + type PruneAccessor<'a> = FullPruneAccessor<'a, T, Q, I>; type PruneAccessorError = diskann::error::Infallible; fn prune_accessor<'a>( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, capacity: usize, ) -> Result, Self::PruneAccessorError> { @@ -1279,10 +1311,11 @@ where } } -impl<'a, T, Q> InsertStrategy<'a, BfTreeProvider, &'a [T]> for FullPrecision +impl<'a, T, Q, I> InsertStrategy<'a, BfTreeProvider, &'a [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { type PruneStrategy = Self; fn prune_strategy(&self) -> Self::PruneStrategy { @@ -1290,13 +1323,14 @@ where } } -impl MultiInsertStrategy, B> for FullPrecision +impl MultiInsertStrategy, B> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, B: for<'a> Batch = &'a [T]> + Debug, { - type Seed = map::Builder>; + type Seed = map::Builder>; type FinishError = diskann::error::Infallible; type PruneStrategy = Self; type InsertStrategy = Self; @@ -1307,13 +1341,13 @@ where fn finish( &self, - _provider: &BfTreeProvider, + _provider: &BfTreeProvider, _ctx: &DefaultContext, batch: &std::sync::Arc, ids: Itr, ) -> impl std::future::Future> + Send where - Itr: ExactSizeIterator + Send, + Itr: ExactSizeIterator + Send, { let overlay = map::Overlay::from_batch(batch.clone(), ids); let builder = map::Builder::new(map::Capacity::Default).with_overlay(overlay); @@ -1322,11 +1356,11 @@ where fn seeded_prune_accessor<'a>( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, seed: &'a Self::Seed, capacity: usize, - ) -> ANNResult> { + ) -> ANNResult> { let set = seed.clone().build(capacity); Ok(FullPruneAccessor::new(provider, set)) } @@ -1340,16 +1374,17 @@ where /// [`InplaceDeleteMethod::TwoHopAndOneHop`]. It is **not compatible** with /// [`InplaceDeleteMethod::VisitedAndTopK`] because `BfTreeProvider` performs hard deletes — /// the vector data is erased before `get_delete_element` is called, causing it to fail. -impl InplaceDeleteStrategy> for FullPrecision +impl InplaceDeleteStrategy> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, + I: BfTreeId, { type DeleteElementError = ANNError; type DeleteElement<'a> = &'a [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = FullAccessor<'a, T, Q>; + type DeleteSearchAccessor<'a> = FullAccessor<'a, T, Q, I>; type SearchPostProcessor = CopyIds; type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { @@ -1366,9 +1401,9 @@ where async fn get_delete_element<'a>( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, - id: u32, + id: I, ) -> Result { use diskann::error::ErrorExt; let elt = provider @@ -1383,16 +1418,17 @@ where /// Perform a search entirely in the quantized space. /// /// Starting points are not filtered out of the final results. -impl<'a, T> SearchStrategy<'a, BfTreeProvider, &'a [T]> for Quantized +impl<'a, T, I> SearchStrategy<'a, BfTreeProvider, &'a [T]> for Quantized where T: VectorRepr, + I: BfTreeId, { - type SearchAccessor = QuantAccessor<'a, T>; + type SearchAccessor = QuantAccessor<'a, T, I>; type SearchAccessorError = ANNError; fn search_accessor( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, query: &'a [T], ) -> Result { @@ -1400,16 +1436,19 @@ where } } -impl<'a, T> DefaultPostProcessor<'a, BfTreeProvider, &'a [T]> for Quantized +impl<'a, T, I> DefaultPostProcessor<'a, BfTreeProvider, &'a [T]> + for Quantized where T: VectorRepr, + I: BfTreeId, { default_post_processor!(glue::Pipeline); } -impl<'a, T> InsertStrategy<'a, BfTreeProvider, &'a [T]> for Quantized +impl<'a, T, I> InsertStrategy<'a, BfTreeProvider, &'a [T]> for Quantized where T: VectorRepr, + I: BfTreeId, { type PruneStrategy = Self; fn prune_strategy(&self) -> Self::PruneStrategy { @@ -1417,9 +1456,10 @@ where } } -impl MultiInsertStrategy, B> for Quantized +impl MultiInsertStrategy, B> for Quantized where T: VectorRepr, + I: BfTreeId, B: glue::Batch, B: for<'a> Batch = &'a [T]> + Debug, { @@ -1434,24 +1474,24 @@ where fn finish( &self, - _provider: &BfTreeProvider, + _provider: &BfTreeProvider, _ctx: &DefaultContext, _batch: &std::sync::Arc, _ids: Itr, ) -> impl std::future::Future> + Send where - Itr: ExactSizeIterator + Send, + Itr: ExactSizeIterator + Send, { std::future::ready(Ok(())) } fn seeded_prune_accessor<'a>( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, _seed: &'a (), capacity: usize, - ) -> ANNResult> { + ) -> ANNResult> { QuantPruneAccessor::new(provider, capacity) } } @@ -1462,15 +1502,16 @@ where /// /// Same constraint as [`FullPrecision`]'s impl: not compatible with /// [`InplaceDeleteMethod::VisitedAndTopK`] due to hard deletes. -impl InplaceDeleteStrategy> for Quantized +impl InplaceDeleteStrategy> for Quantized where T: VectorRepr, + I: BfTreeId, { type DeleteElementError = ANNError; type DeleteElement<'a> = &'a [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = QuantAccessor<'a, T>; + type DeleteSearchAccessor<'a> = QuantAccessor<'a, T, I>; type SearchPostProcessor = Rerank; type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { @@ -1487,9 +1528,9 @@ where async fn get_delete_element<'a>( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, - id: u32, + id: I, ) -> Result { use diskann::error::ErrorExt; provider @@ -1501,16 +1542,17 @@ where } // Pruning -impl PruneStrategy> for Quantized +impl PruneStrategy> for Quantized where T: VectorRepr, + I: BfTreeId, { - type PruneAccessor<'a> = QuantPruneAccessor<'a, T>; + type PruneAccessor<'a> = QuantPruneAccessor<'a, T, I>; type PruneAccessorError = ANNError; fn prune_accessor<'a>( &'a self, - provider: &'a BfTreeProvider, + provider: &'a BfTreeProvider, _context: &'a DefaultContext, capacity: usize, ) -> Result, Self::PruneAccessorError> { @@ -1522,28 +1564,29 @@ where #[derive(Debug, Default, Clone, Copy)] pub struct Rerank; -impl<'a, T> glue::SearchPostProcess, &[T]> for Rerank +impl<'a, T, I> glue::SearchPostProcess, &[T]> for Rerank where T: VectorRepr, + I: BfTreeId, { type Error = ANNError; - fn post_process( + fn post_process( &self, - accessor: &mut QuantAccessor<'a, T>, + accessor: &mut QuantAccessor<'a, T, I>, query: &[T], - candidates: I, + candidates: Itr, output: &mut B, ) -> impl Future> + Send where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, + Itr: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, { use diskann::error::ErrorExt; let provider = accessor.provider; let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); - let mut reranked: Vec<(u32, f32)> = Vec::new(); + let mut reranked: Vec<(I, f32)> = Vec::new(); for n in candidates { match provider .full_vectors @@ -1719,9 +1762,10 @@ fn load_bftree(snapshot_path: std::path::PathBuf, use_snapshot: bool) -> Result< // Serialization // ////////////////////// -impl SaveWith for BfTreeProvider +impl SaveWith for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { type Ok = usize; type Error = ANNError; @@ -1784,9 +1828,10 @@ where } } -impl LoadWith for BfTreeProvider +impl LoadWith for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { type Error = ANNError; @@ -1822,10 +1867,8 @@ where BfTreePaths::neighbors_bftree(&saved_params.prefix), saved_params.use_snapshot, )?; - let neighbor_provider = NeighborProvider::::new_from_bftree( - saved_params.max_degree, - adjacency_list_index, - )?; + let neighbor_provider = + NeighborProvider::::new_from_bftree(saved_params.max_degree, adjacency_list_index)?; Ok(Self { quant_vectors: NoStore, @@ -1839,9 +1882,10 @@ where } } -impl SaveWith for BfTreeProvider +impl SaveWith for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { type Ok = usize; type Error = ANNError; @@ -1929,9 +1973,10 @@ where } } -impl LoadWith for BfTreeProvider +impl LoadWith for BfTreeProvider where T: VectorRepr, + I: BfTreeId, { type Error = ANNError; @@ -1971,10 +2016,8 @@ where BfTreePaths::neighbors_bftree(&saved_params.prefix), saved_params.use_snapshot, )?; - let neighbor_provider = NeighborProvider::::new_from_bftree( - saved_params.max_degree, - adjacency_list_index, - )?; + let neighbor_provider = + NeighborProvider::::new_from_bftree(saved_params.max_degree, adjacency_list_index)?; let filename = BfTreePaths::quant_data_bin(&saved_params.prefix); let mut reader = storage.open_reader(&filename)?; @@ -2102,6 +2145,79 @@ mod tests { assert_eq!(neighbors[0].id, 3); } + /// Build and search a quantized index whose vertex ids are `u64` rather than the + /// default `u32`. This exercises the full generic pipeline end-to-end (set_element, + /// `u64`-keyed neighbor storage via `bytemuck`, greedy search, and rerank) to prove + /// that the `BfTreeProvider<_, _, u64>` path is functional and not merely compilable. + #[tokio::test] + async fn test_quantized_index_search_u64_ids() { + let start_point = Matrix::new(Init(|| 0.0f32), 1, 5); + let dim = 5; + let logical_max_degree = 6; + let physical_max_degree = (logical_max_degree as f32 * 1.3) as u32; + let metric = Metric::L2; + + let provider: BfTreeProvider = BfTreeProvider::new( + BfTreeProviderParameters { + max_points: 20, + num_start_points: NonZeroUsize::new(1).unwrap(), + dim, + metric, + max_degree: physical_max_degree, + vector_provider_config: Config::default(), + quant_vector_provider_config: Config::default(), + neighbor_list_provider_config: Config::default(), + graph_params: None, + use_snapshot: false, + }, + start_point.as_view(), + create_test_quantizer(5), + ) + .unwrap(); + + let index_config = graph::config::Builder::new_with( + logical_max_degree as usize, + graph::config::MaxDegree::new(physical_max_degree as usize), + 10, + metric.into(), + |_| {}, + ) + .build() + .unwrap(); + + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + let ctx = &DefaultContext; + + for i in 0u64..15 { + let point = vec![i as f32; 5]; + index + .insert(&Quantized, ctx, &i, point.as_slice()) + .await + .unwrap(); + } + + let query = vec![3.0; 5]; + let params = Knn::new(5, 10, None).unwrap(); + + let mut neighbors = vec![Neighbor::::default(); 5]; + let res = index + .search( + params, + &Quantized, + &DefaultContext, + query.as_slice(), + &mut BackInserter::new(neighbors.as_mut_slice()), + ) + .await + .unwrap(); + + assert_eq!( + res.result_count, 5, + "there are 15 points and we're asking for 5, we expect 5" + ); + assert_eq!(neighbors[0].id, 3u64); + } + #[tokio::test] async fn test_quantized_index_multi_insert_search() { let index = create_quant_index(); @@ -2364,7 +2480,7 @@ mod tests { // Iterator // - assert_eq!((&provider).into_iter(), 0..(10 + 2)); + assert!((&provider).into_iter().eq(0u32..(10 + 2))); let iter = provider.iter(); diff --git a/diskann-bftree/src/vectors.rs b/diskann-bftree/src/vectors.rs index 661ffac08..e5d1beb8c 100644 --- a/diskann-bftree/src/vectors.rs +++ b/diskann-bftree/src/vectors.rs @@ -95,11 +95,13 @@ impl VectorProvider { /// Return a vector of vector Ids of the starting points /// #[inline(always)] - pub fn starting_points(&self) -> ANNResult> { + pub fn starting_points(&self) -> ANNResult> { (self.max_vectors..self.total()) .map(|i| { - u32::try_from(i).map_err(|_| { - ANNError::log_index_error(format_args!("start point id {i} exceeds u32::MAX")) + I::try_from_index(i).ok_or_else(|| { + ANNError::log_index_error(format_args!( + "start point id {i} exceeds the id type's maximum" + )) }) }) .collect() From c861d224eb48364b0e3da3b2857bd490366b95d3 Mon Sep 17 00:00:00 2001 From: Jordan Maples Date: Tue, 30 Jun 2026 13:23:47 -0700 Subject: [PATCH 2/2] Harden u64 bftree provider path Reinforce the newly-reachable u64 id path surfaced by review: - neighbors: size record key/value buffers by size_of::() instead of hardcoded size_of::(), so u64 providers reserve correct record sizes. - id: add validate_id_capacity to reject vertex counts that overflow the id type, and guard new_empty with it. - provider: persist id_width in SavedParams (back-compat default of 4) and validate it on load, so a u64 index can't be silently loaded as u32. - tests: add u64 high-bit id coverage (ids/neighbors > u32::MAX, low-bit collision check) plus a u64 save/load round-trip, id-width mismatch rejection, and capacity-guard rejection. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- diskann-bftree/src/id.rs | 20 ++++ diskann-bftree/src/neighbors.rs | 43 ++++++++- diskann-bftree/src/provider.rs | 158 ++++++++++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 2 deletions(-) diff --git a/diskann-bftree/src/id.rs b/diskann-bftree/src/id.rs index af8cd3bfc..06aab68f1 100644 --- a/diskann-bftree/src/id.rs +++ b/diskann-bftree/src/id.rs @@ -6,6 +6,7 @@ //! Vertex-id abstraction for the bf-tree provider. use diskann::utils::{IntoUsize, VectorId}; +use diskann::{ANNError, ANNResult}; /// Identifier type usable as a `BfTreeProvider` vertex id. /// @@ -52,3 +53,22 @@ impl BfTreeId for u64 { u64::try_from(index).ok() } } + +/// Validate that a provider holding `total` ids can represent every id in `0..total`. +/// +/// `BfTreeProvider::iter` mints ids densely via the infallible (truncating) +/// [`BfTreeId::from_index`]; callers must guarantee the range fits in `I`. This check +/// enforces that guarantee up front (at construction and load) so the truncating +/// conversion can never silently wrap a real id. +pub(crate) fn validate_id_capacity(total: usize) -> ANNResult<()> { + if let Some(last) = total.checked_sub(1) { + if I::try_from_index(last).is_none() { + return Err(ANNError::log_index_error(format!( + "provider capacity of {total} ids exceeds the maximum representable by the \ + {}-byte vertex id type", + std::mem::size_of::() + ))); + } + } + Ok(()) +} diff --git a/diskann-bftree/src/neighbors.rs b/diskann-bftree/src/neighbors.rs index 117e9c31b..2cec196e2 100644 --- a/diskann-bftree/src/neighbors.rs +++ b/diskann-bftree/src/neighbors.rs @@ -35,8 +35,11 @@ impl HasId for NeighborProvider { impl NeighborProvider { /// Create a new instance based on bf-tree Config directly. pub fn new_with_config(max_degree: u32, config: Config) -> ANNResult { - let key_size = std::mem::size_of::(); - let value_size = (max_degree as usize + 1) * std::mem::size_of::(); + // Records are keyed by an `I`-width id and store a `dim`-cell `I`-width value + // (`dim == max_degree + 1`), so size the validation by the actual id width + // rather than assuming `u32`. + let key_size = std::mem::size_of::(); + let value_size = (max_degree as usize + 1) * std::mem::size_of::(); crate::validate_record_size("neighbor_provider", &config, key_size, value_size)?; let adj_list_index = BfTree::with_config(config, None).map_err(ConfigError)?; @@ -407,6 +410,42 @@ mod tests { } } + /// Exercise the `u64` id path beyond the `u32` range: vertex ids and neighbor + /// values above `u32::MAX` must round-trip, and ids that share their low 32 bits + /// must not collide (proving the bf-tree key uses the full 8-byte width). + #[tokio::test] + async fn test_u64_high_bit_ids() { + let locks = Arc::new(StripedLocks::new()); + let neighbor_provider = + NeighborProvider::::new_with_config(6, Config::default()).unwrap(); + let mut scratch = neighbor_provider.scratch(&locks); + + let high: u64 = (u32::MAX as u64) + 1; + let big_id: u64 = high + 7; + let big_neighbors: Vec = vec![high, high + 1, u64::MAX, 3]; + scratch.write_neighbors(big_id, &big_neighbors).unwrap(); + + let mut result = AdjacencyList::with_capacity(10); + neighbor_provider + .get_neighbors(big_id, &mut result) + .unwrap(); + assert_eq!(&*big_neighbors, &*result); + + // `low` and `low | (1 << 32)` share their low 32 bits; with an 8-byte key they + // are distinct entries. A truncating 4-byte key would alias them. + let low: u64 = 1; + let aliased: u64 = low | (1u64 << 32); + scratch.write_neighbors(low, &[10, 11]).unwrap(); + scratch.write_neighbors(aliased, &[20, 21]).unwrap(); + + neighbor_provider.get_neighbors(low, &mut result).unwrap(); + assert_eq!(&[10u64, 11], &*result); + neighbor_provider + .get_neighbors(aliased, &mut result) + .unwrap(); + assert_eq!(&[20u64, 21], &*result); + } + /// Test corner cases of appending to neighbor list #[tokio::test] async fn test_neighbor_accessors() { diff --git a/diskann-bftree/src/provider.rs b/diskann-bftree/src/provider.rs index 03d58ac0c..f3c675937 100644 --- a/diskann-bftree/src/provider.rs +++ b/diskann-bftree/src/provider.rs @@ -301,6 +301,8 @@ where .quant_vector_provider_config .use_snapshot(params.use_snapshot); + crate::id::validate_id_capacity::(params.max_points + params.num_start_points.get())?; + Ok(Self { quant_vectors: quant_precursor.create(params.quant_vector_provider_config)?, full_vectors: VectorProvider::new_with_config( @@ -1638,6 +1640,31 @@ pub struct SavedParams { /// Whether CPR snapshot support was enabled. #[serde(default)] pub use_snapshot: bool, + /// Width in bytes of the vertex id type the index was built with + /// (`size_of::()`). Defaults to 4 (`u32`) for indexes saved before this + /// field existed, all of which used `u32` ids. + #[serde(default = "default_id_width")] + pub id_width: usize, +} + +/// Default [`SavedParams::id_width`] for legacy indexes (all of which were `u32`). +fn default_id_width() -> usize { + std::mem::size_of::() +} + +/// Validate, on load, that the persisted id metadata is compatible with the id type +/// `I` the caller is loading as: the width must match what the index was built with, +/// and the capacity must fit in `I`. +fn validate_loaded_id_params(saved_params: &SavedParams) -> ANNResult<()> { + let expected = std::mem::size_of::(); + if saved_params.id_width != expected { + return Err(ANNError::log_index_error(format!( + "index was built with {}-byte vertex ids but is being loaded with a {expected}-byte \ + id type; load it with the matching id type", + saved_params.id_width, + ))); + } + crate::id::validate_id_capacity::(saved_params.max_points + saved_params.frozen_points.get()) } /// The element type of the full-precision vectors stored in the index. @@ -1796,6 +1823,7 @@ where graph_params: self.graph_params.clone(), is_memory: self.full_vectors.config().is_memory_backend(), use_snapshot: self.use_snapshot, + id_width: std::mem::size_of::(), }; debug_assert_eq!( @@ -1849,6 +1877,8 @@ where })? }; + validate_loaded_id_params::(&saved_params)?; + let metric = Metric::from_str(&saved_params.metric) .map_err(|e| ANNError::log_index_error(format!("Failed to parse metric: {}", e)))?; @@ -1922,6 +1952,7 @@ where graph_params: self.graph_params.clone(), is_memory: self.full_vectors.config().is_memory_backend(), use_snapshot: self.use_snapshot, + id_width: std::mem::size_of::(), }; debug_assert_eq!( @@ -1994,6 +2025,8 @@ where })? }; + validate_loaded_id_params::(&saved_params)?; + let _quant_params = saved_params.quant_params.ok_or_else(|| { ANNError::log_index_error("Missing quant_params in saved params for quantized provider") })?; @@ -3199,4 +3232,129 @@ mod tests { panic!("NeighborProvider should succeed: {e}"); } } + + /// A `u64`-id provider must round-trip through save/load with the wider on-disk + /// neighbor format, and loading the same bytes with a mismatched id width must fail. + #[tokio::test] + async fn test_bf_tree_provider_save_load_u64_ids() { + let num_points = 16usize; + let dim = 4usize; + let max_degree = 16u32; + let num_start_points = NonZeroUsize::new(2).unwrap(); + let ctx = &DefaultContext; + + let temp_dir = tempdir().unwrap(); + let prefix = temp_dir + .path() + .join("u64_provider") + .to_string_lossy() + .to_string(); + + let mut vector_config = Config::new(BfTreePaths::vectors_bftree(&prefix), 1024 * 1024); + vector_config.storage_backend(bf_tree::StorageBackend::Std); + vector_config.use_snapshot(true); + + let mut neighbor_config = Config::new(BfTreePaths::neighbors_bftree(&prefix), 1024 * 1024); + neighbor_config.storage_backend(bf_tree::StorageBackend::Std); + neighbor_config.use_snapshot(true); + + let params = BfTreeProviderParameters { + max_points: num_points, + num_start_points, + dim, + metric: Metric::L2, + max_degree, + vector_provider_config: vector_config, + quant_vector_provider_config: Config::default(), + neighbor_list_provider_config: neighbor_config, + graph_params: None, + use_snapshot: true, + }; + + let start_points = Matrix::new(Init(|| 0.0f32), num_start_points.into(), dim); + let provider = + BfTreeProvider::::new(params, start_points.as_view(), NoStore) + .unwrap(); + + for i in 0..num_points { + let vector: Vec = (0..dim).map(|j| (i * dim + j) as f32 * 0.1).collect(); + provider + .set_element(ctx, &(i as u64), &vector) + .await + .unwrap(); + } + + let mut scratch = provider.neighbor_provider.scratch(&provider.locks); + for i in 0..num_points as u64 { + let neighbors: Vec = (0..std::cmp::min(i, max_degree as u64)) + .map(|j| (i + j) % num_points as u64) + .collect(); + scratch.write_neighbors(i, &neighbors).unwrap(); + } + drop(scratch); + + let storage = FileStorageProvider; + let save_dir = tempdir().unwrap(); + let save_prefix = save_dir + .path() + .join("saved_u64_provider") + .to_string_lossy() + .to_string(); + provider.save_with(&storage, &save_prefix).await.unwrap(); + + let loaded = BfTreeProvider::::load_with(&storage, &save_prefix) + .await + .unwrap(); + + for i in 0..num_points as u64 { + let mut original_list = AdjacencyList::new(); + let mut loaded_list = AdjacencyList::new(); + provider + .neighbor_provider + .get_neighbors(i, &mut original_list) + .unwrap(); + loaded + .neighbor_provider + .get_neighbors(i, &mut loaded_list) + .unwrap(); + assert_eq!(&*original_list, &*loaded_list, "neighbor mismatch at {i}"); + } + + // Loading a u64-saved index as u32 must be rejected, not silently misinterpreted. + let mismatch = BfTreeProvider::::load_with(&storage, &save_prefix).await; + assert!( + mismatch.is_err(), + "loading a u64 index with a u32 id type must fail" + ); + } + + /// Constructing a provider whose vertex count cannot be represented by the id type + /// must fail eagerly rather than silently truncate ids. + #[tokio::test] + async fn test_new_rejects_capacity_exceeding_id_type() { + let dim = 4usize; + let num_start_points = NonZeroUsize::new(1).unwrap(); + let start_points = Matrix::new(Init(|| 0.0f32), num_start_points.into(), dim); + + let params = BfTreeProviderParameters { + // Largest index would be u32::MAX + 1, which a u32 id cannot hold. + max_points: u32::MAX as usize + 2, + num_start_points, + dim, + metric: Metric::L2, + max_degree: 8, + vector_provider_config: Config::default(), + quant_vector_provider_config: Config::default(), + neighbor_list_provider_config: Config::default(), + graph_params: None, + use_snapshot: false, + }; + + let result = + BfTreeProvider::::new(params, start_points.as_view(), NoStore); + assert!( + result.is_err(), + "u32 provider must reject a capacity exceeding u32::MAX" + ); + } }