diff --git a/Cargo.lock b/Cargo.lock index 1ecde9f9f..8f6f7272e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -912,6 +912,16 @@ dependencies = [ "trybuild", ] +[[package]] +name = "diskann-record" +version = "0.54.0" +dependencies = [ + "anyhow", + "serde", + "serde_json", + "tempfile", +] + [[package]] name = "diskann-tools" version = "0.54.0" diff --git a/Cargo.toml b/Cargo.toml index b285a94f6..d95ee5a1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "diskann-benchmark-core", "diskann-benchmark-simd", "diskann-benchmark", + "diskann-record", "diskann-tools", "vectorset", "diskann-bftree", @@ -66,6 +67,7 @@ diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.54. diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.54.0" } diskann-tools = { path = "diskann-tools", version = "0.54.0" } diskann-bftree = {path = "diskann-bftree", version = "0.54.0" } +diskann-record = { path = "diskann-record", version = "0.54.0" } # External dependencies (shared versions) anyhow = "1.0.98" diff --git a/diskann-record/Cargo.toml b/diskann-record/Cargo.toml new file mode 100644 index 000000000..24bf6db2f --- /dev/null +++ b/diskann-record/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "diskann-record" +version.workspace = true +description.workspace = true +authors.workspace = true +repository.workspace = true +license.workspace = true +edition = "2024" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive"], optional = true } +serde_json = { workspace = true, optional = true } + +[features] +default = ["disk"] +serde = ["dep:serde"] +disk = ["serde", "dep:serde_json"] + +[dev-dependencies] +tempfile.workspace = true + +[lints] +workspace = true diff --git a/diskann-record/README.md b/diskann-record/README.md new file mode 100644 index 000000000..42dd56990 --- /dev/null +++ b/diskann-record/README.md @@ -0,0 +1,19 @@ +# DiskANN Record + +This crate provides a small framework for persisting structured Rust values (`struct`, `enum`, +`Option`, `Vec`, and primitives) as a manifest (can be serialized to JSON) plus a set +of side-car binary artifacts, and reloading them later. It is can be used by `diskann` +providers and indexes to implement durable, consistent and backward-compatible checkpoints. + +Types describe how they map to a versioned record by implementing the `save::Save` and +`load::Load` traits; the `save_fields!` and `load_fields!` macros handle the +field-by-field plumbing for plain structs. Every record carries a `Version` so loaders +can detect schema changes and either upgrade or fall back through a probing chain. + +The goal is to allow crates like `diskann` to checkpoint their state without depending on +a particular serialization backend. Currently, this crate implements `Disk` and `Memory` backends +-- `Disk` for persistent storage and `Memory` for in-memory operations. +`Memory` pipes the output of its `SaveContext` directly into the input of its `LoadContext`, so it is not compatible with other backends. +A hypothetical `ObjectStore` backend could be implemented to be compatible with `Disk` backends to support caching and other advanced features. + +This crate has minimal dependencies by design. \ No newline at end of file diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs new file mode 100644 index 000000000..914fc1138 --- /dev/null +++ b/diskann-record/src/backend/disk.rs @@ -0,0 +1,603 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Disk-backed save/load contexts. +//! +//! [`DiskSaveContext`] writes the manifest as JSON plus side-car artifact files into a +//! directory; [`DiskLoadContext`] reads them back. The two halves are independent and +//! communicate only through the filesystem. Both are available under the `disk` feature. + +use std::{ + collections::{HashMap, HashSet}, + fs::File, + io::BufReader, + path::{Path, PathBuf}, + sync::Mutex, +}; + +use crate::{ + load::{self, LoadContext, Reader}, + save::{self, Handle, SaveContext, Value, Writer, delegate_write_and_seek}, +}; + +/// The disk-backed [`SaveContext`]. +/// +/// Holds the manifest directory, the manifest path, and the artifact file names registered +/// so far paired with whether their writer has finished. Lookup and insertion go through a +/// [`Mutex`] so that concurrent [`Save`](crate::save::Save) impls cannot accidentally hand +/// out the same artifact name twice. +/// +/// # Cleanup on failure +/// +/// Save can fail part-way, so the [`Drop`] impl ensures cleanup of any artifacts created +/// before the failure. +#[derive(Debug)] +pub(crate) struct DiskSaveContext { + dir: PathBuf, + metadata: PathBuf, + files: Mutex>, + committed: bool, +} + +#[derive(serde::Serialize)] +struct Final<'a> { + files: Vec<&'a str>, + value: &'a Value<'a>, +} + +impl DiskSaveContext { + /// Create a disk-backed save context targeting `dir` for side-car artifacts and + /// `metadata` for the manifest. Validates that `dir` is an actual directory. + /// + /// # Errors + /// + /// Returns [`save::Error`] if `dir` does not exist, cannot be inspected, or exists but + /// is not a directory. + pub(crate) fn new(dir: PathBuf, metadata: PathBuf) -> save::Result { + match std::fs::metadata(&dir) { + Ok(meta) if meta.is_dir() => {} + Ok(_) => { + return Err(save::Error::message(format!( + "path {} exists but is not a directory", + dir.display() + ))); + } + Err(err) => { + return Err(save::Error::new(err) + .context(format!("while validating path {}", dir.display()))); + } + } + + Ok(Self { + dir, + metadata, + files: Mutex::new(HashMap::new()), + committed: false, + }) + } + + /// Path of the temp manifest written by [`SaveContext::finish`] before the atomic + /// rename into [`Self::metadata`]. + fn temp_metadata(&self) -> PathBuf { + let mut temp = self.metadata.clone().into_os_string(); + temp.push(".temp"); + PathBuf::from(temp) + } +} + +impl Drop for DiskSaveContext { + /// Best-effort cleanup for an uncommitted save. + fn drop(&mut self) { + if self.committed { + return; + } + let files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + for name in files.keys() { + let _ = std::fs::remove_file(self.dir.join(name)); + } + let _ = std::fs::remove_file(self.temp_metadata()); + } +} + +impl SaveContext for DiskSaveContext { + type Output = (); + + fn write(&self, key: Option<&str>) -> save::Result> { + // When a human-readable hint is supplied it must be a simple relative file name. + // NOTE:: Absolute paths, parent traversal, and multi-component paths cannot produce a + // single, well-formed file name in the manifest directory, so they are ignored and + // treated as if no hint had been supplied. + let key = key.filter(|key| { + let mut components = std::path::Path::new(key).components(); + matches!(components.next(), Some(std::path::Component::Normal(_))) + && components.next().is_none() + }); + + let mut files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + // Prefix each artifact with the count of artifacts written so far, so that reusing + // the same `key` (or omitting it) still yields a unique file name. + let name = match key { + Some(key) => format!("{:03}-{}", files.len(), key), + None => format!("{:03}", files.len()), + }; + + if files.contains_key(&name) { + return Err(save::Error::message(format!( + "generated artifact name {:?} collides with an existing artifact", + name, + ))); + } + let full = self.dir.join(&name); + if full.exists() { + return Err(save::Error::message(format!( + "file {} already exists", + full.display() + ))); + } + let file = std::fs::File::create_new(&full).map_err(|err| { + save::Error::new(err).context(format!("while creating new file {}", full.display())) + })?; + // Reserve the name as not-yet-finished; `FileWriter::finish` flips it to `true`, and + // `SaveContext::finish` reports any slot that was reserved but never finished. + files.insert(name.clone(), false); + Ok(Writer::new(FileWriter { file, parent: self }, name)) + } + + /// Finalize the manifest. + /// + /// Writes the manifest JSON atomically: serializes to a `.temp` file first, + /// then renames it into place. Fails if the temp file already exists (an in-flight + /// save is in progress, or a previous run aborted between rename steps). + /// + /// On failure, context is dropped without committing ==> [`Drop`] impl + /// removes the artifacts + temp manifest. Save is marked committed once + /// rename succeeds and artifacts are in place. + fn finish(mut self, value: Value<'_>) -> save::Result<()> { + let temp = self.temp_metadata(); + { + let files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + if let Some((name, _)) = files.iter().find(|(_, finished)| !**finished) { + return Err(save::Error::message(format!( + "artifact {:?} was reserved but never finished", + name, + ))); + } + let f = Final { + files: files.keys().map(|k| &**k).collect(), + value: &value, + }; + + // Fail if the temp file already exists + let buffer = std::fs::File::create_new(&temp).map_err(|err| { + if err.kind() == std::io::ErrorKind::AlreadyExists { + save::Error::message(format!( + "Temporary file {} already exists. Aborting!", + temp.display() + )) + } else { + save::Error::new(err).context(format!( + "while creating temp manifest file {}", + temp.display() + )) + } + })?; + + serde_json::to_writer_pretty(buffer, &f).map_err(|err| { + save::Error::new(err).context("while serializing manifest to JSON") + })?; + } + std::fs::rename(&temp, &self.metadata).map_err(|err| { + save::Error::new(err).context(format!( + "while renaming temp manifest {} to final path {}", + temp.display(), + self.metadata.display() + )) + })?; + // Manifest now in place, artifacts belong to a valid record + self.committed = true; + Ok(()) + } +} + +/// A file-backed [`WriterInner`](save::WriterInner) that streams bytes straight to disk. +/// +/// The bytes are persisted as they are written; [`WriterInner::finish`](save::WriterInner::finish) +/// only needs to mint the [`Handle`] and mark the artifact finished in its parent context +/// (the buffered bytes are already flushed into the file by [`Writer::finish`]). +#[derive(Debug)] +struct FileWriter<'a> { + file: File, + parent: &'a DiskSaveContext, +} + +impl save::WriterInner for FileWriter<'_> { + fn finish(self: Box, name: String) -> save::Result { + self.parent + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()) + .insert(name.clone(), true); + Ok(Handle::new(name)) + } +} + +delegate_write_and_seek!(file, FileWriter<'_>); + +/// The disk-backed [`LoadContext`]. +/// +/// Reads the manifest produced by [`DiskSaveContext`] and resolves side-car artifact +/// handles against the manifest directory. +#[derive(Debug)] +pub(crate) struct DiskLoadContext { + dir: PathBuf, + files: HashSet, + value: Value<'static>, +} + +#[derive(serde::Deserialize)] +struct FileRepr { + files: HashSet, + value: Value<'static>, +} + +impl DiskLoadContext { + pub(crate) fn new(metadata: &Path, dir: &Path) -> load::Result { + let file = std::fs::File::open(metadata).map_err(|e| { + load::Error::new(e).context(format!("while trying to open {}", metadata.display())) + })?; + + let reader = BufReader::new(file); + let repr: FileRepr = serde_json::from_reader(reader) + .map_err(|e| load::Error::new(e).context("could not deserialize manifest"))?; + + Ok(Self { + dir: dir.into(), + files: repr.files, + value: repr.value, + }) + } +} + +impl LoadContext for DiskLoadContext { + fn value(&self) -> load::Result<&Value<'_>> { + Ok(&self.value) + } + + fn read(&self, key: &str) -> load::Result> { + let key_as_path: &Path = key.as_ref(); + let mut components = key_as_path.components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err( + load::Error::from(load::error::Kind::MissingFile).context(format!( + "handle references file {:?} which escapes the manifest directory", + key, + )), + ); + } + } + if !self.files.contains(key_as_path) { + return Err( + load::Error::from(load::error::Kind::MissingFile).context(format!( + "handle references file {:?} which is not registered in the manifest", + key, + )), + ); + } + + let full = self.dir.join(key); + let file = std::fs::File::open(&full).map_err(|err| { + load::Error::new(err).context(format!("while opening artifact file {}", full.display())) + })?; + + Ok(Reader::new(file)) + } +} + +#[cfg(test)] +mod tests { + use std::path::{Path, PathBuf}; + + use super::*; + + #[test] + fn new_rejects_nonexistent_directory() { + let missing = PathBuf::from("does/not/exist/anywhere/at/all"); + let err = DiskSaveContext::new(missing, "meta.json".into()) + .expect_err("a nonexistent directory must be rejected"); + assert!(format!("{err}").contains("while validating path")); + } + + #[test] + fn new_rejects_file_as_directory() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("not_a_dir"); + std::fs::write(&file, b"hi").unwrap(); + let err = DiskSaveContext::new(file, dir.path().join("meta.json")) + .expect_err("a file path must be rejected as a directory"); + assert!(format!("{err}").contains("is not a directory")); + } + + #[test] + fn write_ignores_path_separators_and_traversal() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + for bad in ["sub/dir.bin", "../escape.bin", "/abs.bin"] { + let handle = SaveContext::write(&ctx, Some(bad)) + .expect("keys with path separators are treated as anonymous") + .finish() + .unwrap(); + let mut components = std::path::Path::new(handle.as_str()).components(); + assert!( + matches!(components.next(), Some(std::path::Component::Normal(_))) + && components.next().is_none(), + "generated name {:?} must be a single relative file name", + handle.as_str(), + ); + } + } + + #[test] + fn write_allows_duplicate_key() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + let first = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap(); + let second = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap(); + assert_ne!( + first.as_str(), + second.as_str(), + "duplicate keys must be disambiguated by the count prefix" + ); + assert_eq!(first.as_str(), "000-artifact.bin"); + assert_eq!(second.as_str(), "001-artifact.bin"); + } + + #[test] + fn write_allows_anonymous_artifact() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + let handle = SaveContext::write(&ctx, None).unwrap().finish().unwrap(); + assert!(!handle.as_str().is_empty()); + } + + #[test] + fn write_rejects_preexisting_file_on_disk() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + // The first artifact is named with a `000` count prefix; pre-create that exact file + // on disk so the `full.exists()` guard rejects the allocation. + std::fs::write(dir.path().join("000-artifact.bin"), b"stale").unwrap(); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("an artifact whose file already exists on disk must be rejected"); + assert!(format!("{err}").contains("already exists")); + } + + #[test] + fn finish_rejects_unfinished_artifact() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + // Reserve an artifact slot but drop the writer without calling `finish`, leaving the + // slot marked as not-yet-finished. + let writer = SaveContext::write(&ctx, Some("artifact.bin")).unwrap(); + drop(writer); + let err = ctx + .finish(Value::Null) + .expect_err("finish must fail when an artifact was reserved but never finished"); + assert!(format!("{err}").contains("was reserved but never finished")); + } + + #[test] + fn write_rejects_name_collision() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + // The count prefix normally makes a collision impossible, so seed the bookkeeping map + // directly with the exact name the next `write` will generate (one entry => count 1 => + // `001-artifact.bin`). + ctx.files + .lock() + .unwrap() + .insert("001-artifact.bin".to_string(), true); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("a generated name that is already registered must be rejected"); + assert!(format!("{err}").contains("collides with an existing artifact")); + } + + #[test] + fn write_reports_file_creation_failure() { + let dir = tempfile::tempdir().unwrap(); + let artifacts = dir.path().join("artifacts"); + std::fs::create_dir(&artifacts).unwrap(); + let ctx = DiskSaveContext::new(artifacts.clone(), dir.path().join("meta.json")).unwrap(); + // Remove the validated directory so `create_new` fails with a non-"exists" IO error + // (the `full.exists()` guard passes because the path is gone). + std::fs::remove_dir(&artifacts).unwrap(); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("creating an artifact in a missing directory must fail"); + assert!(format!("{err}").contains("while creating new file")); + } + + #[test] + fn finish_reports_rename_failure() { + let dir = tempfile::tempdir().unwrap(); + // Make the final manifest path an existing directory: the `.temp` file is + // created and serialized fine, but renaming a file onto a directory fails. + let metadata = dir.path().join("meta.json"); + std::fs::create_dir(&metadata).unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let err = ctx + .finish(Value::Null) + .expect_err("renaming the temp manifest onto a directory must fail"); + assert!(format!("{err}").contains("while renaming temp manifest")); + } + + #[test] + fn drop_without_finish_cleans_up_artifacts() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("meta.json"); + let name = { + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let name = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap() + .as_str() + .to_owned(); + assert!(dir.path().join(&name).exists()); + name + // `ctx` is dropped here without ever being committed via `finish`. + }; + assert!( + !dir.path().join(&name).exists(), + "an uncommitted save must clean up the artifacts it created" + ); + } + + #[test] + fn failed_finish_cleans_up_artifacts_and_temp() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("meta.json"); + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let name = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap() + .as_str() + .to_owned(); + + // Pre-create the temp manifest so `finish` aborts on the `create_new` collision. + let temp = ctx.temp_metadata(); + std::fs::write(&temp, b"stale").unwrap(); + + let err = ctx + .finish(Value::Null) + .expect_err("finish must fail when the temp manifest already exists"); + assert!(format!("{err}").contains("already exists")); + + assert!( + !dir.path().join(&name).exists(), + "a failed finish must clean up the artifacts it created" + ); + assert!( + !metadata.exists(), + "a failed finish must not leave a committed manifest" + ); + } + + #[test] + fn committed_save_preserves_artifacts() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("meta.json"); + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let name = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap() + .as_str() + .to_owned(); + + ctx.finish(Value::Null).unwrap(); + + assert!( + dir.path().join(&name).exists(), + "a committed save must keep its artifacts" + ); + assert!( + metadata.exists(), + "a committed save must write the manifest" + ); + assert!( + !ctx_temp(&metadata).exists(), + "a committed save must not leave a temp manifest" + ); + } + + fn ctx_temp(metadata: &Path) -> PathBuf { + let mut temp = metadata.to_owned().into_os_string(); + temp.push(".temp"); + PathBuf::from(temp) + } + + fn write_manifest(dir: &Path, files: &[&str]) -> PathBuf { + let manifest = serde_json::json!({ + "files": files, + "value": { "$version": "0.0" }, + }); + let metadata = dir.join("metadata.json"); + std::fs::write(&metadata, serde_json::to_vec(&manifest).unwrap()).unwrap(); + metadata + } + + #[test] + fn read_rejects_unregistered_file() { + let dir = tempfile::tempdir().unwrap(); + let metadata = write_manifest(dir.path(), &[]); + let ctx = DiskLoadContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("artifact.bin") else { + panic!("an unregistered file must be rejected"); + }; + assert!(format!("{err}").contains("not registered in the manifest")); + } + + #[test] + fn read_rejects_escaping_handle() { + let dir = tempfile::tempdir().unwrap(); + // Register the escaping name so only the path-shape check can reject it. + let metadata = write_manifest(dir.path(), &["../escape.bin"]); + let ctx = DiskLoadContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("../escape.bin") else { + panic!("a handle escaping the manifest directory must be rejected"); + }; + assert!(format!("{err}").contains("escapes the manifest directory")); + } + + #[test] + fn read_reports_missing_artifact_file() { + let dir = tempfile::tempdir().unwrap(); + // Register the artifact in the manifest but never create the file on disk, so the + // path/registration checks pass and only the file `open` fails. + let metadata = write_manifest(dir.path(), &["artifact.bin"]); + let ctx = DiskLoadContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("artifact.bin") else { + panic!("a registered artifact missing from disk must be reported"); + }; + assert!(format!("{err}").contains("while opening artifact file")); + } + + #[test] + fn new_load_reports_missing_manifest() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("does-not-exist.json"); + let err = DiskLoadContext::new(&metadata, dir.path()) + .expect_err("a missing manifest file must be reported"); + assert!(format!("{err}").contains("while trying to open")); + } + + #[test] + fn new_load_rejects_malformed_manifest() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("metadata.json"); + std::fs::write(&metadata, b"this is not json").unwrap(); + let err = DiskLoadContext::new(&metadata, dir.path()) + .expect_err("a malformed manifest must be rejected"); + assert!(format!("{err}").contains("could not deserialize manifest")); + } +} diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs new file mode 100644 index 000000000..8a980a078 --- /dev/null +++ b/diskann-record/src/backend/memory.rs @@ -0,0 +1,330 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! In-memory save/load contexts. +//! +//! [`MemorySaveContext`] and [`MemoryContext`] mirror the disk-backed contexts but +//! keep the manifest value and every side-car artifact in memory. Saving through an +//! [`MemorySaveContext`] yields an [`MemoryContext`] (its `SaveContext::Output`) that can +//! be loaded directly via the `load` entry point. +//! +//! Unlike the disk path, this round trip never serializes through JSON: the manifest +//! [`Value`] is deep-copied to `'static` via [`Value::into_owned`] and side-car artifacts +//! are buffered in memory through a [`std::io::Cursor`]. As a result these contexts are +//! available regardless of the `disk` / `serde` features. + +use std::{collections::HashMap, io::Cursor, sync::Mutex}; + +use crate::{ + Value, + load::{self, LoadContext, Reader}, + save::{self, Handle, SaveContext, Writer, delegate_write_and_seek}, +}; + +/// A save-side `SaveContext` that keeps every side-car artifact and the committed +/// manifest value in memory. +/// +/// `SaveContext::finish` consumes the context and returns an [`MemoryContext`] ready +/// to be loaded with the `load` entry point. +/// +/// # Cleanup on failure +/// +/// Failures in the same process are automatically cleaned up. +#[derive(Debug, Default)] +pub struct MemorySaveContext { + files: Mutex>>>, +} + +impl MemorySaveContext { + /// Create an empty in-memory save context. + pub fn new() -> Self { + Self::default() + } +} + +impl SaveContext for MemorySaveContext { + type Output = MemoryContext; + + fn write(&self, key: Option<&str>) -> save::Result> { + // Mirror the disk context: a human-readable hint must be a simple relative file + // name. Absolute paths, parent traversal, and multi-component paths cannot produce + // a single, well-formed key, so they are ignored and treated as if no hint had been + // supplied. + let key = key.filter(|key| { + let mut components = std::path::Path::new(key).components(); + matches!(components.next(), Some(std::path::Component::Normal(_))) + && components.next().is_none() + }); + + let mut files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + // Prefix each artifact with the count of artifacts written so far, so that reusing + // the same `key` (or omitting it) still yields a unique name. + let name = match key { + Some(key) => format!("{:03}-{}", files.len(), key), + None => format!("{:03}", files.len()), + }; + + if files.contains_key(&name) { + return Err(save::Error::message(format!( + "generated artifact name {:?} collides with an existing artifact", + name, + ))); + } + // Reserve the name with an empty slot so the count advances and concurrent writers + // cannot collide; the slot is filled with the real bytes by `Writer::finish`. A slot + // that is never filled is reported by `SaveContext::finish`. + files.insert(name.clone(), None); + drop(files); + + Ok(Writer::new( + MemoryWriter { + cursor: Cursor::new(Vec::new()), + parent: self, + }, + name, + )) + } + + fn finish(self, value: Value<'_>) -> save::Result { + let files = self + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); + let files = files + .into_iter() + .map(|(name, bytes)| match bytes { + Some(bytes) => Ok((name, bytes)), + None => Err(save::Error::message(format!( + "artifact {:?} was reserved but never finished", + name, + ))), + }) + .collect::>>()?; + Ok(MemoryContext { + files, + value: value.into_owned(), + }) + } +} + +/// An in-memory [`WriterInner`](save::WriterInner) that buffers bytes in a [`Cursor`] and, +/// on finish, deposits the completed buffer into its parent context's file store. +#[derive(Debug)] +struct MemoryWriter<'a> { + cursor: Cursor>, + parent: &'a MemorySaveContext, +} + +impl save::WriterInner for MemoryWriter<'_> { + fn finish(self: Box, name: String) -> save::Result { + self.parent + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()) + .insert(name.clone(), Some(self.cursor.into_inner())); + Ok(Handle::new(name)) + } +} + +delegate_write_and_seek!(cursor, MemoryWriter<'_>); + +/// A load-side `LoadContext` backed entirely by in-memory buffers. +/// +/// Produced by [`MemorySaveContext`] via `SaveContext::finish`. Holds the committed +/// manifest [`Value`] and every side-car artifact as an in-memory byte buffer, so loading +/// never serializes through JSON or touches the filesystem. +#[derive(Debug)] +pub struct MemoryContext { + files: HashMap>, + value: Value<'static>, +} + +impl LoadContext for MemoryContext { + fn value(&self) -> load::Result<&Value<'_>> { + Ok(&self.value) + } + + fn read(&self, key: &str) -> load::Result> { + match self.files.get(key) { + Some(bytes) => Ok(Reader::new(Cursor::new(bytes.as_slice()))), + None => Err( + load::Error::from(load::error::Kind::MissingFile).context(format!( + "handle references artifact {:?} which is not registered in this context", + key, + )), + ), + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + + use super::*; + use crate::{Version, load, save}; + + #[derive(Debug, PartialEq)] + struct Doc { + name: String, + blob: Vec, + } + + impl save::Save for Doc { + const VERSION: Version = Version::new(1, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + let mut io = context.write(Some("blob.bin"))?; + io.write_all(&self.blob).map_err(save::Error::new)?; + let mut record = crate::save_fields!(self, context, [name]); + record.insert("blob", io.finish()?)?; + Ok(record) + } + } + + impl load::Load<'_> for Doc { + const VERSION: Version = Version::new(1, 0); + fn load(object: load::Object<'_>) -> load::Result { + crate::load_fields!(object, [name: String, blob: save::Handle]); + let mut io = object.read(&blob)?; + let mut blob = Vec::new(); + io.read_to_end(&mut blob).map_err(load::Error::new)?; + Ok(Self { name, blob }) + } + fn load_legacy(_: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[test] + fn round_trips_without_serde_or_disk() { + let doc = Doc { + name: "example".to_owned(), + blob: vec![1, 2, 3, 4, 5], + }; + + let context = save::save(&doc, MemorySaveContext::new()).unwrap(); + let restored: Doc = load::load(&context).unwrap(); + + assert_eq!(doc, restored); + } + + #[test] + fn read_rejects_unregistered_artifact() { + let context = MemoryContext { + files: HashMap::new(), + value: Value::Null, + }; + let err = context + .read("missing.bin") + .err() + .expect("an unregistered artifact must be rejected"); + assert!(format!("{err}").contains("not registered in this context")); + } + + #[test] + fn write_rejects_name_collision() { + let ctx = MemorySaveContext::new(); + // The count prefix normally makes a collision impossible, so seed the bookkeeping map + // directly with the exact name the next `write` will generate (one entry => count 1 => + // `001-artifact.bin`). + ctx.files + .lock() + .unwrap() + .insert("001-artifact.bin".to_string(), None); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("a generated name that is already registered must be rejected"); + assert!(format!("{err}").contains("collides with an existing artifact")); + } + + #[test] + fn finish_rejects_unfinished_artifact() { + let ctx = MemorySaveContext::new(); + // Reserve an artifact slot but drop the writer without calling `finish`, leaving the + // slot empty. + let writer = SaveContext::write(&ctx, Some("artifact.bin")).unwrap(); + drop(writer); + let err = ctx + .finish(Value::Null) + .expect_err("finish must fail when an artifact was reserved but never finished"); + assert!(format!("{err}").contains("was reserved but never finished")); + } + + #[test] + fn write_names_anonymous_artifact_with_count_prefix() { + let ctx = MemorySaveContext::new(); + // Passing `None` as the hint exercises the count-only naming branch. + let handle = SaveContext::write(&ctx, None).unwrap().finish().unwrap(); + assert_eq!(handle.as_str(), "000"); + } + + #[test] + fn load_dispatches_to_load_legacy_on_version_mismatch() { + // Build an object whose version does not match `Doc::VERSION` (1.0) so the `Loadable` + // blanket dispatches to `Doc::load_legacy`, which refuses with `UnknownVersion`. + let value = save::Record::empty() + .into_value(Version::new(2, 0)) + .into_owned(); + let context = MemoryContext { + files: HashMap::new(), + value, + }; + let err = load::load::(&context) + .expect_err("a version mismatch must dispatch to load_legacy, which refuses"); + assert!(format!("{err}").contains("unknown version")); + } + + // Regression for the NaN float-field bug, demonstrated end-to-end through the in-memory + // backend (no serde / disk required). A struct carrying NaN-valued `f64` / `f32` fields is + // saved and reloaded; the reload previously FAILED with `NumberOutOfRange` because + // `Number::as_f64` / `as_f32` rejected NaN. + #[derive(Debug)] + struct Floats { + finite: f64, + nan64: f64, + nan32: f32, + } + + impl save::Save for Floats { + const VERSION: Version = Version::new(1, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(crate::save_fields!(self, context, [finite, nan64, nan32])) + } + } + + impl load::Load<'_> for Floats { + const VERSION: Version = Version::new(1, 0); + fn load(object: load::Object<'_>) -> load::Result { + crate::load_fields!(object, [finite: f64, nan64: f64, nan32: f32]); + Ok(Self { + finite, + nan64, + nan32, + }) + } + fn load_legacy(_: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[test] + fn nan_float_fields_round_trip_in_memory() { + let value = Floats { + finite: 1.5, + nan64: f64::NAN, + nan32: f32::NAN, + }; + + let context = save::save(&value, MemorySaveContext::new()).unwrap(); + let restored: Floats = load::load(&context).expect("NaN float fields must reload"); + + assert_eq!(restored.finite, 1.5); + assert!(restored.nan64.is_nan(), "f64 NaN field lost on reload"); + assert!(restored.nan32.is_nan(), "f32 NaN field lost on reload"); + } +} diff --git a/diskann-record/src/backend/mod.rs b/diskann-record/src/backend/mod.rs new file mode 100644 index 000000000..639cd9fa9 --- /dev/null +++ b/diskann-record/src/backend/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +#[cfg(feature = "disk")] +pub(crate) mod disk; + +pub(crate) mod memory; diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs new file mode 100644 index 000000000..79a907da3 --- /dev/null +++ b/diskann-record/src/lib.rs @@ -0,0 +1,815 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Versioned Save/Load for DiskANN +//! +//! This crate provides a small framework for persisting structured Rust values to disk +//! as a JSON manifest plus a set of side-car binary artifacts, and reloading them later. +//! It is the substrate used by `diskann` providers and indexes to implement durable +//! checkpoints. +//! +//! The model is: +//! +//! * Each [`save::Save`] / [`load::Load`] implementation describes how a single Rust type +//! maps to a [`save::Record`] (a versioned map of named fields). +//! * Field values are either [`save::Value`]s embedded directly in the manifest, or +//! [`save::Handle`]s pointing at side-car binary artifacts written via the +//! [`save::Context`]. +//! * Every record carries a [`Version`] so that loaders can detect schema changes and +//! either upgrade ([`load::Load::load_legacy`]) or fall back through a probing chain +//! (see [`load::Error::is_recoverable`]). +//! +//! # Entry Points +//! +//! - `save::save_to_disk` (requires the `disk` feature): Save a value to a directory +//! plus a manifest path. +//! - `load::load_from_disk` (requires the `disk` feature): Reload a value from a +//! manifest and its artifact directory. +//! +//! # Defining Save / Load +//! +//! User code is expected to implement [`save::Save`] and [`load::Load`] for the types it +//! wants to persist. For plain structs, the [`save_fields!`] and [`load_fields!`] macros +//! handle the field-by-field plumbing. See [`save`] and [`load`] for the relevant traits +//! and helpers. +//! +//! ## Example +//! +//! ```ignore +//! use diskann_record::{Version, save, load}; +//! +//! #[derive(Debug, PartialEq)] +//! struct Config { dim: usize, label: String } +//! +//! impl save::Save for Config { +//! const VERSION: Version = Version::new(0, 0); +//! fn save(&self, context: save::Context<'_>) -> save::Result> { +//! Ok(diskann_record::save_fields!(self, context, [dim, label])) +//! } +//! } +//! +//! impl load::Load<'_> for Config { +//! const VERSION: Version = Version::new(0, 0); +//! fn load(object: load::Object<'_>) -> load::Result { +//! diskann_record::load_fields!(object, [dim: usize, label: String]); +//! Ok(Self { dim, label }) +//! } +//! fn load_legacy(_: load::Object<'_>) -> load::Result { +//! Err(load::error::Kind::UnknownVersion.into()) +//! } +//! } +//! ``` +//! +//! # Wire Format +//! +//! The manifest is JSON. Every object carries a `$version` field; side-car artifacts are +//! referenced through `$handle` strings whose value is a file name relative to the +//! manifest directory. Keys beginning with `$` are reserved for framework metadata and +//! cannot be used as user field names (see [`is_reserved`]). +//! +//! # Platform Requirements +//! +//! `usize` and `isize` are serialized as 64-bit numbers. The crate statically asserts +//! that `usize::BITS == 64` to guarantee that the saver never produces values the +//! canonical wire width cannot represent. Loaders still range-check at runtime. +//! +//! # Error Handling +//! +//! Both [`save::Error`] and [`load::Error`] wrap [`anyhow::Error`] for rich context +//! chains. Load errors additionally carry a recoverable / critical bit, used by probing +//! call sites to decide whether to fall back to an alternative loader. See +//! [`load::error::Kind`] for the classification. + +mod number; +pub use number::Number; + +mod version; +pub use version::Version; + +mod value; +pub use value::{Handle, Keys, Record, Value, Versioned}; + +pub mod load; +pub mod save; + +mod backend; +pub use backend::memory::{MemoryContext, MemorySaveContext}; + +// Canonical wire width for `usize` and `isize` in manifests is 64 bits. Saving a value +// on a 64-bit platform and loading it on a 32-bit platform (or vice versa) could silently +// truncate values that exceed `u32::MAX` / `i32::MAX`. We therefore require a 64-bit +// platform at compile time. Loaders still range-check at runtime, but this check ensures +// the saver never emits values that the canonical width cannot represent. +const _: () = assert!( + usize::BITS == 64, + "diskann-record requires a 64-bit target: usize/isize MUST be 64 bits wide !!", +); + +/// Return `true` if `s` is a reserved manifest key. +/// +/// Keys beginning with `$` are reserved for framework metadata (e.g. `$version`, +/// `$handle`) and may not be used as user field names. Attempting to insert one via +/// [`save::Record::insert`] returns an error. +#[doc(hidden)] +pub const fn is_reserved(s: &str) -> bool { + if let Some(first) = s.as_bytes().first() + && *first == b"$"[0] + { + true + } else { + false + } +} + +/////////// +// Tests // +/////////// + +#[cfg(all(test, feature = "disk"))] +mod tests { + use super::*; + + use std::io::{Read, Write}; + use std::path::{Path, PathBuf}; + + #[derive(Debug, PartialEq)] + struct Test { + x: String, + y: f32, + enabled: bool, + inner: Inner, + // We write this as a binary file. + vector: Vec, + nickname: Option, + absent: Option, + } + + #[derive(Debug, PartialEq)] + struct Inner { + z: usize, + w: Vec, + flags: Vec, + maybe_count: Option, + maybe_missing: Option, + sparse: Vec>, + } + + impl save::Save for Inner { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!( + self, + context, + [z, w, flags, maybe_count, maybe_missing, sparse] + )) + } + } + + impl save::Save for Test { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + // We save `x`, `y`, and `inner` directly into the manifest. + // The raw vector data we instead store in an auxiliary file. + + let mut io = context.write(Some("auxiliary.bin"))?; + io.write_all(&self.vector).map_err(save::Error::new)?; + + let mut record = save_fields!(self, context, [x, y, enabled, inner, nickname, absent]); + record.insert("vector", io.finish()?)?; + Ok(record) + } + } + + impl load::Load<'_> for Test { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + x, + y, + enabled, + inner, + nickname: Option, + absent: Option, + vector: save::Handle, + ] + ); + + let mut io = object.read(&vector)?; + let mut vector = Vec::new(); + io.read_to_end(&mut vector).unwrap(); + + Ok(Self { + x, + y, + enabled, + inner, + vector, + nickname, + absent, + }) + } + + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + impl load::Load<'_> for Inner { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + z, + w, + flags, + maybe_count: Option, + maybe_missing: Option, + sparse: Vec>, + ] + ); + Ok(Self { + z, + w, + flags, + maybe_count, + maybe_missing, + sparse, + }) + } + + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn round_trip_uses_isolated_temp_dir() -> anyhow::Result<()> { + let inner = Inner { + z: 10, + w: vec![-1, -2, -3], + flags: vec![true, false, true], + maybe_count: Some(42), + maybe_missing: None, + sparse: vec![Some(1), None, Some(-3), None], + }; + + let t = Test { + x: "hello".into(), + y: 5.0, + enabled: true, + inner, + vector: vec![0, 1, 2, 3, 4, 5], + nickname: Some("friend".into()), + absent: None, + }; + + // Keep the TempDir guard alive for the full round trip; Drop removes the + // manifest and auxiliary artifact after the assertion completes. + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&t, dir, &metadata)?; + let we_are_back: Test = load::load_from_disk(&metadata, dir)?; + + assert_eq!(t, we_are_back); + Ok(()) + } + + ///////////////////////// + // Enum support: round // + ///////////////////////// + + #[derive(Debug, PartialEq)] + enum Metric { + L2, + Cosine, + Weighted { weights: Vec }, + } + + impl save::Save for Metric { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + let mut record = save::Record::empty(); + match self { + Self::L2 => { + record.insert("L2", save::Value::Null)?; + } + Self::Cosine => { + record.insert("Cosine", save::Value::Null)?; + } + Self::Weighted { weights } => { + let payload = save_fields!(context, [weights]).into_value(Self::VERSION); + record.insert("Weighted", payload)?; + } + } + Ok(record) + } + } + + impl load::Load<'_> for Metric { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + match object.single_key()? { + "L2" => Ok(Self::L2), + "Cosine" => Ok(Self::Cosine), + "Weighted" => { + let inner = object + .child("Weighted")? + .as_object() + .ok_or(load::error::Kind::TypeMismatch)?; + load_fields!(inner, [weights: Vec]); + Ok(Self::Weighted { weights }) + } + _ => Err(load::error::Kind::UnknownVariant.into()), + } + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[derive(Debug, PartialEq)] + struct MetricBag { + primary: Metric, + alternatives: Vec, + } + + impl save::Save for MetricBag { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [primary, alternatives])) + } + } + + impl load::Load<'_> for MetricBag { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [primary: Metric, alternatives: Vec]); + Ok(Self { + primary, + alternatives, + }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn enum_round_trip_through_disk() -> anyhow::Result<()> { + let bag = MetricBag { + primary: Metric::Weighted { + weights: vec![0.25, 0.5, 0.25], + }, + alternatives: vec![ + Metric::L2, + Metric::Cosine, + Metric::Weighted { weights: vec![1.0] }, + ], + }; + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&bag, dir, &metadata)?; + let restored: MetricBag = load::load_from_disk(&metadata, dir)?; + + assert_eq!(bag, restored); + Ok(()) + } + + #[derive(Debug, PartialEq)] + struct StructShape { + x: i32, + } + + impl save::Save for StructShape { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [x])) + } + } + + impl load::Load<'_> for StructShape { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [x: i32]); + Ok(Self { x }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[derive(Debug, PartialEq)] + enum EnumShape { + Only { x: i32 }, + } + + impl save::Save for EnumShape { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + let mut record = save::Record::empty(); + match self { + Self::Only { x } => { + let payload = save_fields!(context, [x]).into_value(Self::VERSION); + record.insert("Only", payload)?; + } + } + Ok(record) + } + } + + impl load::Load<'_> for EnumShape { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + match object.single_key()? { + "Only" => { + let inner = object + .child("Only")? + .as_object() + .ok_or(load::error::Kind::TypeMismatch)?; + load_fields!(inner, [x: i32]); + Ok(Self::Only { x }) + } + _ => Err(load::error::Kind::UnknownVariant.into()), + } + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn loading_enum_as_struct_is_rejected() -> anyhow::Result<()> { + // Enum data has a single key "Only" whose payload is a versioned + // sub-object. Loading it as `StructShape` (which expects field `x`) + // surfaces `MissingField`. + let value = EnumShape::Only { x: 7 }; + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("loading enum data into a struct shape should fail"); + let msg = format!("{err}"); + assert!( + msg.contains("missing field"), + "expected MissingField error, got: {msg}" + ); + Ok(()) + } + + #[test] + fn loading_struct_as_enum_is_rejected() -> anyhow::Result<()> { + // Struct data has field `x`, which the enum loader sees as a candidate + // variant name. It doesn't match any arm, so we get `UnknownVariant`. + let value = StructShape { x: 7 }; + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("loading struct data into an enum shape should fail"); + let msg = format!("{err}"); + assert!( + msg.contains("unknown variant"), + "expected UnknownVariant error, got: {msg}" + ); + Ok(()) + } + + ////////////////////////// + // Legacy version path // + ////////////////////////// + + // Sample record that requires a legacy version path (disk version older than loader version). + #[derive(Debug, PartialEq)] + struct Upgraded { + // Stored in the legacy (v0) record. + count: u32, + // Absent from the v0 record; reconstructed by `load_legacy` from `count`. + scaled: u32, + } + + impl save::Save for Upgraded { + // Write using an "old" schema: only `count` is written to disk. + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [count])) + } + } + + impl load::Load<'_> for Upgraded { + // New schema: differs from the `0.0` stamped on disk, forcing `load_legacy`. + const VERSION: Version = Version::new(1, 0); + fn load(_object: load::Object<'_>) -> load::Result { + panic!("matching-version load must not run for a legacy record"); + } + fn load_legacy(object: load::Object<'_>) -> load::Result { + // Upgrade a v0 record: derive `scaled` from the stored `count`. + load_fields!(object, [count: u32]); + Ok(Self { + count, // The original count value on disk + scaled: count * 10, // "default"/derived value after upgrade + }) + } + } + + #[test] + fn legacy_record_dispatches_to_load_legacy() -> anyhow::Result<()> { + // Save stamps the old `0.0` schema; the loader's `1.0` differs, so the + // round trip must flow through `load_legacy`, which upgrades the record. + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk( + &Upgraded { + count: 4, + scaled: 0, + }, + dir, + &metadata, + )?; + let restored: Upgraded = load::load_from_disk(&metadata, dir)?; + + assert_eq!( + restored, + Upgraded { + count: 4, + scaled: 40 + } + ); + Ok(()) + } + + // A record whose loader has no upgrade path for the older on-disk schema: + // `load_legacy` refuses with `UnknownVersion`. + #[derive(Debug, PartialEq)] + struct NoUpgrade { + value: i32, + } + + impl save::Save for NoUpgrade { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [value])) + } + } + + impl load::Load<'_> for NoUpgrade { + const VERSION: Version = Version::new(2, 0); + fn load(_object: load::Object<'_>) -> load::Result { + panic!("matching-version load must not run for a legacy record"); + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + // Check if version.major is older than 1.0 + if _object.version().major < 1 { + return Err(load::error::Kind::UnknownVersion.into()); + } + panic!("should not reach this point"); + } + } + + #[test] + fn legacy_record_without_upgrade_path_is_rejected() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&NoUpgrade { value: 7 }, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("a legacy record with no upgrade path must fail"); + let msg = format!("{err}"); + assert!( + msg.contains("unknown version"), + "expected UnknownVersion error, got: {msg}" + ); + Ok(()) + } + + /////////////////////////////////// + // Built-in primitive round-trip // + /////////////////////////////////// + + // Covers the built-in `Loadable`/`Saveable` impls that the structural round-trip + // tests above don't reach: the wider integer widths and every `NonZero*` type. + #[derive(Debug, PartialEq)] + struct Primitives { + a: u16, + b: u32, + c: u64, + d: i16, + e: i32, + f: i64, + g: f64, + nz32: std::num::NonZeroU32, + nz64: std::num::NonZeroU64, + nzsize: std::num::NonZeroUsize, + } + + impl save::Save for Primitives { + const VERSION: Version = Version::new(0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!( + self, + context, + [a, b, c, d, e, f, g, nz32, nz64, nzsize] + )) + } + } + + impl load::Load<'_> for Primitives { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + a: u16, + b: u32, + c: u64, + d: i16, + e: i32, + f: i64, + g: f64, + nz32: std::num::NonZeroU32, + nz64: std::num::NonZeroU64, + nzsize: std::num::NonZeroUsize, + ] + ); + Ok(Self { + a, + b, + c, + d, + e, + f, + g, + nz32, + nz64, + nzsize, + }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn builtin_primitives_round_trip() -> anyhow::Result<()> { + let value = Primitives { + a: 4242, + b: 4_000_000_000, + c: 1 << 40, + d: -12345, + e: -2_000_000_000, + f: -(1 << 40), + g: -2.5e-9, + nz32: std::num::NonZeroU32::new(7).unwrap(), + nz64: std::num::NonZeroU64::new(1 << 50).unwrap(), + nzsize: std::num::NonZeroUsize::new(99).unwrap(), + }; + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let restored: Primitives = load::load_from_disk(&metadata, dir)?; + + assert_eq!(value, restored); + Ok(()) + } + + #[test] + fn nonzero_rejects_zero_on_load() -> anyhow::Result<()> { + // A hand-crafted manifest storing `0` in a `NonZeroU32` field must be rejected + // with `NumberOutOfRange` rather than producing an invalid value. + #[derive(Debug)] + struct NzHolder { + _nz: std::num::NonZeroU32, + } + + impl load::Load<'_> for NzHolder { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [nz: std::num::NonZeroU32]); + Ok(Self { _nz: nz }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + let manifest = serde_json::json!({ + "files": [], + "value": { "$version": "0.0", "nz": 0 }, + }); + std::fs::write(&metadata, serde_json::to_vec(&manifest)?)?; + + let err = load::load_from_disk::(&metadata, dir) + .expect_err("zero stored in a NonZero field must be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("number out of range"), + "expected NumberOutOfRange error, got: {msg}" + ); + Ok(()) + } + + /////////////////////////////// + // Manifest directory escape // + /////////////////////////////// + + /// Minimal loadable type with a single handle field. Used by the + /// directory-escape tests below to drive `Object::read` against a + /// hand-crafted manifest. + #[derive(Debug)] + struct HandleOnly { + _blob: Vec, + } + + impl load::Load<'_> for HandleOnly { + const VERSION: Version = Version::new(0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [blob: save::Handle]); + let mut io = object.read(&blob)?; + let mut buf = Vec::new(); + io.read_to_end(&mut buf).map_err(load::Error::new)?; + Ok(Self { _blob: buf }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + /// Write a hand-crafted manifest into `dir` whose root object exposes a + /// `blob` field referencing `handle_target`. Returns the metadata path. + fn write_handle_manifest(dir: &Path, handle_target: &str) -> std::io::Result { + let manifest = serde_json::json!({ + // Register the same target in `files` so the membership check + // would otherwise let it through — this isolates the new + // path-shape check as the thing rejecting the load. + "files": [handle_target], + "value": { + "$version": "0.0", + "blob": { "$handle": handle_target }, + }, + }); + let metadata = dir.join("metadata.json"); + std::fs::write(&metadata, serde_json::to_vec(&manifest)?)?; + Ok(metadata) + } + + #[test] + fn handle_with_parent_traversal_is_rejected() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = write_handle_manifest(dir, "../escape.bin")?; + + let err = load::load_from_disk::(&metadata, dir) + .expect_err("handle escaping the manifest directory must be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("escapes the manifest directory"), + "expected manifest-escape rejection, got: {msg}" + ); + Ok(()) + } + + #[test] + fn handle_with_absolute_path_is_rejected() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + // Use a platform-appropriate absolute path. Both shapes should be + // rejected on their respective platforms; we test the native one. + let absolute = if cfg!(windows) { + "C:\\Windows\\System32\\drivers\\etc\\hosts" + } else { + "/etc/passwd" + }; + let metadata = write_handle_manifest(dir, absolute)?; + + let err = load::load_from_disk::(&metadata, dir) + .expect_err("absolute-path handle must be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("escapes the manifest directory"), + "expected manifest-escape rejection, got: {msg}" + ); + Ok(()) + } +} diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs new file mode 100644 index 000000000..5a4fd51ac --- /dev/null +++ b/diskann-record/src/load/context.rs @@ -0,0 +1,433 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Load-side context, object, array, and side-car reader. +//! +//! [`Context`] is the cheap, clonable handle threaded through every +//! [`super::Load::load`] / [`super::Loadable::load`] impl. From a [`Context`], loaders +//! ask: +//! +//! * [`Context::as_object`] / [`Object::field`] for nested records. +//! * [`Context::as_array`] / [`Array::iter`] for sequences. +//! * [`Context::as_str`] / [`Context::as_number`] / [`Context::as_bool`] / [`Context::is_null`] for scalars. +//! * [`Object::read`] for side-car artifacts referenced by a +//! [`save::Handle`](super::save::Handle). +//! +//! [`Reader`] implements [`std::io::Read`] and [`std::io::Seek`] over a side-car +//! artifact, regardless of the provider's backing store. + +use std::io::BufReader; + +use crate::{ + Number, Version, + load::{Loadable, Result, error}, + save, +}; + +/// The backing store for a load operation. +/// +/// A `LoadContext` supplies the root manifest [`save::Value`] ([`LoadContext::value`]) +/// and resolves side-car artifacts referenced by handles ([`LoadContext::read`]). The +/// concrete implementations live under `crate::backend`: a disk-backed context (under +/// the `disk` feature) and an in-memory context. Alternative implementations (e.g. a +/// virtual filesystem) can be supplied for testing. +/// +/// The generic [`load`](super::load) entry point is parameterized over this trait, and +/// [`Context`] / [`Object`] / `Array` borrow it as an object-safe `&dyn LoadContext` +/// so the load tree is agnostic to the concrete context type. +pub(crate) trait LoadContext { + /// The root value of the manifest. + /// + /// # Errors + /// + /// Returns [`Error`](crate::load::Error) if the root value cannot be produced. + fn value(&self) -> Result<&save::Value<'_>>; + + /// Open the side-car artifact identified by `key` for reading. + /// + /// # Errors + /// + /// Returns [`error::Kind::MissingFile`] if the file is not registered with this + /// context or if `key` escapes the manifest directory. + fn read(&self, key: &str) -> Result>; +} + +/// The backend-specific half of a [`Reader`]. +/// +/// Each [`LoadContext`] implementation supplies its own `ReaderInner` (e.g. an in-memory +/// cursor or an on-disk file). The blanket impl covers any type that is both +/// [`std::io::Read`] and [`std::io::Seek`]. +pub(crate) trait ReaderInner: std::io::Read + std::io::Seek {} + +impl ReaderInner for T where T: std::io::Read + std::io::Seek {} + +/// A borrowed reader over a side-car artifact. +/// +/// Produced by [`Object::read`]. Implements [`std::io::Read`] and [`std::io::Seek`] over +/// whatever backing store the `LoadContext` provides, so non-file-backed providers +/// (like an in-memory byte buffer) can supply an arbitrary seekable reader. +pub struct Reader<'a> { + io: BufReader>, +} + +impl<'a> Reader<'a> { + /// Build a reader over an arbitrary borrowed [`ReaderInner`] source. + /// + /// Used by [`LoadContext`] implementations to expose a side-car artifact backed by a + /// file or an in-memory [`std::io::Cursor`]. + pub(crate) fn new(io: T) -> Self + where + T: ReaderInner + 'a, + { + Self { + io: BufReader::new(Box::new(io)), + } + } +} + +impl std::io::Read for Reader<'_> { + // Required method + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.io.read(buf) + } + + // Provided methods + fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result { + self.io.read_vectored(bufs) + } + fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { + self.io.read_to_end(buf) + } + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { + self.io.read_to_string(buf) + } + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + self.io.read_exact(buf) + } +} + +impl std::io::Seek for Reader<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.io.seek(pos) + } + fn rewind(&mut self) -> std::io::Result<()> { + self.io.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.io.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.io.seek_relative(offset) + } +} + +/////////////////////// +// User facing types // +/////////////////////// + +/// A cheap, clonable handle threaded through every load impl. +/// +/// Loaders use the `as_*` accessors to peek at the underlying [`save::Value`] kind +/// (e.g. [`Context::as_object`], [`Context::as_array`], [`Context::as_str`]) and +/// [`Context::load`] to recursively deserialize a nested value into a concrete type. +#[derive(Clone)] +pub struct Context<'a> { + inner: &'a dyn LoadContext, + value: &'a save::Value<'a>, +} + +impl<'a> Context<'a> { + pub(super) fn new(inner: &'a dyn LoadContext, value: &'a save::Value<'a>) -> Self { + Self { inner, value } + } + + fn context(&self) -> &'a dyn LoadContext { + self.inner + } + + /// Recursively deserialize the underlying value into a `T`. + /// + /// Equivalent to calling `T::load(self.clone())`. Use this from inner loaders that + /// want to delegate to another [`Loadable`]. + pub fn load(&self) -> Result + where + T: Loadable<'a>, + { + T::load(self.clone()) + } + + /// Returns `Some(Object)` if the value is a versioned object, else `None`. + pub fn as_object(&self) -> Option> { + match self.value { + save::Value::Object(versioned) => { + let object = Object { + inner: self.inner, + record: versioned.record(), + version: versioned.version(), + }; + Some(object) + } + _ => None, + } + } + + /// Returns `Some(s)` if the value is a string, else `None`. + pub fn as_str(&self) -> Option<&'a str> { + match self.value { + save::Value::String(s) => Some(s), + _ => None, + } + } + + /// Returns `Some(Array)` if the value is an array, else `None`. + pub fn as_array(&self) -> Option> { + match self.value { + save::Value::Array(array) => Some(Array::new(self.context(), array)), + _ => None, + } + } + + /// Returns `Some(Number)` if the value is numeric, else `None`. + /// + /// Use the conversion methods on [`Number`] (e.g. `as_u32`, `as_i64`) to narrow to + /// the target Rust type; out-of-range conversions return `None` and should be + /// surfaced as [`error::Kind::NumberOutOfRange`]. + pub fn as_number(&self) -> Option { + match self.value { + save::Value::Number(number) => Some(*number), + _ => None, + } + } + + /// Returns `Some(b)` if the value is a boolean, else `None`. + pub fn as_bool(&self) -> Option { + match self.value { + save::Value::Bool(value) => Some(*value), + _ => None, + } + } + + /// Returns `true` if the value is null. + /// + /// Used by [`Loadable`] impls for [`Option`] to detect the absent variant. + pub fn is_null(&self) -> bool { + matches!(self.value, save::Value::Null) + } + + pub(crate) fn as_handle(&self) -> Option<&save::Handle> { + match self.value { + save::Value::Handle(handle) => Some(handle), + _ => None, + } + } +} + +/// A versioned record reached through [`Context::as_object`]. +/// +/// `Object` is the entry point for record-based deserialization: it exposes the schema +/// version via [`Object::version`], the user keys via [`Object::keys`], typed field +/// extraction via [`Object::field`] (and the [`load_fields!`](crate::load_fields) +/// macro), and side-car artifact access via [`Object::read`]. +pub struct Object<'a> { + inner: &'a dyn LoadContext, + record: &'a save::Record<'a>, + version: Version, +} + +impl<'a> Object<'a> { + /// The schema [`Version`] recorded in the manifest for this object. + pub fn version(&self) -> Version { + self.version + } + + /// Iterate over the user keys of this record. Reserved keys (`$version`, + /// `$handle`) are tracked separately and never appear here. + pub fn keys(&self) -> save::Keys<'_, 'a> { + self.record.keys() + } + + /// Number of user keys in this record. + pub fn len(&self) -> usize { + self.record.len() + } + + /// Whether this record has no user keys. + pub fn is_empty(&self) -> bool { + self.record.is_empty() + } + + /// Return the sole user key of this record, used by enum loaders to dispatch + /// to a variant arm. Errors with a recoverable [`error::Kind::TypeMismatch`] + /// if the record has zero or multiple user keys (i.e. the wire shape does + /// not look like an enum). + pub fn single_key(&self) -> Result<&str> { + let mut keys = self.record.keys(); + let Some(first) = keys.next() else { + return Err(error::Kind::TypeMismatch.into()); + }; + if keys.next().is_some() { + return Err(error::Kind::TypeMismatch.into()); + } + Ok(first) + } + + /// Descend into the raw [`Context`] for `key`, without imposing a type. + /// Useful for enum variants whose payload is itself an [`Object`], an array, + /// or any other [`save::Value`]. Returns [`error::Kind::MissingField`] when + /// the key is absent. + pub fn child(&self, key: &str) -> Result> { + match self.record.get(key) { + Some(value) => Ok(Context::new(self.context(), value)), + None => Err(error::Kind::MissingField.into()), + } + } + + /// Extract the value under `key` and deserialize it into a `T`. + /// + /// This is the typed counterpart to [`Object::child`] and the primitive used by the + /// [`load_fields!`](crate::load_fields) macro. + /// + /// # Errors + /// + /// Returns [`error::Kind::MissingField`] if the key is absent. Errors raised by + /// `T::load` (e.g. [`error::Kind::TypeMismatch`]) are propagated unchanged. + pub fn field(&self, key: &str) -> Result + where + T: Loadable<'a>, + { + match self.record.get(key) { + Some(value) => T::load(Context::new(self.context(), value)), + None => Err((error::Kind::MissingField).into()), + } + } + + /// Open the side-car artifact identified by `handle` for reading. + /// + /// The handle must have been previously written through the matching + /// [`save::Context::write`](super::save::Context::write) call and embedded in this + /// record. Returns [`error::Kind::MissingFile`] if the file is not registered in the + /// manifest or if the handle attempts to escape the manifest directory. + pub fn read(&self, handle: &save::Handle) -> Result> { + self.inner.read(handle.as_str()) + } + + fn context(&self) -> &'a dyn LoadContext { + self.inner + } +} + +/// A homogeneous sequence of values reached through [`Context::as_array`]. +/// +/// Backed by a borrowed `&[Value]`. Use [`Array::iter`] to walk the elements; each +/// item is yielded as a [`Context`] that can be further deserialized via +/// [`Context::load`]. +pub struct Array<'a> { + inner: &'a dyn LoadContext, + array: &'a [save::Value<'a>], +} + +impl<'a> Array<'a> { + fn new(inner: &'a dyn LoadContext, array: &'a [save::Value<'a>]) -> Self { + Self { inner, array } + } + + /// Number of elements in the array. + pub fn len(&self) -> usize { + self.array.len() + } + + /// Returns `true` if the array is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Iterate over the elements, each as a [`Context`] ready for further deserialization. + pub fn iter(&self) -> Iter<'a> { + Iter::new(self.context(), self.array.iter()) + } + + fn context(&self) -> &'a dyn LoadContext { + self.inner + } +} + +/// Iterator returned by [`Array::iter`]. +pub struct Iter<'a> { + inner: &'a dyn LoadContext, + iter: std::slice::Iter<'a, save::Value<'a>>, +} + +impl<'a> Iter<'a> { + fn new(inner: &'a dyn LoadContext, iter: std::slice::Iter<'a, save::Value<'a>>) -> Self { + Self { inner, iter } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Context<'a>; + fn next(&mut self) -> Option { + self.iter + .next() + .map(|value| Context::new(self.inner, value)) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl ExactSizeIterator for Iter<'_> {} + +#[cfg(test)] +mod tests { + use super::*; + + /// Minimal [`LoadContext`] used to anchor a [`Context`] in tests; the `Object` + /// methods under test never touch the context's `value` / `read` operations. + struct TestContext; + + impl LoadContext for TestContext { + fn value(&self) -> Result<&save::Value<'_>> { + unimplemented!("not exercised by Object accessor tests") + } + + fn read(&self, _key: &str) -> Result> { + unimplemented!("not exercised by Object accessor tests") + } + } + + #[test] + fn object_reports_populated_keys() { + let mut record = save::Record::empty(); + record.insert("alpha", save::Value::Null).unwrap(); + record.insert("beta", save::Value::Bool(true)).unwrap(); + let value = record.into_value(Version::new(1, 0)); + + let inner = TestContext; + let object = Context::new(&inner, &value) + .as_object() + .expect("a versioned object value must yield an Object"); + + assert_eq!(object.len(), 2); + assert!(!object.is_empty()); + let mut keys: Vec<&str> = object.keys().collect(); + keys.sort_unstable(); + assert_eq!(keys, ["alpha", "beta"]); + } + + #[test] + fn object_reports_empty_record() { + let value = save::Record::empty().into_value(Version::new(1, 0)); + + let inner = TestContext; + let object = Context::new(&inner, &value) + .as_object() + .expect("a versioned object value must yield an Object"); + + assert_eq!(object.len(), 0); + assert!(object.is_empty()); + assert_eq!(object.keys().count(), 0); + } +} diff --git a/diskann-record/src/load/error.rs b/diskann-record/src/load/error.rs new file mode 100644 index 000000000..9ade95790 --- /dev/null +++ b/diskann-record/src/load/error.rs @@ -0,0 +1,219 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Load-side error type and classification. +//! +//! The [`Error`] type wraps [`anyhow::Error`] for rich diagnostics and carries a +//! recoverable / critical bit used by probing call sites. The [`Kind`] enum enumerates +//! the well-known structural failure modes; [`Kind::is_recoverable`] is the canonical +//! source of truth for the recoverable / critical classification. + +use std::fmt::{Debug, Display}; + +/// A specialized [`std::result::Result`] for load-side operations. +pub type Result = ::std::result::Result; + +/// Load-side error. +/// +/// Carries an inner [`anyhow::Error`] for rich diagnostics (chained context, +/// backtraces) along with a single `recoverable` bit. Recoverable errors are +/// the contract for probing APIs: a caller that tries multiple load strategies +/// (e.g. current version, then legacy) can distinguish "this attempt didn't +/// match, try another" from "the data is broken, stop now". +/// +/// Most constructors produce *critical* (non-recoverable) errors. Probing +/// call sites use the explicit `*_recoverable` constructors, or rely on the +/// [`From`] impl which classifies each [`Kind`] variant according to +/// [`Kind::is_recoverable`]. +#[derive(Debug)] +pub struct Error { + inner: anyhow::Error, + recoverable: bool, +} + +impl Error { + /// Construct a critical error from an underlying source error. + pub fn new(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::new(err), + recoverable: false, + } + } + + /// Construct a critical error from a display message. + pub fn message(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::msg(message), + recoverable: false, + } + } + + /// Construct a recoverable error from an underlying source. Suitable for + /// probing APIs that may attempt an alternative load strategy. + pub fn new_recoverable(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::new(err), + recoverable: true, + } + } + + /// Construct a recoverable error from a display message. Suitable for + /// probing APIs that may attempt an alternative load strategy. + pub fn message_recoverable(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::msg(message), + recoverable: true, + } + } + + /// Attach additional context. The `recoverable` flag is preserved. + pub fn context(self, message: D) -> Self + where + D: Display + Send + Sync + 'static, + { + Self { + inner: self.inner.context(message), + recoverable: self.recoverable, + } + } + + /// Returns `true` if this error is recoverable. Probing call sites should + /// only fall back to alternative load strategies when this is `true`. + pub fn is_recoverable(&self) -> bool { + self.recoverable + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Load Error: {:?}", self.inner) + } +} + +impl std::error::Error for Error { + /// Returns the lower-level source of this error, if it exists. + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.inner.as_ref()) + } +} + +/// Well-known classes of load-side failure. +/// +/// Used in two roles: +/// +/// * As the source of an [`Error`] via `From` (and the matching `From` for +/// [`Error`] which classifies recoverable / critical according to +/// [`Kind::is_recoverable`]). +/// * As a probe value in error chains — high-level callers can introspect the kind to +/// decide whether to try a fallback loader. +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum Kind { + /// The manifest's `$version` does not match the loader's expected + /// [`Load::VERSION`](crate::load::Load::VERSION). + VersionMismatch, + /// A required field is absent from the record. + MissingField, + /// The shape of the saved value does not match what the loader expected (e.g. found + /// an array where an object was needed). + TypeMismatch, + /// The manifest's version is recognized as not matching the current schema, and the + /// type's [`Load::load_legacy`](crate::load::Load::load_legacy) has no upgrade path + /// for it. + UnknownVersion, + /// The variant tag read from the wire format does not match any known + /// variant of the target enum. + UnknownVariant, + /// A numeric value in the manifest does not fit in the requested Rust type + /// (either out of range or would lose precision). + NumberOutOfRange, + /// A `$handle` references a file name that is not registered in the + /// manifest's `files` set. + MissingFile, +} + +impl Kind { + /// Stable, human-readable description of this kind. Used as the default error + /// message when constructing an [`Error`] from a `Kind`. + pub const fn as_str(self) -> &'static str { + match self { + Self::VersionMismatch => "version mismatch", + Self::MissingField => "missing field", + Self::TypeMismatch => "type mismatch", + Self::UnknownVersion => "unknown version", + Self::UnknownVariant => "unknown variant", + Self::NumberOutOfRange => "number out of range for target type", + Self::MissingFile => "handle references a file not present in the manifest", + } + } + + /// Whether an error of this kind should be treated as recoverable by + /// probing APIs (i.e., suitable for triggering a fallback to an alternative + /// load strategy). + /// + /// Recoverable kinds describe "the data did not match what this loader + /// expected" (a different version or shape might still succeed). Critical + /// kinds describe structural or integrity problems where retrying would be + /// pointless or unsafe. + pub const fn is_recoverable(self) -> bool { + match self { + // Shape/version probing signals — another loader might succeed. + Self::VersionMismatch | Self::MissingField | Self::TypeMismatch => true, + // Structural / integrity failures — give up. + Self::UnknownVersion + | Self::UnknownVariant + | Self::NumberOutOfRange + | Self::MissingFile => false, + } + } +} + +impl std::error::Error for Kind {} + +impl std::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl From for Error { + fn from(kind: Kind) -> Self { + Self { + inner: anyhow::Error::new(kind), + recoverable: kind.is_recoverable(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context_preserves_recoverable_flag() { + assert!( + Error::from(Kind::TypeMismatch) + .context("extra") + .is_recoverable() + ); + assert!( + !Error::from(Kind::MissingFile) + .context("extra") + .is_recoverable() + ); + } +} diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs new file mode 100644 index 000000000..3efe47540 --- /dev/null +++ b/diskann-record/src/load/mod.rs @@ -0,0 +1,266 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Loading Records from Disk +//! +//! This module mirrors the [`super::save`] side. User types implement [`Load`] (or, for +//! primitive-like leaves, [`Loadable`]) and obtain an [`Object`] / [`Context`] from which +//! they extract individual fields and side-car artifacts. +//! +//! The generic entry point is `load`; `load_from_disk` (available under the `disk` +//! feature) is the disk-backed convenience wrapper that reads a manifest and dispatches +//! into the user type's [`Load`] impl. +//! +//! # Reading Records +//! +//! The [`load_fields!`](crate::load_fields) macro is the idiomatic way to extract a fixed +//! set of named fields from an [`Object`] into local bindings. It mirrors the structure +//! of [`save_fields!`](crate::save_fields). +//! +//! # Version Dispatch +//! +//! Each [`Load`] impl declares a [`VERSION`](Load::VERSION). If the version stored in the +//! manifest matches, [`Load::load`] is called. Otherwise [`Load::load_legacy`] is invoked +//! so the impl can perform a custom upgrade; returning an +//! [`error::Kind::UnknownVersion`] from `load_legacy` indicates the loader has no upgrade +//! path for that schema. +//! +//! # Recoverable vs. Critical Errors +//! +//! Load errors are tagged as recoverable or critical. Probing call sites that try +//! multiple loaders should only retry when [`Error::is_recoverable`] returns `true`. See +//! [`error::Kind::is_recoverable`] for the classification. + +pub mod error; +pub use error::{Error, Result}; + +mod context; +pub(crate) use context::LoadContext; +pub use context::{Context, Object, Reader}; + +#[cfg(feature = "disk")] +use std::path::Path; + +use crate::{Version, save}; + +/// Deserialize a `T` from `context`. +/// +/// This is the generic entry point: it asks `context` for its root value and threads a +/// [`Context`] borrowing `context` through `T`'s [`Loadable::load`] impl. +/// `load_from_disk` is the disk-backed convenience wrapper available under the `disk` +/// feature. +/// +/// # Errors +/// +/// Returns [`Error`] if the context cannot produce a root value or if `T`'s loader fails. +pub(crate) fn load<'a, T, C>(context: &'a C) -> Result +where + T: Loadable<'a>, + C: LoadContext, +{ + let value = context.value()?; + T::load(Context::new(context, value)) +} + +/// Reload a value previously written by [`save::save_to_disk`]. +/// +/// `metadata` is the manifest JSON path produced by the saver, and `dir` is the +/// directory holding any side-car artifacts. +/// +/// # Errors +/// +/// Returns [`Error`] if the manifest is missing or malformed, if a referenced artifact is +/// missing, or if a user [`Load`] impl fails (e.g. due to a version mismatch with no +/// upgrade path). +#[cfg(feature = "disk")] +pub fn load_from_disk(metadata: &Path, dir: &Path) -> Result +where + T: for<'a> Loadable<'a>, +{ + let context = crate::backend::disk::DiskLoadContext::new(metadata, dir)?; + load(&context) +} + +/// Implemented by user types that can be reloaded from a versioned [`Object`]. +/// +/// This is the symmetric counterpart to [`super::save::Save`]. Implementations describe +/// how to reconstruct `Self` from the manifest representation, and how to upgrade +/// records written by older schemas via [`Self::load_legacy`]. +/// +/// # Enums +/// +/// Enum types dispatch on the single non-reserved key of the object (see +/// [`Object::single_key`]) and recurse via [`Object::child`] into the payload. +pub trait Load<'a>: Sized { + /// The schema version this impl was written against. + /// + /// Compared with the manifest's version to choose between [`Self::load`] and + /// [`Self::load_legacy`]. + const VERSION: Version; + + /// Reconstruct `Self` from an object whose `$version` matches [`Self::VERSION`]. + fn load(object: Object<'a>) -> Result; + + /// Reconstruct `Self` from an object whose `$version` does *not* match + /// [`Self::VERSION`]. + /// + /// Implementations may either upgrade the older record or refuse with + /// [`error::Kind::UnknownVersion`] when no upgrade is possible. + fn load_legacy(object: Object<'a>) -> Result; +} + +/// Implemented by any value that can be deserialized from a [`Context`]. +/// +/// This is the bottom of the trait hierarchy and is implemented for the same set of +/// primitive-like types as [`super::save::Saveable`]. Most user types should implement +/// [`Load`] (which gets a [`Loadable`] impl for free via the blanket below) rather than +/// [`Loadable`] directly. +pub trait Loadable<'a>: Sized { + /// Deserialize `Self` from a [`Context`]. + fn load(context: Context<'a>) -> Result; +} + +impl<'a, T> Loadable<'a> for T +where + T: Load<'a>, +{ + fn load(context: Context<'a>) -> Result { + let object = context.as_object().ok_or(error::Kind::TypeMismatch)?; + let version = object.version(); + if version == T::VERSION { + T::load(object) + } else { + T::load_legacy(object) + } + } +} + +//////////// +// Macros // +//////////// + +/// Extract a fixed set of named fields from an [`Object`] into local bindings. +/// +/// Each name in the list becomes a `let` binding of the same name. An optional `: T` +/// suffix selects the [`Loadable`] target type; without it, type inference picks the +/// type from the surrounding context. Errors from individual fields are propagated with +/// `?`. +/// +/// ```ignore +/// load_fields!(object, [ +/// dim: usize, +/// label, // type inferred +/// vectors: save::Handle, +/// ]); +/// ``` +#[macro_export] +macro_rules! load_fields { + (@field $object:ident, $field:ident: $T:ty) => { + let $field: $T = $object.field(stringify!($field))?; + }; + (@field $object:ident, $field:ident) => { + let $field = $object.field(stringify!($field))?; + }; + ($object:ident, [$($field:ident $(: $ty:ty)?),+ $(,)?]) => { + $( + $crate::load_fields!(@field $object, $field $(: $ty)?); + )+ + }; +} + +/////////////// +// Bootstrap // +/////////////// + +impl<'a> Loadable<'a> for &'a str { + fn load(context: Context<'a>) -> Result { + context + .as_str() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl Loadable<'_> for String { + fn load(context: Context<'_>) -> Result { + context.load::<&str>().map(|s| s.into()) + } +} + +impl Loadable<'_> for save::Handle { + fn load(context: Context<'_>) -> Result { + context + .as_handle() + .cloned() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl Loadable<'_> for bool { + fn load(context: Context<'_>) -> Result { + context + .as_bool() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl<'a, T> Loadable<'a> for Option +where + T: Loadable<'a>, +{ + fn load(context: Context<'a>) -> Result { + if context.is_null() { + Ok(None) + } else { + T::load(context).map(Some) + } + } +} + +impl<'a, T> Loadable<'a> for Vec +where + T: Loadable<'a>, +{ + fn load(context: Context<'a>) -> Result { + match context.as_array() { + Some(array) => array.iter().map(T::load).collect(), + None => Err((error::Kind::TypeMismatch).into()), + } + } +} + +macro_rules! load_number { + ($T:ty) => { + impl Loadable<'_> for $T { + fn load(context: Context<'_>) -> Result { + match context.as_number() { + Some(n) => n.try_into().map_err(|_| error::Kind::NumberOutOfRange.into()), + None => Err((error::Kind::TypeMismatch).into()), + } + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(load_number!($Ts);)+ + } +} + +load_number!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, f32, f64); + +// NonZero* primitives are loaded by deserializing the inner numeric type and then +// validating it is non-zero. A zero value produces a `NumberOutOfRange` light error. +macro_rules! load_nonzero { + ($T:ty, $Inner:ty) => { + impl Loadable<'_> for $T { + fn load(context: Context<'_>) -> Result { + let inner: $Inner = context.load()?; + <$T>::new(inner).ok_or_else(|| error::Kind::NumberOutOfRange.into()) + } + } + }; +} + +load_nonzero!(std::num::NonZeroU32, u32); +load_nonzero!(std::num::NonZeroU64, u64); +load_nonzero!(std::num::NonZeroUsize, usize); diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs new file mode 100644 index 000000000..3a2b8b34f --- /dev/null +++ b/diskann-record/src/number.rs @@ -0,0 +1,325 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Lossless container for the on-wire numeric kinds (`u64`, `i64`, `f64`). +//! +//! [`Number`] is the value type produced by the manifest deserializer for every JSON +//! number. The conversion accessors (`as_u32`, `as_i64`, etc.) attempt to narrow into a +//! target Rust type and return `None` when the value is out of range or would lose +//! precision; loaders surface this as [`crate::load::error::Kind::NumberOutOfRange`]. + +#[cfg(feature = "serde")] +use serde::de::{self, Visitor}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A numeric value carried in a manifest, preserving the kind the writer chose. +/// +/// The wire format distinguishes unsigned, signed, and floating-point numbers; the +/// deserializer preserves that distinction by selecting the matching variant. Use the +/// narrowing accessors (e.g. [`Number::as_u32`], [`Number::as_f64`]) to extract a Rust +/// value of the desired type. +#[derive(Debug, Clone, Copy)] +pub enum Number { + U64(u64), + I64(i64), + F64(f64), +} + +impl Number { + /// Returns the string sentinel for a non-finite `f64`, or `None` for any finite or + /// integer value. + /// + /// JSON cannot represent `NaN`/`\pm inf` as numeric literals, so these are encoded as + /// the strings `"nan"`, `"inf"`, and `"neg_inf"` instead. [`Number::from_sentinel`] + /// is the inverse. + pub(crate) fn sentinel(self) -> Option<&'static str> { + match self { + Self::F64(v) if v.is_nan() => Some("nan"), + Self::F64(v) if v == f64::INFINITY => Some("inf"), + Self::F64(v) if v == f64::NEG_INFINITY => Some("neg_inf"), + _ => None, + } + } + + /// Decodes a non-finite `f64` sentinel produced by [`Number::sentinel`], or `None` + /// for any other string. + pub(crate) fn from_sentinel(s: &str) -> Option { + match s { + "nan" => Some(Self::F64(f64::NAN)), + "inf" => Some(Self::F64(f64::INFINITY)), + "neg_inf" => Some(Self::F64(f64::NEG_INFINITY)), + _ => None, + } + } +} + +#[cfg(feature = "serde")] +impl Serialize for Number { + fn serialize(&self, serializer: S) -> Result { + if let Some(sentinel) = self.sentinel() { + return serializer.serialize_str(sentinel); + } + match *self { + Self::U64(v) => serializer.serialize_u64(v), + Self::I64(v) => serializer.serialize_i64(v), + Self::F64(v) => serializer.serialize_f64(v), + } + } +} + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for Number { + fn deserialize>(deserializer: D) -> Result { + struct NumberVisitor; + + impl<'de> Visitor<'de> for NumberVisitor { + type Value = Number; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a number") + } + + fn visit_u64(self, v: u64) -> Result { + Ok(Number::U64(v)) + } + + fn visit_i64(self, v: i64) -> Result { + Ok(Number::I64(v)) + } + + fn visit_f64(self, v: f64) -> Result { + Ok(Number::F64(v)) + } + + fn visit_str(self, v: &str) -> Result { + Number::from_sentinel(v).ok_or_else(|| { + de::Error::custom(format!("expected a number or numeric sentinel, got {v:?}")) + }) + } + } + + deserializer.deserialize_any(NumberVisitor) + } +} + +macro_rules! try_cast { + ($v:ident :$T:ty => $U:ty) => {{ + let c = $v as $U; + if c as $T == $v { Some(c) } else { None } + }}; +} + +macro_rules! int { + ($f:ident, $T:ty) => { + pub fn $f(self) -> Option<$T> { + match self { + Self::U64(v) => v.try_into().ok(), + Self::I64(v) => v.try_into().ok(), + Self::F64(v) => try_cast!(v:f64 => $T), + } + } + } +} + +macro_rules! float { + ($f:ident, $T:ty) => { + pub fn $f(self) -> Option<$T> { + match self { + Self::U64(v) => try_cast!(v:u64 => $T), + Self::I64(v) => try_cast!(v:i64 => $T), + // NaN is representable in every float target but never equals itself, so the + // `try_cast!` round-trip guard would reject it; pass it through directly. + Self::F64(v) if v.is_nan() => Some(v as $T), + Self::F64(v) => try_cast!(v:f64 => $T), + } + } + } +} + +impl Number { + int!(as_u8, u8); + int!(as_u16, u16); + int!(as_u32, u32); + int!(as_u64, u64); + int!(as_usize, usize); + + int!(as_i8, i8); + int!(as_i16, i16); + int!(as_i32, i32); + int!(as_i64, i64); + int!(as_isize, isize); + + float!(as_f32, f32); + float!(as_f64, f64); +} + +macro_rules! from { + ($T:ty => $variant:ident) => { + impl From<$T> for Number { + fn from(v: $T) -> Self { + Self::$variant(v.into()) + } + } + }; + ($($T:ty => $variant:ident),+ $(,)?) => { + $(from!($T => $variant);)+ + } +} + +from!( + u64 => U64, + u32 => U64, + u16 => U64, + u8 => U64, + i64 => I64, + i32 => I64, + i16 => I64, + i8 => I64, + f32 => F64, + f64 => F64, +); + +impl From for Number { + fn from(v: usize) -> Self { + Self::U64(v.try_into().unwrap()) + } +} + +macro_rules! try_from { + ($T:ty => $f:ident) => { + impl TryFrom for $T { + type Error = (); + fn try_from(number: Number) -> Result<$T, Self::Error> { + number.$f().ok_or(()) + } + } + }; + ($($T:ty => $f:ident),+ $(,)?) => { + $(try_from!($T => $f);)+ + } +} + +try_from!( + u64 => as_u64, + u32 => as_u32, + u16 => as_u16, + u8 => as_u8, + usize => as_usize, + + i64 => as_i64, + i32 => as_i32, + i16 => as_i16, + i8 => as_i8, + isize => as_isize, + + f32 => as_f32, + f64 => as_f64, +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn integer_accessors_range_check() { + assert_eq!(Number::U64(255).as_u8(), Some(255)); + assert_eq!(Number::U64(256).as_u8(), None); + assert_eq!(Number::I64(-1).as_u8(), None); + assert_eq!(Number::I64(-5).as_i8(), Some(-5)); + assert_eq!(Number::I64(-129).as_i8(), None); + } + + #[test] + fn float_to_integer_requires_integral_value() { + assert_eq!(Number::F64(2.0).as_u32(), Some(2)); + assert_eq!(Number::F64(2.5).as_u32(), None); + assert_eq!(Number::F64(-1.0).as_u32(), None); + } + + // Regression for the NaN narrowing-accessor bug. + #[test] + fn nan_survives_float_accessors() { + assert_eq!( + Number::F64(f64::NAN).as_f64().map(f64::is_nan), + Some(true), + "as_f64 dropped a NaN" + ); + assert_eq!( + Number::F64(f64::NAN).as_f32().map(f32::is_nan), + Some(true), + "as_f32 dropped a NaN" + ); + // Sanity: infinities already work, so this is specific to NaN. + assert_eq!(Number::F64(f64::INFINITY).as_f64(), Some(f64::INFINITY)); + } + + #[test] + fn nan_round_trips_through_try_from() { + let back = f64::try_from(Number::F64(f64::NAN)).expect("NaN must convert back to f64"); + assert!(back.is_nan()); + } + + #[test] + fn try_from_surfaces_out_of_range() { + assert!(u8::try_from(Number::U64(300)).is_err()); + assert_eq!(u16::try_from(Number::U64(300)).unwrap(), 300); + assert!(usize::try_from(Number::I64(-1)).is_err()); + } + + #[cfg(feature = "disk")] + #[test] + fn non_finite_floats_round_trip_via_sentinels() { + for (value, sentinel) in [ + (f64::NAN, "\"nan\""), + (f64::INFINITY, "\"inf\""), + (f64::NEG_INFINITY, "\"neg_inf\""), + ] { + let json = serde_json::to_string(&Number::F64(value)).unwrap(); + assert_eq!(json, sentinel); + + let back: Number = serde_json::from_str(&json).unwrap(); + match back { + Number::F64(v) if value.is_nan() => assert!(v.is_nan()), + Number::F64(v) => assert_eq!(v, value), + other => panic!("expected F64, got {other:?}"), + } + } + } + + #[cfg(feature = "disk")] + #[test] + fn finite_floats_serialize_as_json_numbers() { + assert_eq!(serde_json::to_string(&Number::F64(1.5)).unwrap(), "1.5"); + assert_eq!(serde_json::to_string(&Number::U64(7)).unwrap(), "7"); + assert_eq!(serde_json::to_string(&Number::I64(-7)).unwrap(), "-7"); + } + + #[cfg(feature = "disk")] + #[test] + fn json_numbers_select_matching_variant() { + // Unsigned, signed, and floating-point literals each route to the visitor + // method that preserves the writer's chosen kind. + assert!(matches!(serde_json::from_str("7").unwrap(), Number::U64(7))); + assert!(matches!( + serde_json::from_str("-7").unwrap(), + Number::I64(-7) + )); + match serde_json::from_str("1.5").unwrap() { + Number::F64(v) => assert_eq!(v, 1.5), + other => panic!("expected F64, got {other:?}"), + } + } + + #[cfg(feature = "disk")] + #[test] + fn non_sentinel_string_is_rejected() { + let err = serde_json::from_str::("\"bogus\"").unwrap_err(); + assert!( + err.to_string().contains("numeric sentinel"), + "unexpected error message: {err}" + ); + } +} diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs new file mode 100644 index 000000000..cc9a6381b --- /dev/null +++ b/diskann-record/src/save/context.rs @@ -0,0 +1,202 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Save-side context and side-car writer. +//! +//! [`Context`] is the cheap handle handed to every [`super::Save::save`] impl. It owns +//! nothing visible to the user; cloning it is free, and it can be passed to children to +//! propagate the same artifact-tracking state. +//! +//! [`Writer`] is the borrowed side-car artifact handle returned by [`Context::write`]. +//! It implements [`std::io::Write`] and [`std::io::Seek`]; calling +//! [`Writer::finish`] flushes the buffer and yields a [`Handle`] that can be inserted +//! into a [`super::Record`]. + +use std::io::BufWriter; + +use crate::save::{Error, Handle, Result, Value}; + +/// Generate forwarding [`std::io::Write`] and [`std::io::Seek`] impls for `$T` that +/// delegate every method to its `$field` member. +macro_rules! delegate_write_and_seek { + ($field:ident, $T:ty) => { + impl std::io::Write for $T { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.$field.write(buf) + } + fn flush(&mut self) -> std::io::Result<()> { + self.$field.flush() + } + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + self.$field.write_vectored(bufs) + } + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + self.$field.write_all(buf) + } + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { + self.$field.write_fmt(args) + } + } + + impl std::io::Seek for $T { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.$field.seek(pos) + } + fn rewind(&mut self) -> std::io::Result<()> { + self.$field.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.$field.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.$field.seek_relative(offset) + } + } + }; +} + +pub(crate) use delegate_write_and_seek; + +/// The backing store for a save operation. +/// +/// A `SaveContext` decides where side-car artifacts are written ([`SaveContext::write`]) +/// and how the final manifest is committed ([`SaveContext::finish`]). The concrete +/// implementations live under `crate::backend`: a disk-backed context (under the `disk` +/// feature) and an in-memory context. Alternative implementations (e.g. a virtual +/// filesystem) can be supplied for testing or to avoid touching the filesystem. +/// +/// The generic [`save`](super::save) entry point is parameterized over this trait so +/// that the base crate carries no hard dependency on any particular implementation. +pub(crate) trait SaveContext { + /// The value produced once the manifest has been committed by + /// [`SaveContext::finish`]. For the disk-backed context this is `()`. + type Output; + + /// Allocate a new side-car artifact, optionally tagged with a human-readable `key`. + /// + /// The artifact's on-disk name (and the value stored in its [`Handle`]) is prefixed + /// with the count of artifacts written so far; when `key` is `Some`, it is appended as + /// `{count}-{key}` purely to aid debugging. Reusing the same `key` across calls is + /// allowed — the count prefix disambiguates the artifacts. + /// + /// # Errors + /// + /// Returns [`Error`] if `key` is `Some` but not a simple relative file name, or if the + /// underlying artifact cannot be created. + fn write(&self, key: Option<&str>) -> Result>; + + /// Commit the manifest `value`, consuming the context. + /// + /// # Errors + /// + /// Returns [`Error`] if the manifest cannot be serialized or committed. + fn finish(self, value: Value<'_>) -> Result; +} + +/// Object-safe view of the artifact-allocating half of a [`SaveContext`]. +/// +/// [`Context`] holds a `&dyn GetWrite` so that the same handle can be threaded through +/// every [`Save::save`](super::Save) impl regardless of the concrete context type. +/// [`SaveContext::finish`] (which consumes `self` and names an associated type) is not +/// object safe, so it is deliberately excluded from this trait. +pub(super) trait GetWrite { + fn write(&self, key: Option<&str>) -> Result>; +} + +impl GetWrite for T +where + T: SaveContext, +{ + fn write(&self, key: Option<&str>) -> Result> { + ::write(self, key) + } +} + +/// A cheap, clonable handle threaded through every [`Save::save`](super::Save) impl. +/// +/// `Context` exposes one operation — [`Context::write`] — for allocating a side-car +/// artifact. The same context is passed to nested [`Save`](super::Save) impls (typically +/// via the [`save_fields!`](crate::save_fields) macro), so a single save tree shares +/// artifact-name bookkeeping. It borrows the backing `SaveContext` as an object-safe +/// `GetWrite` so that the save tree is agnostic to the concrete context type. +#[derive(Clone)] +pub struct Context<'a> { + inner: &'a dyn GetWrite, +} + +impl<'a> Context<'a> { + pub(super) fn new(inner: &'a dyn GetWrite) -> Self { + Self { inner } + } + + /// Allocate a new side-car artifact in the manifest directory, optionally tagging it + /// with a human-readable `key`. + /// + /// The artifact is named with the count of artifacts written so far (with `key`, when + /// `Some`, appended as a readability hint), so the same `key` may be passed to + /// multiple calls. The returned [`Writer`] is positioned at offset 0 and implements + /// [`std::io::Write`] / [`std::io::Seek`]. Call [`Writer::finish`] to obtain a + /// [`Handle`] that may be inserted into a [`Record`](super::Record). + /// + /// # Errors + /// + /// Returns [`Error`] if `key` is `Some` but not a simple relative file name, or if the + /// underlying file cannot be created. + pub fn write(&self, key: Option<&str>) -> Result> { + self.inner.write(key) + } +} + +/// The backend-specific half of a [`Writer`]. +/// +/// Each [`SaveContext`](super::SaveContext) implementation supplies its own +/// `WriterInner` (e.g. an in-memory cursor or an on-disk file). [`Writer`] wraps it in a +/// [`BufWriter`] and forwards [`std::io::Write`] / [`std::io::Seek`] to it; on +/// [`Writer::finish`] the (flushed) inner is consumed to commit the artifact and yield a +/// [`Handle`]. +pub(crate) trait WriterInner: std::io::Write + std::io::Seek + std::fmt::Debug { + /// Commit the completed artifact under `name`, returning its [`Handle`]. + fn finish(self: Box, name: String) -> Result; +} + +/// A borrowed side-car artifact writer produced by [`Context::write`]. +/// +/// Implements [`std::io::Write`] and [`std::io::Seek`]. Writes are buffered; calling +/// [`Writer::finish`] flushes the buffer, commits the artifact through the backing +/// writer, and returns a [`Handle`]. +#[derive(Debug)] +pub struct Writer<'a> { + inner: BufWriter>, + name: String, +} + +impl<'a> Writer<'a> { + /// Wrap a backend-specific [`WriterInner`] into a buffered [`Writer`] named `name`. + pub(crate) fn new(inner: T, name: String) -> Self + where + T: WriterInner + 'a, + { + Self { + inner: BufWriter::new(Box::new(inner)), + name, + } + } + + /// Flush and close the writer, returning a [`Handle`] for the artifact. + /// + /// Insert the returned handle into a [`Record`](super::Record) (typically via + /// [`Record::insert`](super::Record::insert)) so that load-side code can locate the + /// artifact through the manifest. + pub fn finish(self) -> Result { + // into_inner() flushes the buffered bytes into the backend before we commit it. + let inner = self + .inner + .into_inner() + .map_err(|err| Error::new(err.into_error()))?; + inner.finish(self.name) + } +} + +delegate_write_and_seek!(inner, Writer<'_>); diff --git a/diskann-record/src/save/error.rs b/diskann-record/src/save/error.rs new file mode 100644 index 000000000..7ea110704 --- /dev/null +++ b/diskann-record/src/save/error.rs @@ -0,0 +1,65 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Save-side error type. +//! +//! Mirrors the load-side [`super::super::load::Error`] in shape but does not carry a +//! recoverable / critical distinction: every save failure is terminal because no +//! probing fallback exists on the writer side. + +use std::fmt::{Debug, Display}; + +/// A specialized [`std::result::Result`] for save-side operations. +pub type Result = ::std::result::Result; + +/// Save-side error. +/// +/// Wraps [`anyhow::Error`] for rich context chains (see [`Error::context`]) and is +/// returned from every fallible save-side operation, including [`super::Save::save`] +/// impls and `save_to_disk`. +#[derive(Debug)] +pub struct Error { + inner: anyhow::Error, +} + +impl Error { + /// Wrap an underlying source error. + pub fn new(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Error { + inner: anyhow::Error::new(err), + } + } + + /// Construct an error from a display message with no source. + pub fn message(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Error { + inner: anyhow::Error::msg(message), + } + } + + /// Attach additional context describing what was being attempted. + pub fn context(self, message: D) -> Self + where + D: Display + Send + Sync + 'static, + { + Error { + inner: self.inner.context(message), + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Save Error: {:?}", self.inner) + } +} + +impl std::error::Error for Error {} diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs new file mode 100644 index 000000000..606b9a122 --- /dev/null +++ b/diskann-record/src/save/mod.rs @@ -0,0 +1,288 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Saving Records to Disk +//! +//! This module provides the writer-side of the framework. User types implement [`Save`] +//! (or, for primitive-like leaves, [`Saveable`]) and obtain a [`Context`] from which they +//! request side-car artifact writers and assemble a [`Record`] of named fields. +//! +//! The generic entry point is `save`; `save_to_disk` (available under the `disk` +//! feature) is the disk-backed convenience wrapper that serializes a value into a +//! caller-chosen directory plus a manifest path. +//! +//! # Building Records +//! +//! The [`save_fields!`](crate::save_fields) macro is the idiomatic way to build a record +//! from a struct or destructured enum variant. It handles per-field error context and +//! invokes [`Saveable::save`] on each value. +//! +//! # Side-Car Artifacts +//! +//! Binary blobs (e.g. vector buffers) are written to side-car files via +//! [`Context::write`], which returns a [`Writer`]. The handle returned by +//! [`Writer::finish`](context::Writer::finish) can be embedded into the record as a +//! [`Handle`]; it serializes as a `$handle` reference and is rehydrated on the load side. +pub use crate::value::{Handle, Keys, Record, Value, Versioned}; + +mod context; +pub use context::{Context, Writer}; +pub(crate) use context::{SaveContext, WriterInner, delegate_write_and_seek}; + +mod error; +pub use error::{Error, Result}; + +use crate::Version; + +/// Serialize `x` into `context`, returning the context's committed output. +/// +/// This is the generic entry point: it threads a [`Context`] borrowing `context` through +/// the value's [`Saveable::save`] impl, then commits the resulting manifest via +/// [`SaveContext::finish`]. `save_to_disk` is the disk-backed convenience wrapper +/// available under the `disk` feature. +/// +/// # Errors +/// +/// Returns [`Error`] if a user impl returns an error or if the context fails to commit +/// the manifest. +pub(crate) fn save(x: &T, context: C) -> Result +where + T: Saveable, + C: SaveContext, +{ + let value = x.save(Context::new(&context))?; + context.finish(value) +} + +/// Serialize `x` to disk. +/// +/// The manifest (a JSON document) is written atomically to `metadata`; any side-car +/// artifacts the type's [`Save::save`] impl creates via [`Context::write`] are written +/// into `dir`. +/// +/// # Errors +/// +/// Returns [`Error`] if the directory cannot be written to, if the manifest cannot be +/// serialized, or if a user impl returns an error. +#[cfg(feature = "disk")] +pub fn save_to_disk( + x: &T, + dir: impl AsRef, + metadata: impl AsRef, +) -> Result<()> +where + T: Saveable, +{ + let context = + crate::backend::disk::DiskSaveContext::new(dir.as_ref().into(), metadata.as_ref().into())?; + save(x, context) +} + +/// Implemented by user types that map to a versioned [`Record`]. +/// +/// This is the primary trait for structured user types. A [`Save`] impl describes the +/// versioned schema of `Self`: its associated [`VERSION`](Self::VERSION) is attached to +/// the [`Record`] produced by [`Self::save`](Self::save). +/// +/// # Enums +/// +/// Enum types are encoded by returning a [`Record`] with a single user key whose name +/// is the variant tag and whose value is the variant's payload (frequently +/// [`Value::Null`] for unit variants). See the crate-level docs for a worked example. +pub trait Save { + /// The schema version attached to records produced by this impl. + /// + /// Loaders compare this against the version stored in the manifest to decide + /// between [`Load::load`](crate::load::Load::load) and + /// [`Load::load_legacy`](crate::load::Load::load_legacy). + const VERSION: Version; + + /// Serialize `self` into a [`Record`]. + /// + /// Use the supplied [`Context`] to request side-car artifact writers. Use the + /// [`save_fields!`](crate::save_fields) macro to populate the record. + fn save(&self, context: Context<'_>) -> Result>; +} + +/// Implemented by any value that can be written into a [`Value`]. +/// +/// This is the bottom of the trait hierarchy and is implemented for: +/// +/// * Primitive numeric types (signed, unsigned, floats, `NonZero*`). +/// * [`bool`], [`str`], [`String`], and [`Handle`]. +/// * [`Option`] (serializes `None` as [`Value::Null`]). +/// * `&[T]` and [`Vec`] (serialize as [`Value::Array`]). +/// * Any `T: Save` (wraps the produced record in [`Value::Object`] with the type's +/// [`Save::VERSION`]). +/// +/// Most user types should implement [`Save`] (which gets a [`Saveable`] impl for free +/// via the blanket below) rather than [`Saveable`] directly. +pub trait Saveable { + /// Serialize `self` into a [`Value`]. + fn save(&self, context: Context<'_>) -> Result>; +} + +impl Saveable for T +where + T: Save, +{ + fn save(&self, context: Context<'_>) -> Result> { + let record = self.save(context)?; + Ok(record.into_value(T::VERSION)) + } +} + +////////////////// +// Random Stuff // +////////////////// + +impl Saveable for [T] +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + let values: Result> = self.iter().map(|t| t.save(context.clone())).collect(); + values.map(Value::Array) + } +} + +impl Saveable for Vec +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + self.as_slice().save(context) + } +} + +impl Saveable for str { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::String(self.into())) + } +} + +impl Saveable for String { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::String(self.as_str().into())) + } +} + +impl Saveable for Handle { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Handle(self.clone())) + } +} + +impl Saveable for bool { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Bool(*self)) + } +} + +impl Saveable for Option +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + match self { + None => Ok(Value::Null), + Some(t) => t.save(context), + } + } +} + +macro_rules! save_number { + ($T:ty) => { + impl Saveable for $T { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Number((*self).into())) + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(save_number!($Ts);)+ + } +} + +save_number!(usize, u64, u32, u16, u8, i64, i32, i16, i8, f32, f64); + +// NonZero* primitives serialize as their inner numeric type. Loaders reject zero. +macro_rules! save_nonzero { + ($T:ty) => { + impl Saveable for $T { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Number(self.get().into())) + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(save_nonzero!($Ts);)+ + } +} + +save_nonzero!( + std::num::NonZeroU32, + std::num::NonZeroU64, + std::num::NonZeroUsize +); + +#[derive(Debug, Clone, Copy)] +#[doc(hidden)] +pub struct Serializing(pub &'static str); + +impl std::fmt::Display for Serializing { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "while serializing field \"{}\"", self.0) + } +} + +/// Build a [`Record`] from a list of fields. +/// +/// Two forms are supported: +/// +/// * `save_fields!(self, context, [a, b, c])` reads each field as `self.a`, +/// `self.b`, etc. Use this from `Save::save` for plain structs. +/// * `save_fields!(context, [a, b, c])` reads each field from a local binding of +/// the same name. Use this inside enum match arms where the variant's payload +/// has already been destructured into local bindings. Those bindings are +/// assumed to be references (which is automatic when matching against `&self`); +/// for an owned local, take a reference explicitly first. +#[macro_export] +macro_rules! save_fields { + ($me:ident, $context:ident, [$($field:ident),+ $(,)?]) => {{ + $crate::save::Record::from_iter( + [ + $( + ( + ::std::borrow::Cow::Borrowed(stringify!($field)), + <_ as $crate::save::Saveable>::save( + &$me.$field, + $context.clone() + ).map_err(|err| { + err.context($crate::save::Serializing(stringify!($field))) + })? + ), + )+ + ] + ) + }}; + ($context:ident, [$($field:ident),+ $(,)?]) => {{ + $crate::save::Record::from_iter( + [ + $( + ( + ::std::borrow::Cow::Borrowed(stringify!($field)), + <_ as $crate::save::Saveable>::save( + $field, + $context.clone() + ).map_err(|err| { + err.context($crate::save::Serializing(stringify!($field))) + })? + ), + )+ + ] + ) + }}; +} diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs new file mode 100644 index 000000000..c81b9ad73 --- /dev/null +++ b/diskann-record/src/value.rs @@ -0,0 +1,467 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Wire-level value types used in the on-disk manifest. +//! +//! These types are the shared currency of both halves of the framework: +//! user value -> [`save`](crate::save) -> [`Value`] in the save path, and the +//! [`Value`] -> [`load`](crate::load) -> user value in the load path. +//! +//! Every field stored in a manifest is one of: +//! +//! * [`Value::Null`] / [`Value::Bool`] / [`Value::Number`] / [`Value::String`] — +//! primitive scalars. +//! * [`Value::Array`] — a homogeneous sequence (used by `Vec` and `&[T]`). +//! * [`Value::Object`] — a [`Versioned`] [`Record`] (the canonical encoding for a +//! `T: crate::save::Save`). +//! * [`Value::Handle`] — a reference to a side-car artifact (produced by +//! [`crate::save::Context::write`] + [`crate::save::Writer::finish`]). +//! +//! Most user code never touches these enums directly. On the save side, +//! [`crate::save::Saveable`] impls turn Rust values into [`Value`]s and the +//! [`save_fields!`](crate::save_fields) macro assembles the surrounding [`Record`]; on +//! the load side, the [`crate::load`] accessors walk the same [`Value`] tree back into +//! Rust values. + +use std::{borrow::Cow, collections::HashMap}; + +#[cfg(feature = "serde")] +use serde::{ + Deserialize, Deserializer, Serialize, Serializer, + de::{self, MapAccess, SeqAccess, Visitor}, + ser::SerializeStruct, +}; + +use crate::{Number, Version, save::Error}; + +/// The wire-level union of every saveable kind. +/// +/// See the module-level docs for an overview of when each variant is produced. The +/// borrowing parameter `'a` lets [`Value::String`] and nested +/// records reuse memory owned by the caller without copying. +#[derive(Debug)] +pub enum Value<'a> { + Null, + Bool(bool), + Number(Number), + String(Cow<'a, str>), + Array(Vec>), + Object(Versioned<'a>), + Handle(Handle), +} + +#[cfg(feature = "serde")] +impl Serialize for Value<'_> { + fn serialize(&self, ser: S) -> Result { + match self { + Self::Null => ser.serialize_none(), + Self::Bool(b) => ser.serialize_bool(*b), + Self::Number(n) => n.serialize(ser), + Self::String(s) => ser.serialize_str(s), + Self::Array(a) => a.serialize(ser), + Self::Object(v) => v.serialize(ser), + Self::Handle(h) => h.serialize(ser), + } + } +} + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for Value<'static> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Inner; + + impl<'de> Visitor<'de> for Inner { + type Value = Value<'static>; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a valid Value") + } + + fn visit_unit(self) -> Result, E> { + Ok(Value::Null) + } + + fn visit_none(self) -> Result, E> { + Ok(Value::Null) + } + + fn visit_some(self, deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + Value::deserialize(deserializer) + } + + fn visit_bool(self, v: bool) -> Result, E> { + Ok(Value::Bool(v)) + } + + fn visit_u64(self, v: u64) -> Result, E> { + Ok(Value::Number(Number::U64(v))) + } + + fn visit_i64(self, v: i64) -> Result, E> { + Ok(Value::Number(Number::I64(v))) + } + + fn visit_f64(self, v: f64) -> Result, E> { + Ok(Value::Number(Number::F64(v))) + } + + fn visit_str(self, v: &str) -> Result, E> { + if let Some(n) = Number::from_sentinel(v) { + return Ok(Value::Number(n)); + } + Ok(Value::String(Cow::Owned(v.to_owned()))) + } + + fn visit_string(self, v: String) -> Result, E> { + if let Some(n) = Number::from_sentinel(&v) { + return Ok(Value::Number(n)); + } + Ok(Value::String(Cow::Owned(v))) + } + + fn visit_seq(self, mut seq: A) -> Result, A::Error> + where + A: SeqAccess<'de>, + { + let mut values = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(v) = seq.next_element()? { + values.push(v); + } + Ok(Value::Array(values)) + } + + fn visit_map(self, mut map: A) -> Result, A::Error> + where + A: MapAccess<'de>, + { + // TODO: Handle invariants that only one of our reserved words are present. + let mut version: Option = None; + let mut handle_name: Option = None; + let mut fields: HashMap, Value<'static>> = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "$version" => { + version = Some(map.next_value()?); + } + "$handle" => { + handle_name = Some(map.next_value()?); + } + _ => { + let value = map.next_value()?; + fields.insert(Cow::Owned(key), value); + } + } + } + + if let Some(name) = handle_name { + if version.is_some() || !fields.is_empty() { + return Err(de::Error::custom( + "handle object must contain only a \"$handle\" field", + )); + } + return Ok(Value::Handle(Handle(name))); + } + + if let Some(version) = version { + let record = Record { record: fields }; + return Ok(Value::Object(Versioned { record, version })); + } + + Err(de::Error::custom( + "map must contain either \"$version\" or \"$handle\"", + )) + } + } + + deserializer.deserialize_any(Inner) + } +} + +impl From for Value<'_> { + fn from(handle: Handle) -> Self { + Self::Handle(handle) + } +} + +impl Value<'_> { + /// Convert this value into a fully owned [`Value<'static>`], deep-copying any borrowed + /// string or byte data. + /// + /// This is the allocation-based equivalent of round-tripping through the wire format: + /// it severs every borrow from the originating data so the result can be stored + /// independently of its source (for example inside an in-memory + /// [`crate::MemoryContext`]). + pub fn into_owned(self) -> Value<'static> { + match self { + Self::Null => Value::Null, + Self::Bool(b) => Value::Bool(b), + Self::Number(n) => Value::Number(n), + Self::String(s) => Value::String(Cow::Owned(s.into_owned())), + Self::Array(values) => { + Value::Array(values.into_iter().map(Value::into_owned).collect()) + } + Self::Object(versioned) => Value::Object(versioned.into_owned()), + Self::Handle(handle) => Value::Handle(handle), + } + } +} + +/// A map of named [`Value`]s. +/// +/// `Record` is the body of an object in the manifest. On the save side each call to +/// [`crate::save::Save::save`] returns one, and [`Record::into_value`] wraps it as a +/// [`Versioned`] [`Value::Object`] ready for insertion into another record; on the load +/// side the same record is read back through [`crate::load::Object`]. Keys beginning +/// with `$` are reserved for framework metadata (see [`crate::is_reserved`]). +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] +pub struct Record<'a> { + record: HashMap, Value<'a>>, +} + +impl<'a> Record<'a> { + /// Construct an empty record. Useful for unit enum variants. + pub fn empty() -> Self { + Self { + record: HashMap::new(), + } + } + + /// Returns `true` if a value is registered under `key`. + pub fn contains_key(&self, key: &str) -> bool { + self.record.contains_key(key) + } + + /// Look up the [`Value`] registered under `key`, if any. + pub fn get(&self, key: &str) -> Option<&Value<'a>> { + self.record.get(key) + } + + /// Number of (user) keys in this record. Reserved keys (`$version`, `$handle`) + /// are tracked elsewhere and never appear here. + pub fn len(&self) -> usize { + self.record.len() + } + + /// Returns `true` if this record has no user keys. + pub fn is_empty(&self) -> bool { + self.record.is_empty() + } + + /// Iterate over the user keys in this record. Order is unspecified. + pub fn keys(&self) -> Keys<'_, 'a> { + Keys { + inner: self.record.keys(), + } + } + + /// Insert `value` under `key`. + /// + /// # Errors + /// + /// Returns [`Error`] if `key` begins with `$`, which is reserved for the + /// save/load framework (see [`crate::is_reserved`]). + pub fn insert(&mut self, key: K, value: V) -> crate::save::Result>> + where + K: Into>, + V: Into>, + { + let key = key.into(); + if crate::is_reserved(&key) { + return Err(Error::message(format!( + "record key {:?} is reserved (keys starting with `$` are reserved for the \ + save/load framework)", + key, + ))); + } + + Ok(self.record.insert(key, value.into())) + } + + /// Wrap this record as a versioned [`Value`] ready for insertion into another + /// record. Use this from enum [`Save`](crate::save::Save) impls to attach the + /// outer type's version to an inline variant payload. + pub fn into_value(self, version: Version) -> Value<'a> { + Value::Object(Versioned::new(self, version)) + } + + /// Convert this record into a fully owned [`Record<'static>`], deep-copying borrowed + /// keys and values. See [`Value::into_owned`]. + pub fn into_owned(self) -> Record<'static> { + Record { + record: self + .record + .into_iter() + .map(|(key, value)| (Cow::Owned(key.into_owned()), value.into_owned())) + .collect(), + } + } +} + +/// Iterator over the keys of a [`Record`]. +pub struct Keys<'r, 'a> { + inner: std::collections::hash_map::Keys<'r, Cow<'a, str>, Value<'a>>, +} + +impl<'r, 'a> Iterator for Keys<'r, 'a> { + type Item = &'r str; + + fn next(&mut self) -> Option { + self.inner.next().map(|k| k.as_ref()) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl ExactSizeIterator for Keys<'_, '_> {} + +impl<'a> FromIterator<(Cow<'a, str>, Value<'a>)> for Record<'a> { + fn from_iter, Value<'a>)>>(itr: I) -> Self { + Self { + record: itr.into_iter().collect(), + } + } +} + +/// A [`Record`] paired with the schema [`Version`] used to produce it. +/// +/// Serialized as a normal object plus a `$version` field on the wire. Constructed by +/// [`Record::into_value`]. +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Versioned<'a> { + #[cfg_attr(feature = "serde", serde(flatten))] + record: Record<'a>, + #[cfg_attr(feature = "serde", serde(rename = "$version"))] + version: Version, +} + +impl<'a> Versioned<'a> { + pub(crate) fn new(record: Record<'a>, version: Version) -> Self { + Self { record, version } + } + + pub(crate) fn version(&self) -> Version { + self.version + } + + pub(crate) fn record(&self) -> &Record<'a> { + &self.record + } + + pub(crate) fn into_owned(self) -> Versioned<'static> { + Versioned { + record: self.record.into_owned(), + version: self.version, + } + } +} + +/// A reference to a side-car artifact in the manifest directory. +/// +/// Produced by [`Writer::finish`](crate::save::Writer::finish) after a side-car write completes and +/// inserted into a [`Record`] like any other value. Serializes as `{"$handle": ""}` +/// on the wire; the load side rehydrates it through +/// [`crate::load::Object::read`]. +#[derive(Debug, Clone)] +pub struct Handle(String); + +impl Handle { + pub(crate) fn new(string: String) -> Self { + Self(string) + } + + pub(crate) fn as_str(&self) -> &str { + &self.0 + } +} + +#[cfg(feature = "serde")] +impl Serialize for Handle { + fn serialize(&self, ser: S) -> Result { + let mut handle = ser.serialize_struct("Handle", 1)?; + handle.serialize_field("$handle", &self.0)?; + handle.end() + } +} + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for Handle { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper { + #[serde(rename = "$handle")] + handle: String, + } + let helper = Helper::deserialize(deserializer)?; + Ok(Handle(helper.handle)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_rejects_reserved_key() { + let mut record = Record::empty(); + record + .insert("$version", Value::Null) + .expect_err("reserved key must be rejected"); + record + .insert("ok", Value::Bool(true)) + .expect("normal key must be accepted"); + assert!(record.contains_key("ok")); + assert_eq!(record.len(), 1); + } + + #[cfg(feature = "disk")] + #[test] + fn deserialize_rejects_handle_with_extra_fields() { + let json = r#"{ "$handle": "a.bin", "$version": "0.0" }"#; + serde_json::from_str::>(json) + .expect_err("handle object with extra fields must be rejected"); + } + + #[cfg(feature = "disk")] + #[test] + fn deserialize_rejects_object_without_version_or_handle() { + let json = r#"{ "field": 1 }"#; + serde_json::from_str::>(json) + .expect_err("object without $version or $handle must be rejected"); + } + + // The non-finite float sentinels (`"nan"`, `"inf"`, `"neg_inf"`) are reserved wire + // tokens: a JSON string equal to one of them always deserializes back as the + // corresponding `Number::F64`, never as a `Value::String`. This is intentional — those + // exact strings are not used as user string values in this repository, so the + // round-trip ambiguity is resolved in favor of the float sentinel. + #[cfg(feature = "disk")] + #[test] + fn float_sentinel_strings_deserialize_as_numbers() { + for (s, is_nan) in [("nan", true), ("inf", false), ("neg_inf", false)] { + let json = serde_json::to_string(&Value::String(Cow::Borrowed(s))).unwrap(); + let back: Value<'static> = serde_json::from_str(&json).unwrap(); + match back { + Value::Number(Number::F64(v)) if is_nan => assert!(v.is_nan()), + Value::Number(Number::F64(_)) => {} + other => panic!("sentinel {s:?} deserialized as {other:?}, expected Number::F64"), + } + } + } +} diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs new file mode 100644 index 000000000..444e9dded --- /dev/null +++ b/diskann-record/src/version.rs @@ -0,0 +1,128 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Semver-style version stamps embedded in every saved object. + +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; + +/// A semver-style schema version attached to every saved record. +/// +/// Each [`crate::save::Save`] / [`crate::load::Load`] impl declares its +/// `const VERSION: Version`. On load, the version recorded in the manifest is compared +/// against the declared version to dispatch between +/// [`Load::load`](crate::load::Load::load) and +/// [`Load::load_legacy`](crate::load::Load::load_legacy). +/// +/// The framework treats versions as opaque pairs and only checks them for equality; +/// ordering / semver semantics are entirely up to the implementing type. +/// +/// The `major.minor` format aids code readability: a reader can tell at a glance +/// that, for example, version `1.0` is compatible with version `1.2`. +/// +/// On the wire, a `Version` is encoded as a single string of the form +/// `"major.minor"` (e.g. `"0.0"`). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Version { + pub major: u32, + pub minor: u32, +} + +impl Version { + /// Construct a [`Version`] from its two components. + pub const fn new(major: u32, minor: u32) -> Self { + Self { major, minor } + } +} + +impl std::fmt::Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.major, self.minor) + } +} + +impl std::str::FromStr for Version { + type Err = ParseVersionError; + + fn from_str(s: &str) -> Result { + let mut parts = s.split('.'); + let major = parts.next().and_then(|s| s.parse::().ok()); + let minor = parts.next().and_then(|s| s.parse::().ok()); + match (major, minor, parts.next()) { + (Some(major), Some(minor), None) => Ok(Version { major, minor }), + _ => Err(ParseVersionError(s.to_owned())), + } + } +} + +/// Error returned when a string cannot be parsed as a [`Version`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParseVersionError(String); + +impl std::fmt::Display for ParseVersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "unknown version {:?}: expected two `.`-separated u32 components", + self.0, + ) + } +} + +impl std::error::Error for ParseVersionError {} + +#[cfg(feature = "serde")] +impl Serialize for Version { + fn serialize(&self, ser: S) -> Result { + ser.collect_str(self) + } +} + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for Version { + fn deserialize>(de: D) -> Result { + struct VersionVisitor; + + impl de::Visitor<'_> for VersionVisitor { + type Value = Version; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a version string of the form \"major.minor\"") + } + + fn visit_str(self, v: &str) -> Result { + v.parse().map_err(E::custom) + } + } + + de.deserialize_str(VersionVisitor) + } +} + +#[cfg(all(test, feature = "disk"))] +mod tests { + use super::*; + + #[test] + fn serializes_as_dotted_string() { + let json = serde_json::to_string(&Version::new(1, 2)).unwrap(); + assert_eq!(json, "\"1.2\""); + } + + #[test] + fn round_trips_through_json() { + let v = Version::new(4, 5); + let back: Version = serde_json::from_str(&serde_json::to_string(&v).unwrap()).unwrap(); + assert_eq!(v, back); + } + + #[test] + fn rejects_malformed_strings() { + for bad in ["\"1\"", "\"1.2.3\"", "\"1.x\"", "\"abc\""] { + serde_json::from_str::(bad) + .expect_err("malformed version string must be rejected"); + } + } +}