Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,52 @@ mod tests {

use super::*;

// Helper function to clean up checkpoint files after tests
fn clean_checkpoint_file(prefix: &str, identifier: u64) {
let checkpoint_file = format!("{prefix}_{identifier}.checkpoint");
if std::path::Path::new(&checkpoint_file).exists() {
// There is a possible race between checking that the file exists and removing
// the file here, but since this is test code, that is unlikely.
fs::remove_file(&checkpoint_file).unwrap();
}
#[test]
fn test_has_completed_false_when_no_file() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("nonexistent_index")
.to_str()
.unwrap()
.to_string();
let manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 999);
Comment thread
arrayka marked this conversation as resolved.
assert!(!manager.has_completed()?);
Ok(())
}

#[test]
fn test_mark_as_invalid() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("test_invalid")
.to_str()
.unwrap()
.to_string();
let identifier = 77;

let mut manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
// Advance to a later stage with some progress
manager.update(Progress::Completed, WorkStage::QuantizeFPV)?;
manager.update(Progress::Processed(42), WorkStage::InMemIndexBuild)?;

// Verify we can resume from progress=42
let manager2 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
assert_eq!(
manager2.get_resumption_point(WorkStage::QuantizeFPV)?,
Some(42)
);

// Mark as invalid - progress resets to 0 (is_valid=false => progress read as 0)
let mut manager3 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
manager3.mark_as_invalid()?;
assert_eq!(
manager3.get_resumption_point(WorkStage::QuantizeFPV)?,
Some(0)
);

Ok(())
}

#[test]
Expand All @@ -115,9 +153,6 @@ mod tests {
.to_string();
let identifier = 42;

// Clean up any existing files
clean_checkpoint_file(&index_prefix, identifier);

// Define a helper function to process a stage with interruption and resumption
fn process_stage(
index_prefix: &str,
Expand Down Expand Up @@ -176,9 +211,6 @@ mod tests {
let manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
assert!(manager.has_completed()?);

// Clean up test files
clean_checkpoint_file(&index_prefix, identifier);

Ok(())
}
}
158 changes: 141 additions & 17 deletions diskann-disk/src/build/chunking/continuation/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,8 @@ mod tests {
checker,
);

assert!(result.is_ok());
match result.unwrap() {
Progress::Completed => assert_eq!(processed, vec![1, 2, 3, 4, 5]),
_ => panic!("Expected Completed"),
}
assert!(matches!(result.unwrap(), Progress::Completed));
assert_eq!(processed, vec![1, 2, 3, 4, 5]);
}

#[test]
Expand All @@ -143,13 +140,145 @@ mod tests {
checker,
);

assert!(result.is_ok());
match result.unwrap() {
Progress::Completed => {}
_ => panic!("Expected Completed"),
assert!(matches!(result.unwrap(), Progress::Completed));
}

/// A tracker that returns Stop after `stop_after` Continue grants.
#[derive(Clone)]
struct StopAfterTracker {
count: std::sync::Arc<std::sync::Mutex<usize>>,
stop_after: usize,
}

impl ContinuationTrackerTrait for StopAfterTracker {
fn get_continuation_grant(&self) -> ContinuationGrant {
let mut count = self.count.lock().unwrap();
if *count >= self.stop_after {
ContinuationGrant::Stop
} else {
*count += 1;
ContinuationGrant::Continue
}
}
}

#[test]
fn test_process_while_resource_is_available_stops_early() {
let tracker = StopAfterTracker {
count: std::sync::Arc::new(std::sync::Mutex::new(0)),
stop_after: 3,
};
let items = vec![10, 20, 30, 40, 50];
let mut processed = Vec::new();

let result = process_while_resource_is_available(
|item| {
processed.push(item);
Ok::<(), TestError>(())
},
items.into_iter(),
Box::new(tracker),
);

// `Processed(n)` reports the number of items processed before the stop grant.
assert!(matches!(
result.unwrap(),
Progress::Processed(processed_count) if processed_count == 3
));
assert_eq!(processed, vec![10, 20, 30]);
}

/// A tracker that yields once (with a tiny duration), then continues.
#[derive(Clone)]
struct YieldOnceThenContinueTracker {
yielded: std::sync::Arc<std::sync::Mutex<bool>>,
}

impl ContinuationTrackerTrait for YieldOnceThenContinueTracker {
fn get_continuation_grant(&self) -> ContinuationGrant {
let mut yielded = self.yielded.lock().unwrap();
if !*yielded {
*yielded = true;
ContinuationGrant::Yield(std::time::Duration::ZERO)
} else {
Comment thread
arrayka marked this conversation as resolved.
// After yielding once, always continue
ContinuationGrant::Continue
}
}
}

#[test]
fn test_process_while_resource_is_available_yield_then_continue() {
let tracker = YieldOnceThenContinueTracker {
yielded: std::sync::Arc::new(std::sync::Mutex::new(false)),
};
let items = vec![1, 2];
let mut processed = Vec::new();

let result = process_while_resource_is_available(
|item| {
processed.push(item);
Ok::<(), TestError>(())
},
items.into_iter(),
Box::new(tracker),
);

// After yielding, it should have continued and processed all items
assert!(matches!(result.unwrap(), Progress::Completed));
assert_eq!(processed, vec![1, 2]);
}

#[test]
fn test_process_while_resource_is_available_action_error() {
let checker = Box::new(NaiveContinuationTracker::default());
let items = vec![1, 2, 3];

let result = process_while_resource_is_available(
|item| {
if item == 2 {
Err(TestError)
} else {
Ok(())
}
},
items.into_iter(),
checker,
);

assert!(result.is_err());
}

#[tokio::test]
async fn test_process_while_resource_is_available_async_stops_early() {
let tracker = StopAfterTracker {
count: std::sync::Arc::new(std::sync::Mutex::new(0)),
stop_after: 2,
};
let items = vec![1, 2, 3, 4, 5];
let processed = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));

let result = process_while_resource_is_available_async(
|item| {
let processed = processed.clone();
async move {
processed.lock().await.push(item);
Ok::<(), TestError>(())
}
},
items.into_iter(),
Box::new(tracker),
)
.await;

assert!(matches!(
result.unwrap(),
Progress::Processed(processed_count) if processed_count == 2
));
let processed = processed.lock().await;
assert_eq!(*processed, vec![1, 2]);
}

#[tokio::test]
async fn test_process_while_resource_is_available_async_completes() {
let checker = Box::new(NaiveContinuationTracker::default());
Expand All @@ -169,13 +298,8 @@ mod tests {
)
.await;

assert!(result.is_ok());
match result.unwrap() {
Progress::Completed => {
let processed = processed.lock().await;
assert_eq!(*processed, vec![1, 2, 3]);
}
_ => panic!("Expected Completed"),
}
assert!(matches!(result.unwrap(), Progress::Completed));
let processed = processed.lock().await;
assert_eq!(*processed, vec![1, 2, 3]);
}
}
36 changes: 29 additions & 7 deletions diskann-disk/src/search/pq/pq_scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,22 @@ impl PQScratch {
})
}

/// Copy the first `dim` elements of `query` into `query_scratch`.
/// Copy `query` into `query_scratch`.
///
/// `query` must already be in full-precision `f32` representation; quantized
/// inputs (e.g. `MinMaxElement`) should be decoded via `VectorRepr::as_f32`
/// at the caller boundary before invoking this method.
///
/// Accepts oversized `query` (only the first `dim` elements are used) for
/// backwards compatibility with callers that hold alignment-padded buffers.
/// Returns `DimensionMismatchError` if `query.len() < query_scratch.len()`.
/// Returns `DimensionMismatchError` if `query.len() != query_scratch.len()`.
pub fn set(&mut self, query: &[f32]) -> ANNResult<()> {
let dim = self.query_scratch.len();
if query.len() < dim {
if query.len() != dim {
return Err(ANNError::log_dimension_mismatch_error(format!(
"PQScratch::set: expected query of length >= {dim}, got {}",
"PQScratch::set: expected query of length {dim}, got {}",
query.len()
)));
}
self.query_scratch.copy_from_slice(&query[..dim]);
self.query_scratch.copy_from_slice(query);
Ok(())
}

Expand Down Expand Up @@ -128,4 +126,28 @@ mod tests {
assert_eq!(pq_scratch.query_scratch[i], query[i]);
});
}

#[test]
fn test_pq_scratch_set_rejects_short_query() {
let dim = 16;
let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap();

// Query shorter than dim should fail
let short_query: Vec<f32> = (1..dim).map(|i| i as f32).collect(); // dim-1 elements
let err = pq_scratch.set(&short_query).unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError);
assert!(err.to_string().contains("expected query of length"));
}

#[test]
fn test_pq_scratch_set_rejects_oversized_query() {
let dim = 8;
let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap();

// Query longer than dim should fail
let long_query: Vec<f32> = (1..=dim + 10).map(|i| i as f32).collect();
let err = pq_scratch.set(&long_query).unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError);
assert!(err.to_string().contains("expected query of length"));
}
Comment thread
arrayka marked this conversation as resolved.
}
44 changes: 44 additions & 0 deletions diskann-disk/src/search/provider/disk_sector_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,48 @@ mod disk_sector_graph_test {
let data = &graph;
assert_eq!(data.len(), 512);
}

#[test]
fn test_reconfigure_grows_buffer() {
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let mut graph = test_initialize_disk_sector_graph(2, 1, reader);
assert_eq!(graph.max_n_batch_sector_read, 4);

// Reconfigure to larger batch — buffer must grow beyond initial 512 bytes
graph.reconfigure(16).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 16);
assert_eq!(graph.sectors_data.len(), 16 * 64);
}

#[test]
fn test_reconfigure_noop_for_smaller_size() {
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let mut graph = test_initialize_disk_sector_graph(2, 1, reader);
let original_len = graph.sectors_data.len();

// Reconfigure with same or smaller size should be a no-op
graph.reconfigure(4).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 4);
assert_eq!(graph.sectors_data.len(), original_len);

graph.reconfigure(2).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 4);
assert_eq!(graph.sectors_data.len(), original_len);
}

#[test]
fn test_new_disk_sector_graph_zero_block_size_defaults() {
let metadata = GraphMetadata::new(1000, 32, 500, 32, 2, 20, 50, 1024, 256);
// block_size = 0 should fall back to DEFAULT_DISK_SECTOR_LEN regardless of version
let header = GraphHeader::new(metadata, 0, GraphLayoutVersion::new(1, 0));
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
assert_eq!(graph.block_size, DEFAULT_DISK_SECTOR_LEN);
}
}
Loading
Loading