From 45ddbcf0d127084161f9757bd0feea29ffd9c235 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Thu, 18 Jun 2026 10:46:29 -0700 Subject: [PATCH 01/23] split saveload PR into two -- core with all the Save[Saveable]/Load[Loadable] traits [THIS PR] + impls for structs [TODO] --- Cargo.lock | 10 + Cargo.toml | 2 + diskann-record/Cargo.toml | 19 + diskann-record/src/lib.rs | 565 +++++++++++++++++++++++++++++ diskann-record/src/load/context.rs | 398 ++++++++++++++++++++ diskann-record/src/load/error.rs | 200 ++++++++++ diskann-record/src/load/mod.rs | 243 +++++++++++++ diskann-record/src/number.rs | 176 +++++++++ diskann-record/src/save/context.rs | 206 +++++++++++ diskann-record/src/save/error.rs | 65 ++++ diskann-record/src/save/mod.rs | 266 ++++++++++++++ diskann-record/src/save/value.rs | 362 ++++++++++++++++++ diskann-record/src/version.rs | 82 +++++ 13 files changed, 2594 insertions(+) create mode 100644 diskann-record/Cargo.toml create mode 100644 diskann-record/src/lib.rs create mode 100644 diskann-record/src/load/context.rs create mode 100644 diskann-record/src/load/error.rs create mode 100644 diskann-record/src/load/mod.rs create mode 100644 diskann-record/src/number.rs create mode 100644 diskann-record/src/save/context.rs create mode 100644 diskann-record/src/save/error.rs create mode 100644 diskann-record/src/save/mod.rs create mode 100644 diskann-record/src/save/value.rs create mode 100644 diskann-record/src/version.rs diff --git a/Cargo.lock b/Cargo.lock index 1ecde9f9f..8f6f7272e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -912,6 +912,16 @@ dependencies = [ "trybuild", ] +[[package]] +name = "diskann-record" +version = "0.54.0" +dependencies = [ + "anyhow", + "serde", + "serde_json", + "tempfile", +] + [[package]] name = "diskann-tools" version = "0.54.0" diff --git a/Cargo.toml b/Cargo.toml index b285a94f6..d95ee5a1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "diskann-benchmark-core", "diskann-benchmark-simd", "diskann-benchmark", + "diskann-record", "diskann-tools", "vectorset", "diskann-bftree", @@ -66,6 +67,7 @@ diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.54. diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.54.0" } diskann-tools = { path = "diskann-tools", version = "0.54.0" } diskann-bftree = {path = "diskann-bftree", version = "0.54.0" } +diskann-record = { path = "diskann-record", version = "0.54.0" } # External dependencies (shared versions) anyhow = "1.0.98" diff --git a/diskann-record/Cargo.toml b/diskann-record/Cargo.toml new file mode 100644 index 000000000..8c26ae775 --- /dev/null +++ b/diskann-record/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "diskann-record" +version.workspace = true +description.workspace = true +authors.workspace = true +repository.workspace = true +license.workspace = true +edition = "2024" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true + +[dev-dependencies] +tempfile.workspace = true + +[lints] +workspace = true diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs new file mode 100644 index 000000000..0a94b277c --- /dev/null +++ b/diskann-record/src/lib.rs @@ -0,0 +1,565 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Versioned Save/Load for DiskANN +//! +//! This crate provides a small framework for persisting structured Rust values to disk +//! as a JSON manifest plus a set of side-car binary artifacts, and reloading them later. +//! It is the substrate used by `diskann` providers and indexes to implement durable +//! checkpoints. +//! +//! The model is: +//! +//! * Each [`save::Save`] / [`load::Load`] implementation describes how a single Rust type +//! maps to a [`save::Record`] (a versioned map of named fields). +//! * Field values are either [`save::Value`]s embedded directly in the manifest, or +//! [`save::Handle`]s pointing at side-car binary artifacts written via the +//! [`save::Context`]. +//! * Every record carries a [`Version`] so that loaders can detect schema changes and +//! either upgrade ([`load::Load::load_legacy`]) or fall back through a probing chain +//! (see [`load::Error::is_recoverable`]). +//! +//! # Entry Points +//! +//! - [`save::save_to_disk`]: Save a value to a directory plus a manifest path. +//! - [`load::load_from_disk`]: Reload a value from a manifest and its artifact directory. +//! +//! # Defining Save / Load +//! +//! User code is expected to implement [`save::Save`] and [`load::Load`] for the types it +//! wants to persist. For plain structs, the [`save_fields!`] and [`load_fields!`] macros +//! handle the field-by-field plumbing. See [`save`] and [`load`] for the relevant traits +//! and helpers. +//! +//! ## Example +//! +//! ```ignore +//! use diskann_record::{Version, save, load}; +//! +//! #[derive(Debug, PartialEq)] +//! struct Config { dim: usize, label: String } +//! +//! impl save::Save for Config { +//! const VERSION: Version = Version::new(0, 0, 0); +//! fn save(&self, context: save::Context<'_>) -> save::Result> { +//! Ok(diskann_record::save_fields!(self, context, [dim, label])) +//! } +//! } +//! +//! impl load::Load<'_> for Config { +//! const VERSION: Version = Version::new(0, 0, 0); +//! fn load(object: load::Object<'_>) -> load::Result { +//! diskann_record::load_fields!(object, [dim: usize, label: String]); +//! Ok(Self { dim, label }) +//! } +//! fn load_legacy(_: load::Object<'_>) -> load::Result { +//! Err(load::error::Kind::UnknownVersion.into()) +//! } +//! } +//! ``` +//! +//! # Wire Format +//! +//! The manifest is JSON. Every object carries a `$version` field; side-car artifacts are +//! referenced through `$handle` strings whose value is a file name relative to the +//! manifest directory. Keys beginning with `$` are reserved for framework metadata and +//! cannot be used as user field names (see [`is_reserved`]). +//! +//! # Platform Requirements +//! +//! `usize` and `isize` are serialized as 64-bit numbers. The crate statically asserts +//! that `usize::BITS == 64` to guarantee that the saver never produces values the +//! canonical wire width cannot represent. Loaders still range-check at runtime. +//! +//! # Error Handling +//! +//! Both [`save::Error`] and [`load::Error`] wrap [`anyhow::Error`] for rich context +//! chains. Load errors additionally carry a recoverable / critical bit, used by probing +//! call sites to decide whether to fall back to an alternative loader. See +//! [`load::error::Kind`] for the classification. + +mod number; +pub use number::Number; + +mod version; +pub use version::Version; + +pub mod load; +pub mod save; + +// Canonical wire width for `usize` and `isize` in manifests is 64 bits. Saving a value +// on a 64-bit platform and loading it on a 32-bit platform (or vice versa) could silently +// truncate values that exceed `u32::MAX` / `i32::MAX`. We therefore require a 64-bit +// platform at compile time. Loaders still range-check at runtime, but this check ensures +// the saver never emits values that the canonical width cannot represent. +const _: () = assert!( + usize::BITS == 64, + "diskann-record requires a 64-bit target: usize/isize MUST be 64 bits wide !!", +); + +/// Return `true` if `s` is a reserved manifest key. +/// +/// Keys beginning with `$` are reserved for framework metadata (e.g. `$version`, +/// `$handle`) and may not be used as user field names. Attempting to insert one via +/// [`save::Record::insert`] returns an error. +#[doc(hidden)] +pub const fn is_reserved(s: &str) -> bool { + if let Some(first) = s.as_bytes().first() + && *first == b"$"[0] + { + true + } else { + false + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Read, Write}; + use std::path::{Path, PathBuf}; + + #[derive(Debug, PartialEq)] + struct Test { + x: String, + y: f32, + enabled: bool, + inner: Inner, + // We write this as a binary file. + vector: Vec, + nickname: Option, + absent: Option, + } + + #[derive(Debug, PartialEq)] + struct Inner { + z: usize, + w: Vec, + flags: Vec, + maybe_count: Option, + maybe_missing: Option, + sparse: Vec>, + } + + impl save::Save for Inner { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!( + self, + context, + [z, w, flags, maybe_count, maybe_missing, sparse] + )) + } + } + + impl save::Save for Test { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + // We save `x`, `y`, and `inner` directly into the manifest. + // The raw vector data we instead store in an auxiliary file. + + let mut io = context.write("auxiliary.bin")?; + io.write_all(&self.vector).map_err(save::Error::new)?; + + let mut record = save_fields!(self, context, [x, y, enabled, inner, nickname, absent]); + record.insert("vector", io.finish()?)?; + Ok(record) + } + } + + impl load::Load<'_> for Test { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + x, + y, + enabled, + inner, + nickname: Option, + absent: Option, + vector: save::Handle, + ] + ); + + let mut io = object.read(&vector)?; + let mut vector = Vec::new(); + io.read_to_end(&mut vector).unwrap(); + + Ok(Self { + x, + y, + enabled, + inner, + vector, + nickname, + absent, + }) + } + + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + impl load::Load<'_> for Inner { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + z, + w, + flags, + maybe_count: Option, + maybe_missing: Option, + sparse: Vec>, + ] + ); + Ok(Self { + z, + w, + flags, + maybe_count, + maybe_missing, + sparse, + }) + } + + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn round_trip_uses_isolated_temp_dir() -> anyhow::Result<()> { + let inner = Inner { + z: 10, + w: vec![-1, -2, -3], + flags: vec![true, false, true], + maybe_count: Some(42), + maybe_missing: None, + sparse: vec![Some(1), None, Some(-3), None], + }; + + let t = Test { + x: "hello".into(), + y: 5.0, + enabled: true, + inner, + vector: vec![0, 1, 2, 3, 4, 5], + nickname: Some("friend".into()), + absent: None, + }; + + // Keep the TempDir guard alive for the full round trip; Drop removes the + // manifest and auxiliary artifact after the assertion completes. + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&t, dir, &metadata)?; + let we_are_back: Test = load::load_from_disk(&metadata, dir)?; + + assert_eq!(t, we_are_back); + Ok(()) + } + + ///////////////////////// + // Enum support: round // + ///////////////////////// + + #[derive(Debug, PartialEq)] + enum Metric { + L2, + Cosine, + Weighted { weights: Vec }, + } + + impl save::Save for Metric { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + let mut record = save::Record::empty(); + match self { + Self::L2 => { + record.insert("L2", save::Value::Null)?; + } + Self::Cosine => { + record.insert("Cosine", save::Value::Null)?; + } + Self::Weighted { weights } => { + let payload = save_fields!(context, [weights]).into_value(Self::VERSION); + record.insert("Weighted", payload)?; + } + } + Ok(record) + } + } + + impl load::Load<'_> for Metric { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + match object.single_key()? { + "L2" => Ok(Self::L2), + "Cosine" => Ok(Self::Cosine), + "Weighted" => { + let inner = object + .child("Weighted")? + .as_object() + .ok_or(load::error::Kind::TypeMismatch)?; + load_fields!(inner, [weights: Vec]); + Ok(Self::Weighted { weights }) + } + _ => Err(load::error::Kind::UnknownVariant.into()), + } + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[derive(Debug, PartialEq)] + struct MetricBag { + primary: Metric, + alternatives: Vec, + } + + impl save::Save for MetricBag { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [primary, alternatives])) + } + } + + impl load::Load<'_> for MetricBag { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [primary: Metric, alternatives: Vec]); + Ok(Self { + primary, + alternatives, + }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn enum_round_trip_through_disk() -> anyhow::Result<()> { + let bag = MetricBag { + primary: Metric::Weighted { + weights: vec![0.25, 0.5, 0.25], + }, + alternatives: vec![ + Metric::L2, + Metric::Cosine, + Metric::Weighted { weights: vec![1.0] }, + ], + }; + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&bag, dir, &metadata)?; + let restored: MetricBag = load::load_from_disk(&metadata, dir)?; + + assert_eq!(bag, restored); + Ok(()) + } + + #[derive(Debug, PartialEq)] + struct StructShape { + x: i32, + } + + impl save::Save for StructShape { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [x])) + } + } + + impl load::Load<'_> for StructShape { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [x: i32]); + Ok(Self { x }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[derive(Debug, PartialEq)] + enum EnumShape { + Only { x: i32 }, + } + + impl save::Save for EnumShape { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + let mut record = save::Record::empty(); + match self { + Self::Only { x } => { + let payload = save_fields!(context, [x]).into_value(Self::VERSION); + record.insert("Only", payload)?; + } + } + Ok(record) + } + } + + impl load::Load<'_> for EnumShape { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + match object.single_key()? { + "Only" => { + let inner = object + .child("Only")? + .as_object() + .ok_or(load::error::Kind::TypeMismatch)?; + load_fields!(inner, [x: i32]); + Ok(Self::Only { x }) + } + _ => Err(load::error::Kind::UnknownVariant.into()), + } + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn loading_enum_as_struct_is_rejected() -> anyhow::Result<()> { + // Enum data has a single key "Only" whose payload is a versioned + // sub-object. Loading it as `StructShape` (which expects field `x`) + // surfaces `MissingField`. + let value = EnumShape::Only { x: 7 }; + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("loading enum data into a struct shape should fail"); + let msg = format!("{err}"); + assert!( + msg.contains("missing field"), + "expected MissingField error, got: {msg}" + ); + Ok(()) + } + + #[test] + fn loading_struct_as_enum_is_rejected() -> anyhow::Result<()> { + // Struct data has field `x`, which the enum loader sees as a candidate + // variant name. It doesn't match any arm, so we get `UnknownVariant`. + let value = StructShape { x: 7 }; + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("loading struct data into an enum shape should fail"); + let msg = format!("{err}"); + assert!( + msg.contains("unknown variant"), + "expected UnknownVariant error, got: {msg}" + ); + Ok(()) + } + + /////////////////////////////// + // Manifest directory escape // + /////////////////////////////// + + /// Minimal loadable type with a single handle field. Used by the + /// directory-escape tests below to drive `Object::read` against a + /// hand-crafted manifest. + #[derive(Debug)] + struct HandleOnly { + _blob: Vec, + } + + impl load::Load<'_> for HandleOnly { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [blob: save::Handle]); + let mut io = object.read(&blob)?; + let mut buf = Vec::new(); + io.read_to_end(&mut buf).map_err(load::Error::new)?; + Ok(Self { _blob: buf }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + /// Write a hand-crafted manifest into `dir` whose root object exposes a + /// `blob` field referencing `handle_target`. Returns the metadata path. + fn write_handle_manifest(dir: &Path, handle_target: &str) -> std::io::Result { + let manifest = serde_json::json!({ + // Register the same target in `files` so the membership check + // would otherwise let it through — this isolates the new + // path-shape check as the thing rejecting the load. + "files": [handle_target], + "value": { + "$version": "0.0.0", + "blob": { "$handle": handle_target }, + }, + }); + let metadata = dir.join("metadata.json"); + std::fs::write(&metadata, serde_json::to_vec(&manifest)?)?; + Ok(metadata) + } + + #[test] + fn handle_with_parent_traversal_is_rejected() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = write_handle_manifest(dir, "../escape.bin")?; + + let err = load::load_from_disk::(&metadata, dir) + .expect_err("handle escaping the manifest directory must be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("escapes the manifest directory"), + "expected manifest-escape rejection, got: {msg}" + ); + Ok(()) + } + + #[test] + fn handle_with_absolute_path_is_rejected() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + // Use a platform-appropriate absolute path. Both shapes should be + // rejected on their respective platforms; we test the native one. + let absolute = if cfg!(windows) { + "C:\\Windows\\System32\\drivers\\etc\\hosts" + } else { + "/etc/passwd" + }; + let metadata = write_handle_manifest(dir, absolute)?; + + let err = load::load_from_disk::(&metadata, dir) + .expect_err("absolute-path handle must be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("escapes the manifest directory"), + "expected manifest-escape rejection, got: {msg}" + ); + Ok(()) + } +} diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs new file mode 100644 index 000000000..f3e03ac98 --- /dev/null +++ b/diskann-record/src/load/context.rs @@ -0,0 +1,398 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Load-side context, object, array, and side-car reader. +//! +//! [`Context`] is the cheap, clonable handle threaded through every +//! [`super::Load::load`] / [`super::Loadable::load`] impl. From a [`Context`], loaders +//! ask: +//! +//! * [`Context::as_object`] / [`Object::field`] for nested records. +//! * [`Context::as_array`] / [`Array::iter`] for sequences. +//! * [`Context::as_str`] / [`Context::as_number`] / [`Context::as_bool`] / [`Context::is_null`] for scalars. +//! * [`Object::read`] for side-car artifacts referenced by a +//! [`save::Handle`](super::save::Handle). +//! +//! [`Reader`] implements [`std::io::Read`] and [`std::io::Seek`] over the artifact file. + +use std::{ + collections::HashSet, + fs::File, + io::BufReader, + path::{Path, PathBuf}, +}; + +use crate::{ + Number, Version, + load::{Error, Loadable, Result, error}, + save, +}; + +#[derive(Debug, serde::Deserialize)] +pub(super) struct ContextInner { + dir: PathBuf, + files: HashSet, + value: save::Value<'static>, +} + +#[derive(Debug, serde::Deserialize)] +struct FileRepr { + files: HashSet, + value: save::Value<'static>, +} + +impl ContextInner { + pub(super) fn new(metadata: &Path, dir: &Path) -> Result { + let file = std::fs::File::open(metadata).map_err(|e| { + Error::new(e).context(format!("while trying to open {}", metadata.display())) + })?; + + let reader = std::io::BufReader::new(file); + let repr: FileRepr = serde_json::from_reader(reader) + .map_err(|e| Error::new(e).context("could not deserialize manifest"))?; + + let this = Self { + dir: dir.into(), + files: repr.files, + value: repr.value, + }; + Ok(this) + } + + pub(super) fn read(&self, key: &str) -> Result> { + let key_as_path: &Path = key.as_ref(); + if key.contains("..") || key_as_path.is_absolute() { + return Err(Error::from(error::Kind::MissingFile).context(format!( + "handle references file {:?} which escapes the manifest directory", + key, + ))); + } + if !self.files.contains(key_as_path) { + return Err(Error::from(error::Kind::MissingFile).context(format!( + "handle references file {:?} which is not registered in the manifest", + key, + ))); + } + + let full = self.dir.join(key); + let file = std::fs::File::open(&full).map_err(|err| { + Error::new(err).context(format!("while opening artifact file {}", full.display())) + })?; + let reader = Reader { + io: BufReader::new(file), + _lifetime: std::marker::PhantomData, + }; + + Ok(reader) + } + + pub(super) fn context(&self) -> Context<'_> { + Context::new(self, &self.value) + } +} + +/// A borrowed reader over a side-car artifact. +/// +/// Produced by [`Object::read`]. Implements [`std::io::Read`] and [`std::io::Seek`]. +pub struct Reader<'a> { + io: BufReader, + _lifetime: std::marker::PhantomData<&'a ()>, +} + +impl std::io::Read for Reader<'_> { + // Required method + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.io.read(buf) + } + + // Provided methods + fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result { + self.io.read_vectored(bufs) + } + fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { + self.io.read_to_end(buf) + } + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { + self.io.read_to_string(buf) + } + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + self.io.read_exact(buf) + } +} + +impl std::io::Seek for Reader<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.io.seek(pos) + } + + fn rewind(&mut self) -> std::io::Result<()> { + self.io.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.io.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.io.seek_relative(offset) + } +} + +/////////////////////// +// User facing types // +/////////////////////// + +/// A cheap, clonable handle threaded through every load impl. +/// +/// Loaders use the `as_*` accessors to peek at the underlying [`save::Value`] kind +/// (e.g. [`Context::as_object`], [`Context::as_array`], [`Context::as_str`]) and +/// [`Context::load`] to recursively deserialize a nested value into a concrete type. +#[derive(Debug, Clone)] +pub struct Context<'a> { + inner: &'a ContextInner, + value: &'a save::Value<'a>, +} + +impl<'a> Context<'a> { + fn new(inner: &'a ContextInner, value: &'a save::Value<'a>) -> Self { + Self { inner, value } + } + + fn context(&self) -> &'a ContextInner { + self.inner + } + + /// Recursively deserialize the underlying value into a `T`. + /// + /// Equivalent to calling `T::load(self.clone())`. Use this from inner loaders that + /// want to delegate to another [`Loadable`]. + pub fn load(&self) -> Result + where + T: Loadable<'a>, + { + T::load(self.clone()) + } + + /// Returns `Some(Object)` if the value is a versioned object, else `None`. + pub fn as_object(&self) -> Option> { + match self.value { + save::Value::Object(versioned) => { + let object = Object { + inner: self.inner, + record: versioned.record(), + version: versioned.version(), + }; + Some(object) + } + _ => None, + } + } + + /// Returns `Some(s)` if the value is a string, else `None`. + pub fn as_str(&self) -> Option<&'a str> { + match self.value { + save::Value::String(s) => Some(s), + _ => None, + } + } + + /// Returns `Some(Array)` if the value is an array, else `None`. + pub fn as_array(&self) -> Option> { + match self.value { + save::Value::Array(array) => Some(Array::new(self.context(), array)), + _ => None, + } + } + + /// Returns `Some(Number)` if the value is numeric, else `None`. + /// + /// Use the conversion methods on [`Number`] (e.g. `as_u32`, `as_i64`) to narrow to + /// the target Rust type; out-of-range conversions return `None` and should be + /// surfaced as [`error::Kind::NumberOutOfRange`]. + pub fn as_number(&self) -> Option { + match self.value { + save::Value::Number(number) => Some(*number), + _ => None, + } + } + + /// Returns `Some(b)` if the value is a boolean, else `None`. + pub fn as_bool(&self) -> Option { + match self.value { + save::Value::Bool(value) => Some(*value), + _ => None, + } + } + + /// Returns `true` if the value is null. + /// + /// Used by [`Loadable`] impls for [`Option`] to detect the absent variant. + pub fn is_null(&self) -> bool { + matches!(self.value, save::Value::Null) + } + + pub(crate) fn as_handle(&self) -> Option<&save::Handle> { + match self.value { + save::Value::Handle(handle) => Some(handle), + _ => None, + } + } +} + +/// A versioned record reached through [`Context::as_object`]. +/// +/// `Object` is the entry point for record-based deserialization: it exposes the schema +/// version via [`Object::version`], the user keys via [`Object::keys`], typed field +/// extraction via [`Object::field`] (and the [`load_fields!`](crate::load_fields) +/// macro), and side-car artifact access via [`Object::read`]. +#[derive(Debug)] +pub struct Object<'a> { + inner: &'a ContextInner, + record: &'a save::Record<'a>, + version: Version, +} + +impl<'a> Object<'a> { + /// The schema [`Version`] recorded in the manifest for this object. + pub fn version(&self) -> Version { + self.version + } + + /// Iterate over the user keys of this record. Reserved keys (`$version`, + /// `$handle`) are tracked separately and never appear here. + pub fn keys(&self) -> save::Keys<'_, 'a> { + self.record.keys() + } + + /// Number of user keys in this record. + pub fn len(&self) -> usize { + self.record.len() + } + + /// Whether this record has no user keys. + pub fn is_empty(&self) -> bool { + self.record.is_empty() + } + + /// Return the sole user key of this record, used by enum loaders to dispatch + /// to a variant arm. Errors with a recoverable [`error::Kind::TypeMismatch`] + /// if the record has zero or multiple user keys (i.e. the wire shape does + /// not look like an enum). + pub fn single_key(&self) -> Result<&str> { + let mut keys = self.record.keys(); + let Some(first) = keys.next() else { + return Err(error::Kind::TypeMismatch.into()); + }; + if keys.next().is_some() { + return Err(error::Kind::TypeMismatch.into()); + } + Ok(first) + } + + /// Descend into the raw [`Context`] for `key`, without imposing a type. + /// Useful for enum variants whose payload is itself an [`Object`], an array, + /// or any other [`save::Value`]. Returns [`error::Kind::MissingField`] when + /// the key is absent. + pub fn child(&self, key: &str) -> Result> { + match self.record.get(key) { + Some(value) => Ok(Context::new(self.context(), value)), + None => Err(error::Kind::MissingField.into()), + } + } + + /// Extract the value under `key` and deserialize it into a `T`. + /// + /// This is the typed counterpart to [`Object::child`] and the primitive used by the + /// [`load_fields!`](crate::load_fields) macro. + /// + /// # Errors + /// + /// Returns [`error::Kind::MissingField`] if the key is absent. Errors raised by + /// `T::load` (e.g. [`error::Kind::TypeMismatch`]) are propagated unchanged. + pub fn field(&self, key: &str) -> Result + where + T: Loadable<'a>, + { + match self.record.get(key) { + Some(value) => T::load(Context::new(self.context(), value)), + None => Err((error::Kind::MissingField).into()), + } + } + + /// Open the side-car artifact identified by `handle` for reading. + /// + /// The handle must have been previously written through the matching + /// [`save::Context::write`](super::save::Context::write) call and embedded in this + /// record. Returns [`error::Kind::MissingFile`] if the file is not registered in the + /// manifest or if the handle attempts to escape the manifest directory. + pub fn read(&self, handle: &save::Handle) -> Result> { + self.inner.read(handle.as_str()) + } + + fn context(&self) -> &'a ContextInner { + self.inner + } +} + +/// A homogeneous sequence of values reached through [`Context::as_array`]. +/// +/// Backed by a borrowed `&[Value]`. Use [`Array::iter`] to walk the elements; each +/// item is yielded as a [`Context`] that can be further deserialized via +/// [`Context::load`]. +#[derive(Debug)] +pub struct Array<'a> { + inner: &'a ContextInner, + array: &'a [save::Value<'a>], +} + +impl<'a> Array<'a> { + fn new(inner: &'a ContextInner, array: &'a [save::Value<'a>]) -> Self { + Self { inner, array } + } + + /// Number of elements in the array. + pub fn len(&self) -> usize { + self.array.len() + } + + /// Returns `true` if the array is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Iterate over the elements, each as a [`Context`] ready for further deserialization. + pub fn iter(&self) -> Iter<'a> { + Iter::new(self.context(), self.array.iter()) + } + + fn context(&self) -> &'a ContextInner { + self.inner + } +} + +/// Iterator returned by [`Array::iter`]. +pub struct Iter<'a> { + inner: &'a ContextInner, + iter: std::slice::Iter<'a, save::Value<'a>>, +} + +impl<'a> Iter<'a> { + fn new(inner: &'a ContextInner, iter: std::slice::Iter<'a, save::Value<'a>>) -> Self { + Self { inner, iter } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Context<'a>; + fn next(&mut self) -> Option { + self.iter + .next() + .map(|value| Context::new(self.inner, value)) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl ExactSizeIterator for Iter<'_> {} diff --git a/diskann-record/src/load/error.rs b/diskann-record/src/load/error.rs new file mode 100644 index 000000000..df17daf23 --- /dev/null +++ b/diskann-record/src/load/error.rs @@ -0,0 +1,200 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Load-side error type and classification. +//! +//! The [`Error`] type wraps [`anyhow::Error`] for rich diagnostics and carries a +//! recoverable / critical bit used by probing call sites. The [`Kind`] enum enumerates +//! the well-known structural failure modes; [`Kind::is_recoverable`] is the canonical +//! source of truth for the recoverable / critical classification. + +use std::fmt::{Debug, Display}; + +/// A specialized [`std::result::Result`] for load-side operations. +pub type Result = ::std::result::Result; + +/// Load-side error. +/// +/// Carries an inner [`anyhow::Error`] for rich diagnostics (chained context, +/// backtraces) along with a single `recoverable` bit. Recoverable errors are +/// the contract for probing APIs: a caller that tries multiple load strategies +/// (e.g. current version, then legacy) can distinguish "this attempt didn't +/// match, try another" from "the data is broken, stop now". +/// +/// Most constructors produce *critical* (non-recoverable) errors. Probing +/// call sites use the explicit `*_recoverable` constructors, or rely on the +/// [`From`] impl which classifies each [`Kind`] variant according to +/// [`Kind::is_recoverable`]. +#[derive(Debug)] +pub struct Error { + inner: anyhow::Error, + recoverable: bool, +} + +impl Error { + /// Construct a critical error from an underlying source error. + pub fn new(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::new(err), + recoverable: false, + } + } + + /// Construct a critical error from a display message. + pub fn message(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::msg(message), + recoverable: false, + } + } + + /// Construct a recoverable error from an underlying source. Suitable for + /// probing APIs that may attempt an alternative load strategy. + pub fn new_recoverable(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::new(err), + recoverable: true, + } + } + + /// Construct a recoverable error from a display message. Suitable for + /// probing APIs that may attempt an alternative load strategy. + pub fn message_recoverable(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Self { + inner: anyhow::Error::msg(message), + recoverable: true, + } + } + + /// Attach additional context. The `recoverable` flag is preserved. + pub fn context(self, message: D) -> Self + where + D: Display + Send + Sync + 'static, + { + Self { + inner: self.inner.context(message), + recoverable: self.recoverable, + } + } + + /// Returns `true` if this error is recoverable. Probing call sites should + /// only fall back to alternative load strategies when this is `true`. + pub fn is_recoverable(&self) -> bool { + self.recoverable + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Load Error: {:?}", self.inner) + } +} + +impl std::error::Error for Error { + /// Returns the lower-level source of this error, if it exists. + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.inner.as_ref()) + } +} + +/// Well-known classes of load-side failure. +/// +/// Used in two roles: +/// +/// * As the source of an [`Error`] via `From` (and the matching `From` for +/// [`Error`] which classifies recoverable / critical according to +/// [`Kind::is_recoverable`]). +/// * As a probe value in error chains — high-level callers can introspect the kind to +/// decide whether to try a fallback loader. +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum Kind { + /// The manifest's `$version` does not match the loader's expected + /// [`Load::VERSION`](crate::load::Load::VERSION). + VersionMismatch, + /// A required field is absent from the record. + MissingField, + /// The shape of the saved value does not match what the loader expected (e.g. found + /// an array where an object was needed). + TypeMismatch, + /// The manifest's version is recognized as not matching the current schema, and the + /// type's [`Load::load_legacy`](crate::load::Load::load_legacy) has no upgrade path + /// for it. + UnknownVersion, + /// The variant tag read from the wire format does not match any known + /// variant of the target enum. + UnknownVariant, + /// A numeric value in the manifest does not fit in the requested Rust type + /// (either out of range or would lose precision). + NumberOutOfRange, + /// A `$handle` references a file name that is not registered in the + /// manifest's `files` set. + MissingFile, +} + +impl Kind { + /// Stable, human-readable description of this kind. Used as the default error + /// message when constructing an [`Error`] from a `Kind`. + pub const fn as_str(self) -> &'static str { + match self { + Self::VersionMismatch => "version mismatch", + Self::MissingField => "missing field", + Self::TypeMismatch => "type mismatch", + Self::UnknownVersion => "unknown version", + Self::UnknownVariant => "unknown variant", + Self::NumberOutOfRange => "number out of range for target type", + Self::MissingFile => "handle references a file not present in the manifest", + } + } + + /// Whether an error of this kind should be treated as recoverable by + /// probing APIs (i.e., suitable for triggering a fallback to an alternative + /// load strategy). + /// + /// Recoverable kinds describe "the data did not match what this loader + /// expected" (a different version or shape might still succeed). Critical + /// kinds describe structural or integrity problems where retrying would be + /// pointless or unsafe. + pub const fn is_recoverable(self) -> bool { + match self { + // Shape/version probing signals — another loader might succeed. + Self::VersionMismatch | Self::MissingField | Self::TypeMismatch => true, + // Structural / integrity failures — give up. + Self::UnknownVersion + | Self::UnknownVariant + | Self::NumberOutOfRange + | Self::MissingFile => false, + } + } +} + +impl std::error::Error for Kind {} + +impl std::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl From for Error { + fn from(kind: Kind) -> Self { + Self { + inner: anyhow::Error::new(kind), + recoverable: kind.is_recoverable(), + } + } +} diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs new file mode 100644 index 000000000..032cd4f27 --- /dev/null +++ b/diskann-record/src/load/mod.rs @@ -0,0 +1,243 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Loading Records from Disk +//! +//! This module mirrors the [`super::save`] side. User types implement [`Load`] (or, for +//! primitive-like leaves, [`Loadable`]) and obtain an [`Object`] / [`Context`] from which +//! they extract individual fields and side-car artifacts. +//! +//! The top-level entry point is [`load_from_disk`], which reads a manifest and dispatches +//! into the user type's [`Load`] impl. +//! +//! # Reading Records +//! +//! The [`load_fields!`](crate::load_fields) macro is the idiomatic way to extract a fixed +//! set of named fields from an [`Object`] into local bindings. It mirrors the structure +//! of [`save_fields!`](crate::save_fields). +//! +//! # Version Dispatch +//! +//! Each [`Load`] impl declares a [`VERSION`](Load::VERSION). If the version stored in the +//! manifest matches, [`Load::load`] is called. Otherwise [`Load::load_legacy`] is invoked +//! so the impl can perform a custom upgrade; returning an +//! [`error::Kind::UnknownVersion`] from `load_legacy` indicates the loader has no upgrade +//! path for that schema. +//! +//! # Recoverable vs. Critical Errors +//! +//! Load errors are tagged as recoverable or critical. Probing call sites that try +//! multiple loaders should only retry when [`Error::is_recoverable`] returns `true`. See +//! [`error::Kind::is_recoverable`] for the classification. + +pub mod error; +pub use error::{Error, Result}; + +mod context; +pub use context::{Context, Object}; + +use std::path::Path; + +use crate::{Version, save}; + +/// Reload a value previously written by [`save::save_to_disk`]. +/// +/// `metadata` is the manifest JSON path produced by the saver, and `dir` is the +/// directory holding any side-car artifacts. +/// +/// # Errors +/// +/// Returns [`Error`] if the manifest is missing or malformed, if a referenced artifact is +/// missing, or if a user [`Load`] impl fails (e.g. due to a version mismatch with no +/// upgrade path). +pub fn load_from_disk(metadata: &Path, dir: &Path) -> Result +where + T: for<'a> Loadable<'a>, +{ + let inner = context::ContextInner::new(metadata, dir)?; + inner.context().load() +} + +/// Implemented by user types that can be reloaded from a versioned [`Object`]. +/// +/// This is the symmetric counterpart to [`super::save::Save`]. Implementations describe +/// how to reconstruct `Self` from the manifest representation, and how to upgrade +/// records written by older schemas via [`Self::load_legacy`]. +/// +/// # Enums +/// +/// Enum types dispatch on the single non-reserved key of the object (see +/// [`Object::single_key`]) and recurse via [`Object::child`] into the payload. +pub trait Load<'a>: Sized { + /// The schema version this impl was written against. + /// + /// Compared with the manifest's version to choose between [`Self::load`] and + /// [`Self::load_legacy`]. + const VERSION: Version; + + /// Reconstruct `Self` from an object whose `$version` matches [`Self::VERSION`]. + fn load(object: Object<'a>) -> Result; + + /// Reconstruct `Self` from an object whose `$version` does *not* match + /// [`Self::VERSION`]. + /// + /// Implementations may either upgrade the older record or refuse with + /// [`error::Kind::UnknownVersion`] when no upgrade is possible. + fn load_legacy(object: Object<'a>) -> Result; +} + +/// Implemented by any value that can be deserialized from a [`Context`]. +/// +/// This is the bottom of the trait hierarchy and is implemented for the same set of +/// primitive-like types as [`super::save::Saveable`]. Most user types should implement +/// [`Load`] (which gets a [`Loadable`] impl for free via the blanket below) rather than +/// [`Loadable`] directly. +pub trait Loadable<'a>: Sized { + /// Deserialize `Self` from a [`Context`]. + fn load(context: Context<'a>) -> Result; +} + +impl<'a, T> Loadable<'a> for T +where + T: Load<'a>, +{ + fn load(context: Context<'a>) -> Result { + let object = context.as_object().ok_or(error::Kind::TypeMismatch)?; + let version = object.version(); + if version == T::VERSION { + T::load(object) + } else { + T::load_legacy(object) + } + } +} + +//////////// +// Macros // +//////////// + +/// Extract a fixed set of named fields from an [`Object`] into local bindings. +/// +/// Each name in the list becomes a `let` binding of the same name. An optional `: T` +/// suffix selects the [`Loadable`] target type; without it, type inference picks the +/// type from the surrounding context. Errors from individual fields are propagated with +/// `?`. +/// +/// ```ignore +/// load_fields!(object, [ +/// dim: usize, +/// label, // type inferred +/// vectors: save::Handle, +/// ]); +/// ``` +#[macro_export] +macro_rules! load_fields { + (@field $object:ident, $field:ident: $T:ty) => { + let $field: $T = $object.field(stringify!($field))?; + }; + (@field $object:ident, $field:ident) => { + let $field = $object.field(stringify!($field))?; + }; + ($object:ident, [$($field:ident $(: $ty:ty)?),+ $(,)?]) => { + $( + $crate::load_fields!(@field $object, $field $(: $ty)?); + )+ + }; +} + +/////////////// +// Bootstrap // +/////////////// + +impl<'a> Loadable<'a> for &'a str { + fn load(context: Context<'a>) -> Result { + context + .as_str() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl Loadable<'_> for String { + fn load(context: Context<'_>) -> Result { + context.load::<&str>().map(|s| s.into()) + } +} + +impl Loadable<'_> for save::Handle { + fn load(context: Context<'_>) -> Result { + context + .as_handle() + .cloned() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl Loadable<'_> for bool { + fn load(context: Context<'_>) -> Result { + context + .as_bool() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl<'a, T> Loadable<'a> for Option +where + T: Loadable<'a>, +{ + fn load(context: Context<'a>) -> Result { + if context.is_null() { + Ok(None) + } else { + T::load(context).map(Some) + } + } +} + +impl<'a, T> Loadable<'a> for Vec +where + T: Loadable<'a>, +{ + fn load(context: Context<'a>) -> Result { + match context.as_array() { + Some(array) => array.iter().map(T::load).collect(), + None => Err((error::Kind::TypeMismatch).into()), + } + } +} + +macro_rules! load_number { + ($T:ty) => { + impl Loadable<'_> for $T { + fn load(context: Context<'_>) -> Result { + match context.as_number() { + Some(n) => n.try_into().map_err(|_| error::Kind::NumberOutOfRange.into()), + None => Err((error::Kind::TypeMismatch).into()), + } + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(load_number!($Ts);)+ + } +} + +load_number!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, f32, f64); + +// NonZero* primitives are loaded by deserializing the inner numeric type and then +// validating it is non-zero. A zero value produces a `NumberOutOfRange` light error. +macro_rules! load_nonzero { + ($T:ty, $Inner:ty) => { + impl Loadable<'_> for $T { + fn load(context: Context<'_>) -> Result { + let inner: $Inner = context.load()?; + <$T>::new(inner).ok_or_else(|| error::Kind::NumberOutOfRange.into()) + } + } + }; +} + +load_nonzero!(std::num::NonZeroU32, u32); +load_nonzero!(std::num::NonZeroU64, u64); +load_nonzero!(std::num::NonZeroUsize, usize); diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs new file mode 100644 index 000000000..33b635dc2 --- /dev/null +++ b/diskann-record/src/number.rs @@ -0,0 +1,176 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Lossless container for the on-wire numeric kinds (`u64`, `i64`, `f64`). +//! +//! [`Number`] is the value type produced by the manifest deserializer for every JSON +//! number. The conversion accessors (`as_u32`, `as_i64`, etc.) attempt to narrow into a +//! target Rust type and return `None` when the value is out of range or would lose +//! precision; loaders surface this as [`crate::load::error::Kind::NumberOutOfRange`]. + +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A numeric value carried in a manifest, preserving the kind the writer chose. +/// +/// The wire format distinguishes unsigned, signed, and floating-point numbers; the +/// deserializer preserves that distinction by selecting the matching variant. Use the +/// narrowing accessors (e.g. [`Number::as_u32`], [`Number::as_f64`]) to extract a Rust +/// value of the desired type. +#[derive(Debug, Clone, Copy)] +pub enum Number { + U64(u64), + I64(i64), + F64(f64), +} + +impl Serialize for Number { + fn serialize(&self, serializer: S) -> Result { + match *self { + Self::U64(v) => serializer.serialize_u64(v), + Self::I64(v) => serializer.serialize_i64(v), + Self::F64(v) => serializer.serialize_f64(v), + } + } +} + +impl<'de> Deserialize<'de> for Number { + fn deserialize>(deserializer: D) -> Result { + struct NumberVisitor; + + impl<'de> Visitor<'de> for NumberVisitor { + type Value = Number; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a number") + } + + fn visit_u64(self, v: u64) -> Result { + Ok(Number::U64(v)) + } + + fn visit_i64(self, v: i64) -> Result { + Ok(Number::I64(v)) + } + + fn visit_f64(self, v: f64) -> Result { + Ok(Number::F64(v)) + } + } + + deserializer.deserialize_any(NumberVisitor) + } +} + +macro_rules! try_cast { + ($v:ident :$T:ty => $U:ty) => {{ + let c = $v as $U; + if c as $T == $v { Some(c) } else { None } + }}; +} + +macro_rules! int { + ($f:ident, $T:ty) => { + pub fn $f(self) -> Option<$T> { + match self { + Self::U64(v) => v.try_into().ok(), + Self::I64(v) => v.try_into().ok(), + Self::F64(v) => try_cast!(v:f64 => $T), + } + } + } +} + +macro_rules! float { + ($f:ident, $T:ty) => { + pub fn $f(self) -> Option<$T> { + match self { + Self::U64(v) => try_cast!(v:u64 => $T), + Self::I64(v) => try_cast!(v:i64 => $T), + Self::F64(v) => try_cast!(v:f64 => $T), + } + } + } +} + +impl Number { + int!(as_u8, u8); + int!(as_u16, u16); + int!(as_u32, u32); + int!(as_u64, u64); + int!(as_usize, usize); + + int!(as_i8, i8); + int!(as_i16, i16); + int!(as_i32, i32); + int!(as_i64, i64); + int!(as_isize, isize); + + float!(as_f32, f32); + float!(as_f64, f64); +} + +macro_rules! from { + ($T:ty => $variant:ident) => { + impl From<$T> for Number { + fn from(v: $T) -> Self { + Self::$variant(v.into()) + } + } + }; + ($($T:ty => $variant:ident),+ $(,)?) => { + $(from!($T => $variant);)+ + } +} + +from!( + u64 => U64, + u32 => U64, + u16 => U64, + u8 => U64, + i64 => I64, + i32 => I64, + i16 => I64, + i8 => I64, + f32 => F64, + f64 => F64, +); + +impl From for Number { + fn from(v: usize) -> Self { + Self::U64(v.try_into().unwrap()) + } +} + +macro_rules! try_from { + ($T:ty => $f:ident) => { + impl TryFrom for $T { + type Error = (); + fn try_from(number: Number) -> Result<$T, Self::Error> { + number.$f().ok_or(()) + } + } + }; + ($($T:ty => $f:ident),+ $(,)?) => { + $(try_from!($T => $f);)+ + } +} + +try_from!( + u64 => as_u64, + u32 => as_u32, + u16 => as_u16, + u8 => as_u8, + usize => as_usize, + + i64 => as_i64, + i32 => as_i32, + i16 => as_i16, + i8 => as_i8, + isize => as_isize, + + f32 => as_f32, + f64 => as_f64, +); diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs new file mode 100644 index 000000000..01785ad7b --- /dev/null +++ b/diskann-record/src/save/context.rs @@ -0,0 +1,206 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Save-side context and side-car writer. +//! +//! [`Context`] is the cheap handle handed to every [`super::Save::save`] impl. It owns +//! nothing visible to the user; cloning it is free, and it can be passed to children to +//! propagate the same artifact-tracking state. +//! +//! [`Writer`] is the borrowed side-car artifact handle returned by [`Context::write`]. +//! It implements [`std::io::Write`] and [`std::io::Seek`]; calling +//! [`Writer::finish`] flushes the buffer and yields a [`Handle`] that can be inserted +//! into a [`super::Record`]. + +use std::{collections::HashSet, fs::File, io::BufWriter, path::PathBuf, sync::Mutex}; + +use crate::save::{Error, Handle, Result, Value}; + +/// The owned context behind a [`Context`]. +/// +/// Holds the manifest directory, the manifest path, and the set of artifact file names +/// registered so far. Lookup and insertion go through a [`Mutex`] so that concurrent +/// [`Save`](super::Save) impls cannot accidentally hand out the same artifact name twice. +#[derive(Debug)] +pub(super) struct ContextInner { + dir: PathBuf, + metadata: PathBuf, + files: Mutex>, +} + +#[derive(serde::Serialize)] +struct Final<'a> { + files: Vec<&'a str>, + value: &'a Value<'a>, +} + +impl ContextInner { + // TODO: Error if the directory looks bad? + pub(super) fn new(dir: PathBuf, metadata: PathBuf) -> Self { + Self { + dir, + metadata, + files: Mutex::new(HashSet::new()), + } + } + + pub(super) fn context(&self) -> Context<'_> { + Context { inner: self } + } + + pub(super) fn write(&self, key: &str) -> Result> { + // TODO: Proper disambiguation - making UUIDs etc. + let mut files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + if !files.insert(key.into()) { + return Err(Error::message(format!( + "file name {:?} has already been registered with this save context", + key, + ))); + } + let full = self.dir.join(key); + let file = std::fs::File::create_new(&full).map_err(|err| { + Error::new(err).context(format!("while creating new file {}", full.display())) + })?; + Ok(Writer { + io: BufWriter::new(file), + name: key.into(), + _lifetime: std::marker::PhantomData, + }) + } + + /// Finalize the manifest. + /// + /// Writes the manifest JSON atomically: serializes to a `.temp` file first, + /// then renames it into place. Fails if the temp file already exists (an in-flight + /// save is in progress, or a previous run aborted between rename steps). + pub fn finish(self, value: Value<'_>) -> Result<()> { + let mut temp = self.metadata.clone().into_os_string(); + temp.push(".temp"); + let temp = PathBuf::from(temp); + if temp.exists() { + return Err(Error::message(format!( + "Temporary file {} already exists. Aborting!", + temp.display() + ))); + } + let files = self + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); + let f = Final { + files: files.iter().map(|k| &**k).collect(), + value: &value, + }; + + let buffer = std::fs::File::create(&temp).map_err(|err| { + Error::new(err).context(format!( + "while creating temp manifest file {}", + temp.display() + )) + })?; + serde_json::to_writer_pretty(buffer, &f) + .map_err(|err| Error::new(err).context("while serializing manifest to JSON"))?; + std::fs::rename(&temp, &self.metadata).map_err(|err| { + Error::new(err).context(format!( + "while renaming temp manifest {} to final path {}", + temp.display(), + self.metadata.display() + )) + })?; + Ok(()) + } +} + +/// A cheap, clonable handle threaded through every [`Save::save`](super::Save) impl. +/// +/// `Context` exposes one operation — [`Context::write`] — for allocating a side-car +/// artifact. The same context is passed to nested [`Save`](super::Save) impls (typically +/// via the [`save_fields!`](crate::save_fields) macro), so a single save tree shares +/// artifact-name bookkeeping. +#[derive(Debug, Clone)] +pub struct Context<'a> { + inner: &'a ContextInner, +} + +impl<'a> Context<'a> { + /// Allocate a new side-car artifact named `key` in the manifest directory. + /// + /// The returned [`Writer`] is positioned at offset 0 and implements + /// [`std::io::Write`] / [`std::io::Seek`]. Call [`Writer::finish`] to obtain a + /// [`Handle`] that may be inserted into a [`Record`](super::Record). + /// + /// # Errors + /// + /// Returns [`Error`] if `key` has already been registered with this context (names + /// must be unique within a single save), or if the underlying file cannot be created + /// (e.g. because the artifact already exists on disk). + pub fn write(&self, key: &str) -> Result> { + self.inner.write(key) + } +} + +/// A borrowed side-car artifact writer produced by [`Context::write`]. +/// +/// Implements [`std::io::Write`] and [`std::io::Seek`]. Writes are buffered; calling +/// [`Writer::finish`] flushes the buffer, closes the file, and returns a [`Handle`]. +#[derive(Debug)] +pub struct Writer<'a> { + io: BufWriter, + name: String, + _lifetime: std::marker::PhantomData<&'a ()>, +} + +impl Writer<'_> { + /// Flush and close the writer, returning a [`Handle`] for the artifact. + /// + /// Insert the returned handle into a [`Record`](super::Record) (typically via + /// [`Record::insert`](super::Record::insert)) so that load-side code can locate the + /// artifact through the manifest. + pub fn finish(self) -> Result { + // NOTE: self.io.into_inner() will flush the buffer and close the file. + self.io + .into_inner() + .map_err(|err| Error::new(err.into_error()))?; + Ok(Handle::new(self.name)) + } +} + +impl std::io::Write for Writer<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.io.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.io.flush() + } + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + self.io.write_vectored(bufs) + } + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + self.io.write_all(buf) + } + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { + self.io.write_fmt(args) + } +} + +impl std::io::Seek for Writer<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.io.seek(pos) + } + + fn rewind(&mut self) -> std::io::Result<()> { + self.io.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.io.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.io.seek_relative(offset) + } +} diff --git a/diskann-record/src/save/error.rs b/diskann-record/src/save/error.rs new file mode 100644 index 000000000..5262fad2a --- /dev/null +++ b/diskann-record/src/save/error.rs @@ -0,0 +1,65 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Save-side error type. +//! +//! Mirrors the load-side [`super::super::load::Error`] in shape but does not carry a +//! recoverable / critical distinction: every save failure is terminal because no +//! probing fallback exists on the writer side. + +use std::fmt::{Debug, Display}; + +/// A specialized [`std::result::Result`] for save-side operations. +pub type Result = ::std::result::Result; + +/// Save-side error. +/// +/// Wraps [`anyhow::Error`] for rich context chains (see [`Error::context`]) and is +/// returned from every fallible save-side operation, including [`super::Save::save`] +/// impls and [`super::save_to_disk`]. +#[derive(Debug)] +pub struct Error { + inner: anyhow::Error, +} + +impl Error { + /// Wrap an underlying source error. + pub fn new(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Error { + inner: anyhow::Error::new(err), + } + } + + /// Construct an error from a display message with no source. + pub fn message(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Error { + inner: anyhow::Error::msg(message), + } + } + + /// Attach additional context describing what was being attempted. + pub fn context(self, message: D) -> Self + where + D: Display + Send + Sync + 'static, + { + Error { + inner: self.inner.context(message), + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Save Error: {:?}", self.inner) + } +} + +impl std::error::Error for Error {} diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs new file mode 100644 index 000000000..bcb7027fd --- /dev/null +++ b/diskann-record/src/save/mod.rs @@ -0,0 +1,266 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! # Saving Records to Disk +//! +//! This module provides the writer-side of the framework. User types implement [`Save`] +//! (or, for primitive-like leaves, [`Saveable`]) and obtain a [`Context`] from which they +//! request side-car artifact writers and assemble a [`Record`] of named fields. +//! +//! The top-level entry point is [`save_to_disk`], which serializes a value into a +//! caller-chosen directory plus a manifest path. +//! +//! # Building Records +//! +//! The [`save_fields!`](crate::save_fields) macro is the idiomatic way to build a record +//! from a struct or destructured enum variant. It handles per-field error context and +//! invokes [`Saveable::save`] on each value. +//! +//! # Side-Car Artifacts +//! +//! Binary blobs (e.g. vector buffers) are written to side-car files via +//! [`Context::write`], which returns a [`Writer`](Context). The handle returned by +//! [`Writer::finish`](context::Writer::finish) can be embedded into the record as a +//! [`Handle`]; it serializes as a `$handle` reference and is rehydrated on the load side. +mod value; +pub use value::{Handle, Keys, Record, Value, Versioned}; + +mod context; +pub use context::{Context, Writer}; + +mod error; +pub use error::{Error, Result}; + +use crate::Version; + +/// Serialize `x` to disk. +/// +/// The manifest (a JSON document) is written atomically to `metadata`; any side-car +/// artifacts the type's [`Save::save`] impl creates via [`Context::write`] are written +/// into `dir`. +/// +/// # Errors +/// +/// Returns [`Error`] if the directory cannot be written to, if the manifest cannot be +/// serialized, or if a user impl returns an error. +pub fn save_to_disk( + x: &T, + dir: impl AsRef, + metadata: impl AsRef, +) -> Result<()> +where + T: Saveable, +{ + let inner = context::ContextInner::new(dir.as_ref().into(), metadata.as_ref().into()); + let value = x.save(inner.context())?; + inner.finish(value) +} + +/// Implemented by user types that map to a versioned [`Record`]. +/// +/// This is the primary trait for structured user types. A [`Save`] impl describes the +/// versioned schema of `Self`: its associated [`VERSION`](Self::VERSION) is attached to +/// the [`Record`] produced by [`Self::save`](Self::save). +/// +/// # Enums +/// +/// Enum types are encoded by returning a [`Record`] with a single user key whose name +/// is the variant tag and whose value is the variant's payload (frequently +/// [`Value::Null`] for unit variants). See the crate-level docs for a worked example. +pub trait Save { + /// The schema version attached to records produced by this impl. + /// + /// Loaders compare this against the version stored in the manifest to decide + /// between [`Load::load`](crate::load::Load::load) and + /// [`Load::load_legacy`](crate::load::Load::load_legacy). + const VERSION: Version; + + /// Serialize `self` into a [`Record`]. + /// + /// Use the supplied [`Context`] to request side-car artifact writers. Use the + /// [`save_fields!`](crate::save_fields) macro to populate the record. + fn save(&self, context: Context<'_>) -> Result>; +} + +/// Implemented by any value that can be written into a [`Value`]. +/// +/// This is the bottom of the trait hierarchy and is implemented for: +/// +/// * Primitive numeric types (signed, unsigned, floats, `NonZero*`). +/// * [`bool`], [`str`], [`String`], and [`Handle`]. +/// * [`Option`] (serializes `None` as [`Value::Null`]). +/// * `&[T]` and [`Vec`] (serialize as [`Value::Array`]). +/// * Any `T: Save` (wraps the produced record in [`Value::Object`] with the type's +/// [`Save::VERSION`]). +/// +/// Most user types should implement [`Save`] (which gets a [`Saveable`] impl for free +/// via the blanket below) rather than [`Saveable`] directly. +pub trait Saveable { + /// Serialize `self` into a [`Value`]. + fn save(&self, context: Context<'_>) -> Result>; +} + +impl Saveable for T +where + T: Save, +{ + fn save(&self, context: Context<'_>) -> Result> { + let record = self.save(context)?; + Ok(record.into_value(T::VERSION)) + } +} + +////////////////// +// Random Stuff // +////////////////// + +impl Saveable for [T] +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + let values: Result> = self.iter().map(|t| t.save(context.clone())).collect(); + values.map(Value::Array) + } +} + +impl Saveable for Vec +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + self.as_slice().save(context) + } +} + +impl Saveable for str { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::String(self.into())) + } +} + +impl Saveable for String { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::String(self.as_str().into())) + } +} + +impl Saveable for Handle { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Handle(self.clone())) + } +} + +impl Saveable for bool { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Bool(*self)) + } +} + +impl Saveable for Option +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + match self { + None => Ok(Value::Null), + Some(t) => t.save(context), + } + } +} + +macro_rules! save_number { + ($T:ty) => { + impl Saveable for $T { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Number((*self).into())) + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(save_number!($Ts);)+ + } +} + +save_number!(usize, u64, u32, u16, u8, i64, i32, i16, i8, f32, f64); + +// NonZero* primitives serialize as their inner numeric type. Loaders reject zero. +macro_rules! save_nonzero { + ($T:ty) => { + impl Saveable for $T { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Number(self.get().into())) + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(save_nonzero!($Ts);)+ + } +} + +save_nonzero!( + std::num::NonZeroU32, + std::num::NonZeroU64, + std::num::NonZeroUsize +); + +#[derive(Debug, Clone, Copy)] +#[doc(hidden)] +pub struct Serializing(pub &'static str); + +impl std::fmt::Display for Serializing { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "while serializing field \"{}\"", self.0) + } +} + +/// Build a [`Record`] from a list of fields. +/// +/// Two forms are supported: +/// +/// * `save_fields!(self, context, [a, b, c])` reads each field as `self.a`, +/// `self.b`, etc. Use this from `Save::save` for plain structs. +/// * `save_fields!(context, [a, b, c])` reads each field from a local binding of +/// the same name. Use this inside enum match arms where the variant's payload +/// has already been destructured into local bindings. Those bindings are +/// assumed to be references (which is automatic when matching against `&self`); +/// for an owned local, take a reference explicitly first. +#[macro_export] +macro_rules! save_fields { + ($me:ident, $context:ident, [$($field:ident),+ $(,)?]) => {{ + $crate::save::Record::from_iter( + [ + $( + ( + ::std::borrow::Cow::Borrowed(stringify!($field)), + <_ as $crate::save::Saveable>::save( + &$me.$field, + $context.clone() + ).map_err(|err| { + err.context($crate::save::Serializing(stringify!($field))) + })? + ), + )+ + ] + ) + }}; + ($context:ident, [$($field:ident),+ $(,)?]) => {{ + $crate::save::Record::from_iter( + [ + $( + ( + ::std::borrow::Cow::Borrowed(stringify!($field)), + <_ as $crate::save::Saveable>::save( + $field, + $context.clone() + ).map_err(|err| { + err.context($crate::save::Serializing(stringify!($field))) + })? + ), + )+ + ] + ) + }}; +} diff --git a/diskann-record/src/save/value.rs b/diskann-record/src/save/value.rs new file mode 100644 index 000000000..4a3f7cc42 --- /dev/null +++ b/diskann-record/src/save/value.rs @@ -0,0 +1,362 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Wire-level value types used in the on-disk manifest. +//! +//! Every saveable field is one of: +//! +//! * [`Value::Null`] / [`Value::Bool`] / [`Value::Number`] / [`Value::String`] / +//! [`Value::Bytes`] — primitive scalars. +//! * [`Value::Array`] — a homogeneous sequence (used by `Vec` and `&[T]`). +//! * [`Value::Object`] — a [`Versioned`] [`Record`] (the canonical encoding for a +//! `T: super::Save`). +//! * [`Value::Handle`] — a reference to a side-car artifact (produced by +//! [`super::Context::write`] + [`super::context::Writer::finish`]). +//! +//! Most user code never touches these enums directly: [`super::Saveable`] impls turn +//! Rust values into [`Value`]s, and the [`save_fields!`](crate::save_fields) macro +//! assembles the surrounding [`Record`]. + +use std::{borrow::Cow, collections::HashMap}; + +use serde::{ + Deserialize, Deserializer, Serialize, Serializer, + de::{self, MapAccess, SeqAccess, Visitor}, + ser::SerializeStruct, +}; + +use crate::{Number, Version, save::Error}; + +/// The wire-level union of every saveable kind. +/// +/// See the module-level docs for an overview of when each variant is produced. The +/// borrowing parameter `'a` lets [`Value::String`], [`Value::Bytes`], and nested +/// records reuse memory owned by the caller without copying. +#[derive(Debug)] +pub enum Value<'a> { + Null, + Bool(bool), + Number(Number), + String(Cow<'a, str>), + Bytes(Cow<'a, [u8]>), + Array(Vec>), + Object(Versioned<'a>), + Handle(Handle), +} + +impl Serialize for Value<'_> { + fn serialize(&self, ser: S) -> Result { + match self { + Self::Null => ser.serialize_none(), + Self::Bool(b) => ser.serialize_bool(*b), + Self::Number(n) => n.serialize(ser), + Self::String(s) => ser.serialize_str(s), + Self::Bytes(b) => ser.serialize_bytes(b), + Self::Array(a) => a.serialize(ser), + Self::Object(v) => v.serialize(ser), + Self::Handle(h) => h.serialize(ser), + } + } +} + +impl<'de> Deserialize<'de> for Value<'static> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Inner; + + impl<'de> Visitor<'de> for Inner { + type Value = Value<'static>; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a valid Value") + } + + fn visit_unit(self) -> Result, E> { + Ok(Value::Null) + } + + fn visit_none(self) -> Result, E> { + Ok(Value::Null) + } + + fn visit_some(self, deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + Value::deserialize(deserializer) + } + + fn visit_bool(self, v: bool) -> Result, E> { + Ok(Value::Bool(v)) + } + + fn visit_u64(self, v: u64) -> Result, E> { + Ok(Value::Number(Number::U64(v))) + } + + fn visit_i64(self, v: i64) -> Result, E> { + Ok(Value::Number(Number::I64(v))) + } + + fn visit_f64(self, v: f64) -> Result, E> { + Ok(Value::Number(Number::F64(v))) + } + + fn visit_str(self, v: &str) -> Result, E> { + Ok(Value::String(Cow::Owned(v.to_owned()))) + } + + fn visit_string(self, v: String) -> Result, E> { + Ok(Value::String(Cow::Owned(v))) + } + + fn visit_bytes(self, v: &[u8]) -> Result, E> { + Ok(Value::Bytes(Cow::Owned(v.to_owned()))) + } + + fn visit_byte_buf(self, v: Vec) -> Result, E> { + Ok(Value::Bytes(Cow::Owned(v))) + } + + fn visit_seq(self, mut seq: A) -> Result, A::Error> + where + A: SeqAccess<'de>, + { + let mut values = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(v) = seq.next_element()? { + values.push(v); + } + Ok(Value::Array(values)) + } + + fn visit_map(self, mut map: A) -> Result, A::Error> + where + A: MapAccess<'de>, + { + // TODO: Handle invariants that only one of our reserved words are present. + let mut version: Option = None; + let mut handle_name: Option = None; + let mut fields: HashMap, Value<'static>> = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "$version" => { + version = Some(map.next_value()?); + } + "$handle" => { + handle_name = Some(map.next_value()?); + } + _ => { + let value = map.next_value()?; + fields.insert(Cow::Owned(key), value); + } + } + } + + if let Some(name) = handle_name { + if version.is_some() || !fields.is_empty() { + return Err(de::Error::custom( + "handle object must contain only a \"$handle\" field", + )); + } + return Ok(Value::Handle(Handle(name))); + } + + if let Some(version) = version { + let record = Record { record: fields }; + return Ok(Value::Object(Versioned { record, version })); + } + + Err(de::Error::custom( + "map must contain either \"$version\" or \"$handle\"", + )) + } + } + + deserializer.deserialize_any(Inner) + } +} + +impl From for Value<'_> { + fn from(handle: Handle) -> Self { + Self::Handle(handle) + } +} + +/// A map of named [`Value`]s. +/// +/// `Record` is the body of a saved object: each call to [`super::Save::save`] returns +/// one, and [`Record::into_value`] wraps it as a [`Versioned`] [`Value::Object`] ready +/// for insertion into another record. Keys beginning with `$` are reserved for +/// framework metadata (see [`crate::is_reserved`]). +#[derive(Debug, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Record<'a> { + record: HashMap, Value<'a>>, +} + +impl<'a> Record<'a> { + /// Construct an empty record. Useful for unit enum variants. + pub fn empty() -> Self { + Self { + record: HashMap::new(), + } + } + + /// Returns `true` if a value is registered under `key`. + pub fn contains_key(&self, key: &str) -> bool { + self.record.contains_key(key) + } + + /// Look up the [`Value`] registered under `key`, if any. + pub fn get(&self, key: &str) -> Option<&Value<'a>> { + self.record.get(key) + } + + /// Number of (user) keys in this record. Reserved keys (`$version`, `$handle`) + /// are tracked elsewhere and never appear here. + pub fn len(&self) -> usize { + self.record.len() + } + + /// Returns `true` if this record has no user keys. + pub fn is_empty(&self) -> bool { + self.record.is_empty() + } + + /// Iterate over the user keys in this record. Order is unspecified. + pub fn keys(&self) -> Keys<'_, 'a> { + Keys { + inner: self.record.keys(), + } + } + + /// Insert `value` under `key`. + /// + /// # Errors + /// + /// Returns [`Error`] if `key` begins with `$`, which is reserved for the + /// save/load framework (see [`crate::is_reserved`]). + pub fn insert(&mut self, key: K, value: V) -> crate::save::Result>> + where + K: Into>, + V: Into>, + { + let key = key.into(); + if crate::is_reserved(&key) { + return Err(Error::message(format!( + "record key {:?} is reserved (keys starting with `$` are reserved for the \ + save/load framework)", + key, + ))); + } + + Ok(self.record.insert(key, value.into())) + } + + /// Wrap this record as a versioned [`Value`] ready for insertion into another + /// record. Use this from enum [`Save`](crate::save::Save) impls to attach the + /// outer type's version to an inline variant payload. + pub fn into_value(self, version: Version) -> Value<'a> { + Value::Object(Versioned::new(self, version)) + } +} + +/// Iterator over the keys of a [`Record`]. +pub struct Keys<'r, 'a> { + inner: std::collections::hash_map::Keys<'r, Cow<'a, str>, Value<'a>>, +} + +impl<'r, 'a> Iterator for Keys<'r, 'a> { + type Item = &'r str; + + fn next(&mut self) -> Option { + self.inner.next().map(|k| k.as_ref()) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl ExactSizeIterator for Keys<'_, '_> {} + +impl<'a> FromIterator<(Cow<'a, str>, Value<'a>)> for Record<'a> { + fn from_iter, Value<'a>)>>(itr: I) -> Self { + Self { + record: itr.into_iter().collect(), + } + } +} + +/// A [`Record`] paired with the schema [`Version`] used to produce it. +/// +/// Serialized as a normal object plus a `$version` field on the wire. Constructed by +/// [`Record::into_value`]. +#[derive(Debug, Serialize, Deserialize)] +pub struct Versioned<'a> { + #[serde(flatten)] + record: Record<'a>, + #[serde(rename = "$version")] + version: Version, +} + +impl<'a> Versioned<'a> { + pub(crate) fn new(record: Record<'a>, version: Version) -> Self { + Self { record, version } + } + + pub(crate) fn version(&self) -> Version { + self.version + } + + pub(crate) fn record(&self) -> &Record<'a> { + &self.record + } +} + +/// A reference to a side-car artifact in the manifest directory. +/// +/// Produced by [`Writer::finish`](super::Writer::finish) after a side-car write completes and +/// inserted into a [`Record`] like any other value. Serializes as `{"$handle": ""}` +/// on the wire; the load side rehydrates it through +/// [`crate::load::Object::read`]. +#[derive(Debug, Clone)] +pub struct Handle(String); + +impl Handle { + pub(crate) fn new(string: String) -> Self { + Self(string) + } + + pub(crate) fn as_str(&self) -> &str { + &self.0 + } +} + +impl Serialize for Handle { + fn serialize(&self, ser: S) -> Result { + let mut handle = ser.serialize_struct("Handle", 1)?; + handle.serialize_field("$handle", &self.0)?; + handle.end() + } +} + +impl<'de> Deserialize<'de> for Handle { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper { + #[serde(rename = "$handle")] + handle: String, + } + let helper = Helper::deserialize(deserializer)?; + Ok(Handle(helper.handle)) + } +} diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs new file mode 100644 index 000000000..cbd7eab6f --- /dev/null +++ b/diskann-record/src/version.rs @@ -0,0 +1,82 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Semver-style version stamps embedded in every saved object. + +use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; + +/// A semver-style schema version attached to every saved record. +/// +/// Each [`crate::save::Save`] / [`crate::load::Load`] impl declares its +/// `const VERSION: Version`. On load, the version recorded in the manifest is compared +/// against the declared version to dispatch between +/// [`Load::load`](crate::load::Load::load) and +/// [`Load::load_legacy`](crate::load::Load::load_legacy). +/// +/// The framework treats versions as opaque triples and only checks them for equality; +/// ordering / semver semantics are entirely up to the implementing type. +/// +/// On the wire, a `Version` is encoded as a single string of the form +/// `"major.minor.patch"` (e.g. `"0.0.0"`). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Version { + pub major: u32, + pub minor: u32, + pub patch: u32, +} + +impl Version { + /// Construct a [`Version`] from its three components. + pub const fn new(major: u32, minor: u32, patch: u32) -> Self { + Self { + major, + minor, + patch, + } + } +} + +impl Serialize for Version { + fn serialize(&self, ser: S) -> Result { + ser.collect_str(&format_args!( + "{}.{}.{}", + self.major, self.minor, self.patch + )) + } +} + +impl<'de> Deserialize<'de> for Version { + fn deserialize>(de: D) -> Result { + struct VersionVisitor; + + impl de::Visitor<'_> for VersionVisitor { + type Value = Version; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a version string of the form \"major.minor.patch\"") + } + + fn visit_str(self, v: &str) -> Result { + let mut parts = v.split('.'); + let major = parts.next().and_then(|s| s.parse::().ok()); + let minor = parts.next().and_then(|s| s.parse::().ok()); + let patch = parts.next().and_then(|s| s.parse::().ok()); + match (major, minor, patch, parts.next()) { + (Some(major), Some(minor), Some(patch), None) => Ok(Version { + major, + minor, + patch, + }), + _ => Err(E::custom(format!( + "unknown version {:?}: expected three `.`-separated u32 components", + v, + ))), + } + } + } + + de.deserialize_str(VersionVisitor) + } +} From 7d29f1c228fe3c8d67323d37b71f5c81aee47cff Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Thu, 18 Jun 2026 11:07:13 -0700 Subject: [PATCH 02/23] making copilot review happy --- diskann-record/src/save/context.rs | 58 ++++++++++++++++++++---------- diskann-record/src/save/mod.rs | 2 +- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index 01785ad7b..e8cdcc66b 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -51,6 +51,19 @@ impl ContextInner { } pub(super) fn write(&self, key: &str) -> Result> { + // Reject absolute paths, parent traversal, and multi-component paths. Handles must be + // simple file names relative to the manifest directory. + let mut components = std::path::Path::new(key).components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err(Error::message(format!( + "artifact file name {:?} must be a relative file name with no path separators", + key, + ))); + } + } + // TODO: Proper disambiguation - making UUIDs etc. let mut files = self .files @@ -63,6 +76,12 @@ impl ContextInner { ))); } let full = self.dir.join(key); + if full.exists() { + return Err(Error::message(format!( + "file {} already exists", + full.display() + ))); + } let file = std::fs::File::create_new(&full).map_err(|err| { Error::new(err).context(format!("while creating new file {}", full.display())) })?; @@ -79,30 +98,33 @@ impl ContextInner { /// then renames it into place. Fails if the temp file already exists (an in-flight /// save is in progress, or a previous run aborted between rename steps). pub fn finish(self, value: Value<'_>) -> Result<()> { - let mut temp = self.metadata.clone().into_os_string(); - temp.push(".temp"); - let temp = PathBuf::from(temp); - if temp.exists() { - return Err(Error::message(format!( - "Temporary file {} already exists. Aborting!", - temp.display() - ))); - } let files = self - .files - .into_inner() - .unwrap_or_else(|poison| poison.into_inner()); + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); let f = Final { files: files.iter().map(|k| &**k).collect(), value: &value, }; - - let buffer = std::fs::File::create(&temp).map_err(|err| { - Error::new(err).context(format!( - "while creating temp manifest file {}", - temp.display() - )) + + // Fail if the temp file already exists + let mut temp = self.metadata.clone().into_os_string(); + temp.push(".temp"); + let temp = PathBuf::from(temp); + let buffer = std::fs::File::create_new(&temp).map_err(|err| { + if err.kind() == std::io::ErrorKind::AlreadyExists { + Error::message(format!( + "Temporary file {} already exists. Aborting!", + temp.display() + )) + } else { + Error::new(err).context(format!( + "while creating temp manifest file {}", + temp.display() + )) + } })?; + serde_json::to_writer_pretty(buffer, &f) .map_err(|err| Error::new(err).context("while serializing manifest to JSON"))?; std::fs::rename(&temp, &self.metadata).map_err(|err| { diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index bcb7027fd..268edaab9 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -21,7 +21,7 @@ //! # Side-Car Artifacts //! //! Binary blobs (e.g. vector buffers) are written to side-car files via -//! [`Context::write`], which returns a [`Writer`](Context). The handle returned by +//! [`Context::write`], which returns a [`Writer`]. The handle returned by //! [`Writer::finish`](context::Writer::finish) can be embedded into the record as a //! [`Handle`]; it serializes as a `$handle` reference and is rehydrated on the load side. mod value; From f0646107d0ba595d51f35cc3998a2532831e1064 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Thu, 18 Jun 2026 11:09:06 -0700 Subject: [PATCH 03/23] making copilot review happy --- diskann-record/src/load/context.rs | 14 +++++++++----- diskann-record/src/save/context.rs | 10 +++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index f3e03ac98..3bf341232 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -63,11 +63,15 @@ impl ContextInner { pub(super) fn read(&self, key: &str) -> Result> { let key_as_path: &Path = key.as_ref(); - if key.contains("..") || key_as_path.is_absolute() { - return Err(Error::from(error::Kind::MissingFile).context(format!( - "handle references file {:?} which escapes the manifest directory", - key, - ))); + let mut components = key_as_path.components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err(Error::from(error::Kind::MissingFile).context(format!( + "handle references file {:?} which escapes the manifest directory", + key, + ))); + } } if !self.files.contains(key_as_path) { return Err(Error::from(error::Kind::MissingFile).context(format!( diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index e8cdcc66b..b6119cf26 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -99,14 +99,14 @@ impl ContextInner { /// save is in progress, or a previous run aborted between rename steps). pub fn finish(self, value: Value<'_>) -> Result<()> { let files = self - .files - .into_inner() - .unwrap_or_else(|poison| poison.into_inner()); + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); let f = Final { files: files.iter().map(|k| &**k).collect(), value: &value, }; - + // Fail if the temp file already exists let mut temp = self.metadata.clone().into_os_string(); temp.push(".temp"); @@ -124,7 +124,7 @@ impl ContextInner { )) } })?; - + serde_json::to_writer_pretty(buffer, &f) .map_err(|err| Error::new(err).context("while serializing manifest to JSON"))?; std::fs::rename(&temp, &self.metadata).map_err(|err| { From 5ca175ca1e72563dfe3c87e9522a0b5edd0fa049 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 22 Jun 2026 10:47:33 -0700 Subject: [PATCH 04/23] moved up one level since both load/save paths use them --- diskann-record/src/lib.rs | 3 +++ diskann-record/src/save/mod.rs | 3 +-- diskann-record/src/{save => }/value.rs | 29 ++++++++++++++++---------- 3 files changed, 22 insertions(+), 13 deletions(-) rename diskann-record/src/{save => }/value.rs (89%) diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index 0a94b277c..a1729e512 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -86,6 +86,9 @@ pub use number::Number; mod version; pub use version::Version; +mod value; +pub use value::{Handle, Keys, Record, Value, Versioned}; + pub mod load; pub mod save; diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index 268edaab9..f1d4534d5 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -24,8 +24,7 @@ //! [`Context::write`], which returns a [`Writer`]. The handle returned by //! [`Writer::finish`](context::Writer::finish) can be embedded into the record as a //! [`Handle`]; it serializes as a `$handle` reference and is rehydrated on the load side. -mod value; -pub use value::{Handle, Keys, Record, Value, Versioned}; +pub use crate::value::{Handle, Keys, Record, Value, Versioned}; mod context; pub use context::{Context, Writer}; diff --git a/diskann-record/src/save/value.rs b/diskann-record/src/value.rs similarity index 89% rename from diskann-record/src/save/value.rs rename to diskann-record/src/value.rs index 4a3f7cc42..a62096649 100644 --- a/diskann-record/src/save/value.rs +++ b/diskann-record/src/value.rs @@ -5,19 +5,25 @@ //! Wire-level value types used in the on-disk manifest. //! -//! Every saveable field is one of: +//! These types are the shared currency of both halves of the framework: +//! user value -> [`save`](crate::save) -> [`Value`] in the save path, and the +//! [`Value`] -> [`load`](crate::load) -> user value in the load path. +//! +//! Every field stored in a manifest is one of: //! //! * [`Value::Null`] / [`Value::Bool`] / [`Value::Number`] / [`Value::String`] / //! [`Value::Bytes`] — primitive scalars. //! * [`Value::Array`] — a homogeneous sequence (used by `Vec` and `&[T]`). //! * [`Value::Object`] — a [`Versioned`] [`Record`] (the canonical encoding for a -//! `T: super::Save`). +//! `T: crate::save::Save`). //! * [`Value::Handle`] — a reference to a side-car artifact (produced by -//! [`super::Context::write`] + [`super::context::Writer::finish`]). +//! [`crate::save::Context::write`] + [`crate::save::Writer::finish`]). //! -//! Most user code never touches these enums directly: [`super::Saveable`] impls turn -//! Rust values into [`Value`]s, and the [`save_fields!`](crate::save_fields) macro -//! assembles the surrounding [`Record`]. +//! Most user code never touches these enums directly. On the save side, +//! [`crate::save::Saveable`] impls turn Rust values into [`Value`]s and the +//! [`save_fields!`](crate::save_fields) macro assembles the surrounding [`Record`]; on +//! the load side, the [`crate::load`] accessors walk the same [`Value`] tree back into +//! Rust values. use std::{borrow::Cow, collections::HashMap}; @@ -189,10 +195,11 @@ impl From for Value<'_> { /// A map of named [`Value`]s. /// -/// `Record` is the body of a saved object: each call to [`super::Save::save`] returns -/// one, and [`Record::into_value`] wraps it as a [`Versioned`] [`Value::Object`] ready -/// for insertion into another record. Keys beginning with `$` are reserved for -/// framework metadata (see [`crate::is_reserved`]). +/// `Record` is the body of an object in the manifest. On the save side each call to +/// [`crate::save::Save::save`] returns one, and [`Record::into_value`] wraps it as a +/// [`Versioned`] [`Value::Object`] ready for insertion into another record; on the load +/// side the same record is read back through [`crate::load::Object`]. Keys beginning +/// with `$` are reserved for framework metadata (see [`crate::is_reserved`]). #[derive(Debug, Serialize, Deserialize)] #[serde(transparent)] pub struct Record<'a> { @@ -321,7 +328,7 @@ impl<'a> Versioned<'a> { /// A reference to a side-car artifact in the manifest directory. /// -/// Produced by [`Writer::finish`](super::Writer::finish) after a side-car write completes and +/// Produced by [`Writer::finish`](crate::save::Writer::finish) after a side-car write completes and /// inserted into a [`Record`] like any other value. Serializes as `{"$handle": ""}` /// on the wire; the load side rehydrates it through /// [`crate::load::Object::read`]. From 2cd7e6e7aeac2e117338920baabd78bf3dc061aa Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 22 Jun 2026 11:17:06 -0700 Subject: [PATCH 05/23] removed ContextInner; added new SaveContext, LoadContext traits; save and load paths now use ; feature gated serde and disk impls --- diskann-record/Cargo.toml | 9 +++- diskann-record/src/lib.rs | 8 +-- diskann-record/src/load/context.rs | 81 ++++++++++++++++++++-------- diskann-record/src/load/mod.rs | 30 +++++++++-- diskann-record/src/number.rs | 4 ++ diskann-record/src/save/context.rs | 86 +++++++++++++++++++++++++----- diskann-record/src/save/error.rs | 2 +- diskann-record/src/save/mod.rs | 31 +++++++++-- diskann-record/src/value.rs | 17 ++++-- diskann-record/src/version.rs | 3 ++ 10 files changed, 216 insertions(+), 55 deletions(-) diff --git a/diskann-record/Cargo.toml b/diskann-record/Cargo.toml index 8c26ae775..24bf6db2f 100644 --- a/diskann-record/Cargo.toml +++ b/diskann-record/Cargo.toml @@ -9,8 +9,13 @@ edition = "2024" [dependencies] anyhow.workspace = true -serde = { workspace = true, features = ["derive"] } -serde_json.workspace = true +serde = { workspace = true, features = ["derive"], optional = true } +serde_json = { workspace = true, optional = true } + +[features] +default = ["disk"] +serde = ["dep:serde"] +disk = ["serde", "dep:serde_json"] [dev-dependencies] tempfile.workspace = true diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index a1729e512..76f5ef034 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -23,8 +23,10 @@ //! //! # Entry Points //! -//! - [`save::save_to_disk`]: Save a value to a directory plus a manifest path. -//! - [`load::load_from_disk`]: Reload a value from a manifest and its artifact directory. +//! - `save::save_to_disk` (requires the `disk` feature): Save a value to a directory +//! plus a manifest path. +//! - `load::load_from_disk` (requires the `disk` feature): Reload a value from a +//! manifest and its artifact directory. //! //! # Defining Save / Load //! @@ -122,7 +124,7 @@ pub const fn is_reserved(s: &str) -> bool { // Tests // /////////// -#[cfg(test)] +#[cfg(all(test, feature = "disk"))] mod tests { use super::*; diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index 3bf341232..aec198930 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -17,33 +17,67 @@ //! //! [`Reader`] implements [`std::io::Read`] and [`std::io::Seek`] over the artifact file. +use std::{fs::File, io::BufReader}; + +#[cfg(feature = "disk")] use std::{ collections::HashSet, - fs::File, - io::BufReader, path::{Path, PathBuf}, }; +#[cfg(feature = "disk")] +use crate::load::Error; use crate::{ Number, Version, - load::{Error, Loadable, Result, error}, + load::{Loadable, Result, error}, save, }; +/// The backing store for a load operation. +/// +/// A `LoadContext` supplies the root manifest [`save::Value`] ([`LoadContext::value`]) +/// and resolves side-car artifacts referenced by handles ([`LoadContext::read`]). The +/// default, disk-backed implementation (`DiskContext`) lives in this module under the +/// `disk` feature; alternative implementations (e.g. a virtual filesystem or a purely +/// in-memory store) can be supplied for testing. +/// +/// The generic [`load`](super::load) entry point is parameterized over this trait, and +/// [`Context`] / [`Object`] / `Array` borrow it as an object-safe `&dyn LoadContext` +/// so the load tree is agnostic to the concrete context type. +pub trait LoadContext { + /// The root value of the manifest. + /// + /// # Errors + /// + /// Returns [`Error`](crate::load::Error) if the root value cannot be produced. + fn value(&self) -> Result<&save::Value<'_>>; + + /// Open the side-car artifact identified by `key` for reading. + /// + /// # Errors + /// + /// Returns [`error::Kind::MissingFile`] if the file is not registered with this + /// context or if `key` escapes the manifest directory. + fn read(&self, key: &str) -> Result>; +} + +#[cfg(feature = "disk")] #[derive(Debug, serde::Deserialize)] -pub(super) struct ContextInner { +pub(super) struct DiskContext { dir: PathBuf, files: HashSet, value: save::Value<'static>, } +#[cfg(feature = "disk")] #[derive(Debug, serde::Deserialize)] struct FileRepr { files: HashSet, value: save::Value<'static>, } -impl ContextInner { +#[cfg(feature = "disk")] +impl DiskContext { pub(super) fn new(metadata: &Path, dir: &Path) -> Result { let file = std::fs::File::open(metadata).map_err(|e| { Error::new(e).context(format!("while trying to open {}", metadata.display())) @@ -60,8 +94,15 @@ impl ContextInner { }; Ok(this) } +} + +#[cfg(feature = "disk")] +impl LoadContext for DiskContext { + fn value(&self) -> Result<&save::Value<'_>> { + Ok(&self.value) + } - pub(super) fn read(&self, key: &str) -> Result> { + fn read(&self, key: &str) -> Result> { let key_as_path: &Path = key.as_ref(); let mut components = key_as_path.components(); match components.next() { @@ -91,10 +132,6 @@ impl ContextInner { Ok(reader) } - - pub(super) fn context(&self) -> Context<'_> { - Context::new(self, &self.value) - } } /// A borrowed reader over a side-car artifact. @@ -151,18 +188,18 @@ impl std::io::Seek for Reader<'_> { /// Loaders use the `as_*` accessors to peek at the underlying [`save::Value`] kind /// (e.g. [`Context::as_object`], [`Context::as_array`], [`Context::as_str`]) and /// [`Context::load`] to recursively deserialize a nested value into a concrete type. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Context<'a> { - inner: &'a ContextInner, + inner: &'a dyn LoadContext, value: &'a save::Value<'a>, } impl<'a> Context<'a> { - fn new(inner: &'a ContextInner, value: &'a save::Value<'a>) -> Self { + pub(super) fn new(inner: &'a dyn LoadContext, value: &'a save::Value<'a>) -> Self { Self { inner, value } } - fn context(&self) -> &'a ContextInner { + fn context(&self) -> &'a dyn LoadContext { self.inner } @@ -249,9 +286,8 @@ impl<'a> Context<'a> { /// version via [`Object::version`], the user keys via [`Object::keys`], typed field /// extraction via [`Object::field`] (and the [`load_fields!`](crate::load_fields) /// macro), and side-car artifact access via [`Object::read`]. -#[derive(Debug)] pub struct Object<'a> { - inner: &'a ContextInner, + inner: &'a dyn LoadContext, record: &'a save::Record<'a>, version: Version, } @@ -333,7 +369,7 @@ impl<'a> Object<'a> { self.inner.read(handle.as_str()) } - fn context(&self) -> &'a ContextInner { + fn context(&self) -> &'a dyn LoadContext { self.inner } } @@ -343,14 +379,13 @@ impl<'a> Object<'a> { /// Backed by a borrowed `&[Value]`. Use [`Array::iter`] to walk the elements; each /// item is yielded as a [`Context`] that can be further deserialized via /// [`Context::load`]. -#[derive(Debug)] pub struct Array<'a> { - inner: &'a ContextInner, + inner: &'a dyn LoadContext, array: &'a [save::Value<'a>], } impl<'a> Array<'a> { - fn new(inner: &'a ContextInner, array: &'a [save::Value<'a>]) -> Self { + fn new(inner: &'a dyn LoadContext, array: &'a [save::Value<'a>]) -> Self { Self { inner, array } } @@ -369,19 +404,19 @@ impl<'a> Array<'a> { Iter::new(self.context(), self.array.iter()) } - fn context(&self) -> &'a ContextInner { + fn context(&self) -> &'a dyn LoadContext { self.inner } } /// Iterator returned by [`Array::iter`]. pub struct Iter<'a> { - inner: &'a ContextInner, + inner: &'a dyn LoadContext, iter: std::slice::Iter<'a, save::Value<'a>>, } impl<'a> Iter<'a> { - fn new(inner: &'a ContextInner, iter: std::slice::Iter<'a, save::Value<'a>>) -> Self { + fn new(inner: &'a dyn LoadContext, iter: std::slice::Iter<'a, save::Value<'a>>) -> Self { Self { inner, iter } } } diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs index 032cd4f27..8f6a87e43 100644 --- a/diskann-record/src/load/mod.rs +++ b/diskann-record/src/load/mod.rs @@ -9,7 +9,8 @@ //! primitive-like leaves, [`Loadable`]) and obtain an [`Object`] / [`Context`] from which //! they extract individual fields and side-car artifacts. //! -//! The top-level entry point is [`load_from_disk`], which reads a manifest and dispatches +//! The generic entry point is [`load`]; `load_from_disk` (available under the `disk` +//! feature) is the disk-backed convenience wrapper that reads a manifest and dispatches //! into the user type's [`Load`] impl. //! //! # Reading Records @@ -36,12 +37,32 @@ pub mod error; pub use error::{Error, Result}; mod context; -pub use context::{Context, Object}; +pub use context::{Context, LoadContext, Object}; +#[cfg(feature = "disk")] use std::path::Path; use crate::{Version, save}; +/// Deserialize a `T` from `context`. +/// +/// This is the generic entry point: it asks `context` for its root value and threads a +/// [`Context`] borrowing `context` through `T`'s [`Loadable::load`] impl. +/// `load_from_disk` is the disk-backed convenience wrapper available under the `disk` +/// feature. +/// +/// # Errors +/// +/// Returns [`Error`] if the context cannot produce a root value or if `T`'s loader fails. +pub fn load<'a, T, C>(context: &'a C) -> Result +where + T: Loadable<'a>, + C: LoadContext, +{ + let value = context.value()?; + T::load(Context::new(context, value)) +} + /// Reload a value previously written by [`save::save_to_disk`]. /// /// `metadata` is the manifest JSON path produced by the saver, and `dir` is the @@ -52,12 +73,13 @@ use crate::{Version, save}; /// Returns [`Error`] if the manifest is missing or malformed, if a referenced artifact is /// missing, or if a user [`Load`] impl fails (e.g. due to a version mismatch with no /// upgrade path). +#[cfg(feature = "disk")] pub fn load_from_disk(metadata: &Path, dir: &Path) -> Result where T: for<'a> Loadable<'a>, { - let inner = context::ContextInner::new(metadata, dir)?; - inner.context().load() + let context = context::DiskContext::new(metadata, dir)?; + load(&context) } /// Implemented by user types that can be reloaded from a versioned [`Object`]. diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index 33b635dc2..15b9644aa 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -10,7 +10,9 @@ //! target Rust type and return `None` when the value is out of range or would lose //! precision; loaders surface this as [`crate::load::error::Kind::NumberOutOfRange`]. +#[cfg(feature = "serde")] use serde::de::{self, Visitor}; +#[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// A numeric value carried in a manifest, preserving the kind the writer chose. @@ -26,6 +28,7 @@ pub enum Number { F64(f64), } +#[cfg(feature = "serde")] impl Serialize for Number { fn serialize(&self, serializer: S) -> Result { match *self { @@ -36,6 +39,7 @@ impl Serialize for Number { } } +#[cfg(feature = "serde")] impl<'de> Deserialize<'de> for Number { fn deserialize>(deserializer: D) -> Result { struct NumberVisitor; diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index b6119cf26..99c56b5b8 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -14,29 +14,85 @@ //! [`Writer::finish`] flushes the buffer and yields a [`Handle`] that can be inserted //! into a [`super::Record`]. -use std::{collections::HashSet, fs::File, io::BufWriter, path::PathBuf, sync::Mutex}; +use std::{fs::File, io::BufWriter}; + +#[cfg(feature = "disk")] +use std::{collections::HashSet, path::PathBuf, sync::Mutex}; use crate::save::{Error, Handle, Result, Value}; -/// The owned context behind a [`Context`]. +/// The backing store for a save operation. +/// +/// A `SaveContext` decides where side-car artifacts are written ([`SaveContext::write`]) +/// and how the final manifest is committed ([`SaveContext::finish`]). The default, +/// disk-backed implementation (`DiskContext`) lives in this module under the `disk` +/// feature; alternative implementations (e.g. a virtual filesystem or a purely in-memory +/// store) can be supplied for testing or to avoid touching the filesystem. +/// +/// The generic [`save`](super::save) entry point is parameterized over this trait so +/// that the base crate carries no hard dependency on any particular implementation. +pub trait SaveContext { + /// The value produced once the manifest has been committed by + /// [`SaveContext::finish`]. For the disk-backed context this is `()`. + type Output; + + /// Allocate a new side-car artifact named `key`. + /// + /// # Errors + /// + /// Returns [`Error`] if `key` is not a simple relative file name, if it has already + /// been registered, or if the underlying artifact cannot be created. + fn write(&self, key: &str) -> Result>; + + /// Commit the manifest `value`, consuming the context. + /// + /// # Errors + /// + /// Returns [`Error`] if the manifest cannot be serialized or committed. + fn finish(self, value: Value<'_>) -> Result; +} + +/// Object-safe view of the artifact-allocating half of a [`SaveContext`]. +/// +/// [`Context`] holds a `&dyn GetWrite` so that the same handle can be threaded through +/// every [`Save::save`](super::Save) impl regardless of the concrete context type. +/// [`SaveContext::finish`] (which consumes `self` and names an associated type) is not +/// object safe, so it is deliberately excluded from this trait. +pub(super) trait GetWrite { + fn write(&self, key: &str) -> Result>; +} + +impl GetWrite for T +where + T: SaveContext, +{ + fn write(&self, key: &str) -> Result> { + ::write(self, key) + } +} + +/// The disk-backed [`SaveContext`]. /// /// Holds the manifest directory, the manifest path, and the set of artifact file names /// registered so far. Lookup and insertion go through a [`Mutex`] so that concurrent /// [`Save`](super::Save) impls cannot accidentally hand out the same artifact name twice. +#[cfg(feature = "disk")] #[derive(Debug)] -pub(super) struct ContextInner { +pub(super) struct DiskContext { dir: PathBuf, metadata: PathBuf, files: Mutex>, } +#[cfg(feature = "disk")] #[derive(serde::Serialize)] struct Final<'a> { files: Vec<&'a str>, value: &'a Value<'a>, } -impl ContextInner { +#[cfg(feature = "disk")] +impl DiskContext { // TODO: Error if the directory looks bad? pub(super) fn new(dir: PathBuf, metadata: PathBuf) -> Self { Self { @@ -45,12 +101,13 @@ impl ContextInner { files: Mutex::new(HashSet::new()), } } +} - pub(super) fn context(&self) -> Context<'_> { - Context { inner: self } - } +#[cfg(feature = "disk")] +impl SaveContext for DiskContext { + type Output = (); - pub(super) fn write(&self, key: &str) -> Result> { + fn write(&self, key: &str) -> Result> { // Reject absolute paths, parent traversal, and multi-component paths. Handles must be // simple file names relative to the manifest directory. let mut components = std::path::Path::new(key).components(); @@ -97,7 +154,7 @@ impl ContextInner { /// Writes the manifest JSON atomically: serializes to a `.temp` file first, /// then renames it into place. Fails if the temp file already exists (an in-flight /// save is in progress, or a previous run aborted between rename steps). - pub fn finish(self, value: Value<'_>) -> Result<()> { + fn finish(self, value: Value<'_>) -> Result<()> { let files = self .files .into_inner() @@ -143,13 +200,18 @@ impl ContextInner { /// `Context` exposes one operation — [`Context::write`] — for allocating a side-car /// artifact. The same context is passed to nested [`Save`](super::Save) impls (typically /// via the [`save_fields!`](crate::save_fields) macro), so a single save tree shares -/// artifact-name bookkeeping. -#[derive(Debug, Clone)] +/// artifact-name bookkeeping. It borrows the backing [`SaveContext`] as an object-safe +/// `GetWrite` so that the save tree is agnostic to the concrete context type. +#[derive(Clone)] pub struct Context<'a> { - inner: &'a ContextInner, + inner: &'a dyn GetWrite, } impl<'a> Context<'a> { + pub(super) fn new(inner: &'a dyn GetWrite) -> Self { + Self { inner } + } + /// Allocate a new side-car artifact named `key` in the manifest directory. /// /// The returned [`Writer`] is positioned at offset 0 and implements diff --git a/diskann-record/src/save/error.rs b/diskann-record/src/save/error.rs index 5262fad2a..7ea110704 100644 --- a/diskann-record/src/save/error.rs +++ b/diskann-record/src/save/error.rs @@ -18,7 +18,7 @@ pub type Result = ::std::result::Result; /// /// Wraps [`anyhow::Error`] for rich context chains (see [`Error::context`]) and is /// returned from every fallible save-side operation, including [`super::Save::save`] -/// impls and [`super::save_to_disk`]. +/// impls and `save_to_disk`. #[derive(Debug)] pub struct Error { inner: anyhow::Error, diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index f1d4534d5..c54f495c8 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -9,7 +9,8 @@ //! (or, for primitive-like leaves, [`Saveable`]) and obtain a [`Context`] from which they //! request side-car artifact writers and assemble a [`Record`] of named fields. //! -//! The top-level entry point is [`save_to_disk`], which serializes a value into a +//! The generic entry point is [`save`]; `save_to_disk` (available under the `disk` +//! feature) is the disk-backed convenience wrapper that serializes a value into a //! caller-chosen directory plus a manifest path. //! //! # Building Records @@ -27,13 +28,33 @@ pub use crate::value::{Handle, Keys, Record, Value, Versioned}; mod context; -pub use context::{Context, Writer}; +pub use context::{Context, SaveContext, Writer}; mod error; pub use error::{Error, Result}; use crate::Version; +/// Serialize `x` into `context`, returning the context's committed output. +/// +/// This is the generic entry point: it threads a [`Context`] borrowing `context` through +/// the value's [`Saveable::save`] impl, then commits the resulting manifest via +/// [`SaveContext::finish`]. `save_to_disk` is the disk-backed convenience wrapper +/// available under the `disk` feature. +/// +/// # Errors +/// +/// Returns [`Error`] if a user impl returns an error or if the context fails to commit +/// the manifest. +pub fn save(x: &T, context: C) -> Result +where + T: Saveable, + C: SaveContext, +{ + let value = x.save(Context::new(&context))?; + context.finish(value) +} + /// Serialize `x` to disk. /// /// The manifest (a JSON document) is written atomically to `metadata`; any side-car @@ -44,6 +65,7 @@ use crate::Version; /// /// Returns [`Error`] if the directory cannot be written to, if the manifest cannot be /// serialized, or if a user impl returns an error. +#[cfg(feature = "disk")] pub fn save_to_disk( x: &T, dir: impl AsRef, @@ -52,9 +74,8 @@ pub fn save_to_disk( where T: Saveable, { - let inner = context::ContextInner::new(dir.as_ref().into(), metadata.as_ref().into()); - let value = x.save(inner.context())?; - inner.finish(value) + let context = context::DiskContext::new(dir.as_ref().into(), metadata.as_ref().into()); + save(x, context) } /// Implemented by user types that map to a versioned [`Record`]. diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index a62096649..9554eaab7 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -27,6 +27,7 @@ use std::{borrow::Cow, collections::HashMap}; +#[cfg(feature = "serde")] use serde::{ Deserialize, Deserializer, Serialize, Serializer, de::{self, MapAccess, SeqAccess, Visitor}, @@ -52,6 +53,7 @@ pub enum Value<'a> { Handle(Handle), } +#[cfg(feature = "serde")] impl Serialize for Value<'_> { fn serialize(&self, ser: S) -> Result { match self { @@ -67,6 +69,7 @@ impl Serialize for Value<'_> { } } +#[cfg(feature = "serde")] impl<'de> Deserialize<'de> for Value<'static> { fn deserialize(deserializer: D) -> Result where @@ -200,8 +203,9 @@ impl From for Value<'_> { /// [`Versioned`] [`Value::Object`] ready for insertion into another record; on the load /// side the same record is read back through [`crate::load::Object`]. Keys beginning /// with `$` are reserved for framework metadata (see [`crate::is_reserved`]). -#[derive(Debug, Serialize, Deserialize)] -#[serde(transparent)] +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub struct Record<'a> { record: HashMap, Value<'a>>, } @@ -304,11 +308,12 @@ impl<'a> FromIterator<(Cow<'a, str>, Value<'a>)> for Record<'a> { /// /// Serialized as a normal object plus a `$version` field on the wire. Constructed by /// [`Record::into_value`]. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Versioned<'a> { - #[serde(flatten)] + #[cfg_attr(feature = "serde", serde(flatten))] record: Record<'a>, - #[serde(rename = "$version")] + #[cfg_attr(feature = "serde", serde(rename = "$version"))] version: Version, } @@ -345,6 +350,7 @@ impl Handle { } } +#[cfg(feature = "serde")] impl Serialize for Handle { fn serialize(&self, ser: S) -> Result { let mut handle = ser.serialize_struct("Handle", 1)?; @@ -353,6 +359,7 @@ impl Serialize for Handle { } } +#[cfg(feature = "serde")] impl<'de> Deserialize<'de> for Handle { fn deserialize(deserializer: D) -> Result where diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs index cbd7eab6f..9dee90eb5 100644 --- a/diskann-record/src/version.rs +++ b/diskann-record/src/version.rs @@ -5,6 +5,7 @@ //! Semver-style version stamps embedded in every saved object. +#[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; /// A semver-style schema version attached to every saved record. @@ -38,6 +39,7 @@ impl Version { } } +#[cfg(feature = "serde")] impl Serialize for Version { fn serialize(&self, ser: S) -> Result { ser.collect_str(&format_args!( @@ -47,6 +49,7 @@ impl Serialize for Version { } } +#[cfg(feature = "serde")] impl<'de> Deserialize<'de> for Version { fn deserialize>(de: D) -> Result { struct VersionVisitor; From 6cea2514caf4e4b7ceb950641804b36dff0ee0d6 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 22 Jun 2026 11:33:03 -0700 Subject: [PATCH 06/23] added file-level tests to improve coverage --- diskann-record/src/load/context.rs | 38 ++++++++++++++++ diskann-record/src/load/error.rs | 19 ++++++++ diskann-record/src/number.rs | 28 ++++++++++++ diskann-record/src/save/context.rs | 71 ++++++++++++++++++++++++++++-- diskann-record/src/save/mod.rs | 2 +- diskann-record/src/value.rs | 34 ++++++++++++++ diskann-record/src/version.rs | 26 +++++++++++ 7 files changed, 213 insertions(+), 5 deletions(-) diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index aec198930..c34139ae8 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -435,3 +435,41 @@ impl<'a> Iterator for Iter<'a> { } impl ExactSizeIterator for Iter<'_> {} + +#[cfg(all(test, feature = "disk"))] +mod tests { + use super::*; + + fn write_manifest(dir: &Path, files: &[&str]) -> PathBuf { + let manifest = serde_json::json!({ + "files": files, + "value": { "$version": "0.0.0" }, + }); + let metadata = dir.join("metadata.json"); + std::fs::write(&metadata, serde_json::to_vec(&manifest).unwrap()).unwrap(); + metadata + } + + #[test] + fn read_rejects_unregistered_file() { + let dir = tempfile::tempdir().unwrap(); + let metadata = write_manifest(dir.path(), &[]); + let ctx = DiskContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("artifact.bin") else { + panic!("an unregistered file must be rejected"); + }; + assert!(format!("{err}").contains("not registered in the manifest")); + } + + #[test] + fn read_rejects_escaping_handle() { + let dir = tempfile::tempdir().unwrap(); + // Register the escaping name so only the path-shape check can reject it. + let metadata = write_manifest(dir.path(), &["../escape.bin"]); + let ctx = DiskContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("../escape.bin") else { + panic!("a handle escaping the manifest directory must be rejected"); + }; + assert!(format!("{err}").contains("escapes the manifest directory")); + } +} diff --git a/diskann-record/src/load/error.rs b/diskann-record/src/load/error.rs index df17daf23..9ade95790 100644 --- a/diskann-record/src/load/error.rs +++ b/diskann-record/src/load/error.rs @@ -198,3 +198,22 @@ impl From for Error { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context_preserves_recoverable_flag() { + assert!( + Error::from(Kind::TypeMismatch) + .context("extra") + .is_recoverable() + ); + assert!( + !Error::from(Kind::MissingFile) + .context("extra") + .is_recoverable() + ); + } +} diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index 15b9644aa..c7c97643e 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -178,3 +178,31 @@ try_from!( f32 => as_f32, f64 => as_f64, ); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn integer_accessors_range_check() { + assert_eq!(Number::U64(255).as_u8(), Some(255)); + assert_eq!(Number::U64(256).as_u8(), None); + assert_eq!(Number::I64(-1).as_u8(), None); + assert_eq!(Number::I64(-5).as_i8(), Some(-5)); + assert_eq!(Number::I64(-129).as_i8(), None); + } + + #[test] + fn float_to_integer_requires_integral_value() { + assert_eq!(Number::F64(2.0).as_u32(), Some(2)); + assert_eq!(Number::F64(2.5).as_u32(), None); + assert_eq!(Number::F64(-1.0).as_u32(), None); + } + + #[test] + fn try_from_surfaces_out_of_range() { + assert!(u8::try_from(Number::U64(300)).is_err()); + assert_eq!(u16::try_from(Number::U64(300)).unwrap(), 300); + assert!(usize::try_from(Number::I64(-1)).is_err()); + } +} diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index 99c56b5b8..dd4ef43b1 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -93,13 +93,34 @@ struct Final<'a> { #[cfg(feature = "disk")] impl DiskContext { - // TODO: Error if the directory looks bad? - pub(super) fn new(dir: PathBuf, metadata: PathBuf) -> Self { - Self { + /// Create a disk-backed save context targeting `dir` for side-car artifacts and + /// `metadata` for the manifest. Validates that `dir` is an actual directory. + /// + /// # Errors + /// + /// Returns [`Error`] if `dir` does not exist, cannot be inspected, or exists but is + /// not a directory. + pub(super) fn new(dir: PathBuf, metadata: PathBuf) -> Result { + match std::fs::metadata(&dir) { + Ok(meta) if meta.is_dir() => {} + Ok(_) => { + return Err(Error::message(format!( + "path {} exists but is not a directory", + dir.display() + ))); + } + Err(err) => { + return Err( + Error::new(err).context(format!("while validating path {}", dir.display())) + ); + } + } + + Ok(Self { dir, metadata, files: Mutex::new(HashSet::new()), - } + }) } } @@ -288,3 +309,45 @@ impl std::io::Seek for Writer<'_> { self.io.seek_relative(offset) } } + +#[cfg(all(test, feature = "disk"))] +mod tests { + use super::*; + + #[test] + fn new_rejects_nonexistent_directory() { + let missing = PathBuf::from("does/not/exist/anywhere/at/all"); + let err = DiskContext::new(missing, "meta.json".into()) + .expect_err("a nonexistent directory must be rejected"); + assert!(format!("{err}").contains("while validating path")); + } + + #[test] + fn new_rejects_file_as_directory() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("not_a_dir"); + std::fs::write(&file, b"hi").unwrap(); + let err = DiskContext::new(file, dir.path().join("meta.json")) + .expect_err("a file path must be rejected as a directory"); + assert!(format!("{err}").contains("is not a directory")); + } + + #[test] + fn write_rejects_path_separators_and_traversal() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + for bad in ["sub/dir.bin", "../escape.bin", "/abs.bin"] { + SaveContext::write(&ctx, bad).expect_err("keys with path separators must be rejected"); + } + } + + #[test] + fn write_rejects_duplicate_key() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + let _writer = SaveContext::write(&ctx, "artifact.bin").unwrap(); + let err = SaveContext::write(&ctx, "artifact.bin") + .expect_err("a duplicate artifact key must be rejected"); + assert!(format!("{err}").contains("already been registered")); + } +} diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index c54f495c8..1d938e3b1 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -74,7 +74,7 @@ pub fn save_to_disk( where T: Saveable, { - let context = context::DiskContext::new(dir.as_ref().into(), metadata.as_ref().into()); + let context = context::DiskContext::new(dir.as_ref().into(), metadata.as_ref().into())?; save(x, context) } diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index 9554eaab7..2daa1b12a 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -374,3 +374,37 @@ impl<'de> Deserialize<'de> for Handle { Ok(Handle(helper.handle)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_rejects_reserved_key() { + let mut record = Record::empty(); + record + .insert("$version", Value::Null) + .expect_err("reserved key must be rejected"); + record + .insert("ok", Value::Bool(true)) + .expect("normal key must be accepted"); + assert!(record.contains_key("ok")); + assert_eq!(record.len(), 1); + } + + #[cfg(feature = "disk")] + #[test] + fn deserialize_rejects_handle_with_extra_fields() { + let json = r#"{ "$handle": "a.bin", "$version": "0.0.0" }"#; + serde_json::from_str::>(json) + .expect_err("handle object with extra fields must be rejected"); + } + + #[cfg(feature = "disk")] + #[test] + fn deserialize_rejects_object_without_version_or_handle() { + let json = r#"{ "field": 1 }"#; + serde_json::from_str::>(json) + .expect_err("object without $version or $handle must be rejected"); + } +} diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs index 9dee90eb5..fc823caae 100644 --- a/diskann-record/src/version.rs +++ b/diskann-record/src/version.rs @@ -83,3 +83,29 @@ impl<'de> Deserialize<'de> for Version { de.deserialize_str(VersionVisitor) } } + +#[cfg(all(test, feature = "disk"))] +mod tests { + use super::*; + + #[test] + fn serializes_as_dotted_string() { + let json = serde_json::to_string(&Version::new(1, 2, 3)).unwrap(); + assert_eq!(json, "\"1.2.3\""); + } + + #[test] + fn round_trips_through_json() { + let v = Version::new(4, 5, 6); + let back: Version = serde_json::from_str(&serde_json::to_string(&v).unwrap()).unwrap(); + assert_eq!(v, back); + } + + #[test] + fn rejects_malformed_strings() { + for bad in ["\"1.2\"", "\"1.2.3.4\"", "\"1.x.3\"", "\"abc\""] { + serde_json::from_str::(bad) + .expect_err("malformed version string must be rejected"); + } + } +} From 14ceeccad64b7b99a66529082d2b7e1b3ffb2bbd Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 22 Jun 2026 11:57:39 -0700 Subject: [PATCH 07/23] Changed SaveContext::write to take Option<&str>. This allows us to write a Vec where each struct might produce a file with the same key --- diskann-record/src/lib.rs | 2 +- diskann-record/src/save/context.rs | 108 +++++++++++++++++++---------- 2 files changed, 74 insertions(+), 36 deletions(-) diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index 76f5ef034..e7250d266 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -170,7 +170,7 @@ mod tests { // We save `x`, `y`, and `inner` directly into the manifest. // The raw vector data we instead store in an auxiliary file. - let mut io = context.write("auxiliary.bin")?; + let mut io = context.write(Some("auxiliary.bin"))?; io.write_all(&self.vector).map_err(save::Error::new)?; let mut record = save_fields!(self, context, [x, y, enabled, inner, nickname, absent]); diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index dd4ef43b1..9801211d2 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -36,13 +36,18 @@ pub trait SaveContext { /// [`SaveContext::finish`]. For the disk-backed context this is `()`. type Output; - /// Allocate a new side-car artifact named `key`. + /// Allocate a new side-car artifact, optionally tagged with a human-readable `key`. + /// + /// The artifact's on-disk name (and the value stored in its [`Handle`]) is prefixed + /// with the count of artifacts written so far; when `key` is `Some`, it is appended as + /// `{count}-{key}` purely to aid debugging. Reusing the same `key` across calls is + /// allowed — the count prefix disambiguates the artifacts. /// /// # Errors /// - /// Returns [`Error`] if `key` is not a simple relative file name, if it has already - /// been registered, or if the underlying artifact cannot be created. - fn write(&self, key: &str) -> Result>; + /// Returns [`Error`] if `key` is `Some` but not a simple relative file name, or if the + /// underlying artifact cannot be created. + fn write(&self, key: Option<&str>) -> Result>; /// Commit the manifest `value`, consuming the context. /// @@ -59,14 +64,14 @@ pub trait SaveContext { /// [`SaveContext::finish`] (which consumes `self` and names an associated type) is not /// object safe, so it is deliberately excluded from this trait. pub(super) trait GetWrite { - fn write(&self, key: &str) -> Result>; + fn write(&self, key: Option<&str>) -> Result>; } impl GetWrite for T where T: SaveContext, { - fn write(&self, key: &str) -> Result> { + fn write(&self, key: Option<&str>) -> Result> { ::write(self, key) } } @@ -128,32 +133,43 @@ impl DiskContext { impl SaveContext for DiskContext { type Output = (); - fn write(&self, key: &str) -> Result> { - // Reject absolute paths, parent traversal, and multi-component paths. Handles must be - // simple file names relative to the manifest directory. - let mut components = std::path::Path::new(key).components(); - match components.next() { - Some(std::path::Component::Normal(_)) if components.next().is_none() => {} - _ => { - return Err(Error::message(format!( - "artifact file name {:?} must be a relative file name with no path separators", - key, - ))); + fn write(&self, key: Option<&str>) -> Result> { + // When a human-readable hint is supplied it must be a simple relative file name: + // reject absolute paths, parent traversal, and multi-component paths so the prefix + // below produces a single, well-formed file name in the manifest directory. + if let Some(key) = key { + let mut components = std::path::Path::new(key).components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err(Error::message(format!( + "artifact file name hint {:?} must be a relative file name with no path \ + separators", + key, + ))); + } } } - // TODO: Proper disambiguation - making UUIDs etc. let mut files = self .files .lock() .unwrap_or_else(|poison| poison.into_inner()); - if !files.insert(key.into()) { + + // Prefix each artifact with the count of artifacts written so far, so that reusing + // the same `key` (or omitting it) still yields a unique file name. + let name = match key { + Some(key) => format!("{:03}-{}", files.len(), key), + None => format!("{:03}", files.len()), + }; + + if !files.insert(name.clone()) { return Err(Error::message(format!( - "file name {:?} has already been registered with this save context", - key, + "generated artifact name {:?} collides with an existing artifact", + name, ))); } - let full = self.dir.join(key); + let full = self.dir.join(&name); if full.exists() { return Err(Error::message(format!( "file {} already exists", @@ -165,7 +181,7 @@ impl SaveContext for DiskContext { })?; Ok(Writer { io: BufWriter::new(file), - name: key.into(), + name, _lifetime: std::marker::PhantomData, }) } @@ -233,18 +249,20 @@ impl<'a> Context<'a> { Self { inner } } - /// Allocate a new side-car artifact named `key` in the manifest directory. + /// Allocate a new side-car artifact in the manifest directory, optionally tagging it + /// with a human-readable `key`. /// - /// The returned [`Writer`] is positioned at offset 0 and implements + /// The artifact is named with the count of artifacts written so far (with `key`, when + /// `Some`, appended as a readability hint), so the same `key` may be passed to + /// multiple calls. The returned [`Writer`] is positioned at offset 0 and implements /// [`std::io::Write`] / [`std::io::Seek`]. Call [`Writer::finish`] to obtain a /// [`Handle`] that may be inserted into a [`Record`](super::Record). /// /// # Errors /// - /// Returns [`Error`] if `key` has already been registered with this context (names - /// must be unique within a single save), or if the underlying file cannot be created - /// (e.g. because the artifact already exists on disk). - pub fn write(&self, key: &str) -> Result> { + /// Returns [`Error`] if `key` is `Some` but not a simple relative file name, or if the + /// underlying file cannot be created. + pub fn write(&self, key: Option<&str>) -> Result> { self.inner.write(key) } } @@ -337,17 +355,37 @@ mod tests { let dir = tempfile::tempdir().unwrap(); let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); for bad in ["sub/dir.bin", "../escape.bin", "/abs.bin"] { - SaveContext::write(&ctx, bad).expect_err("keys with path separators must be rejected"); + SaveContext::write(&ctx, Some(bad)) + .expect_err("keys with path separators must be rejected"); } } #[test] - fn write_rejects_duplicate_key() { + fn write_allows_duplicate_key() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + let first = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap(); + let second = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap(); + assert_ne!( + first.as_str(), + second.as_str(), + "duplicate keys must be disambiguated by the count prefix" + ); + assert_eq!(first.as_str(), "000-artifact.bin"); + assert_eq!(second.as_str(), "001-artifact.bin"); + } + + #[test] + fn write_allows_anonymous_artifact() { let dir = tempfile::tempdir().unwrap(); let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); - let _writer = SaveContext::write(&ctx, "artifact.bin").unwrap(); - let err = SaveContext::write(&ctx, "artifact.bin") - .expect_err("a duplicate artifact key must be rejected"); - assert!(format!("{err}").contains("already been registered")); + let handle = SaveContext::write(&ctx, None).unwrap().finish().unwrap(); + assert!(!handle.as_str().is_empty()); } } From 407ff326f08c7ebccc488efd431cc437a26447ff Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 23 Jun 2026 14:20:32 -0700 Subject: [PATCH 08/23] gave Reader a new(); added Display and from_str to Version, serde calls into it now; added in-memory ONLY variant of SaveContext and LoadContext --> moved to backend/; added an enum Backend to choose between Disk*Context and InMemory*Context; moved Disk*Context to backend/ --- diskann-record/src/backend/disk.rs | 339 +++++++++++++++++++++++++++ diskann-record/src/backend/memory.rs | 194 +++++++++++++++ diskann-record/src/backend/mod.rs | 9 + diskann-record/src/lib.rs | 3 + diskann-record/src/load/context.rs | 163 ++----------- diskann-record/src/load/mod.rs | 4 +- diskann-record/src/save/context.rs | 338 ++++++++------------------ diskann-record/src/save/mod.rs | 3 +- diskann-record/src/value.rs | 43 ++++ diskann-record/src/version.rs | 62 +++-- 10 files changed, 753 insertions(+), 405 deletions(-) create mode 100644 diskann-record/src/backend/disk.rs create mode 100644 diskann-record/src/backend/memory.rs create mode 100644 diskann-record/src/backend/mod.rs diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs new file mode 100644 index 000000000..01a258b16 --- /dev/null +++ b/diskann-record/src/backend/disk.rs @@ -0,0 +1,339 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Disk-backed save/load contexts. +//! +//! [`DiskSaveContext`] writes the manifest as JSON plus side-car artifact files into a +//! directory; [`DiskLoadContext`] reads them back. The two halves are independent and +//! communicate only through the filesystem. Both are available under the `disk` feature. + +use std::{ + collections::HashSet, + io::BufReader, + path::{Path, PathBuf}, + sync::Mutex, +}; + +use crate::{ + load::{self, LoadContext, Reader}, + save::{self, SaveContext, Value, Writer}, +}; + +/// The disk-backed [`SaveContext`]. +/// +/// Holds the manifest directory, the manifest path, and the set of artifact file names +/// registered so far. Lookup and insertion go through a [`Mutex`] so that concurrent +/// [`Save`](crate::save::Save) impls cannot accidentally hand out the same artifact name +/// twice. +#[derive(Debug)] +pub(crate) struct DiskSaveContext { + dir: PathBuf, + metadata: PathBuf, + files: Mutex>, +} + +#[derive(serde::Serialize)] +struct Final<'a> { + files: Vec<&'a str>, + value: &'a Value<'a>, +} + +impl DiskSaveContext { + /// Create a disk-backed save context targeting `dir` for side-car artifacts and + /// `metadata` for the manifest. Validates that `dir` is an actual directory. + /// + /// # Errors + /// + /// Returns [`save::Error`] if `dir` does not exist, cannot be inspected, or exists but + /// is not a directory. + pub(crate) fn new(dir: PathBuf, metadata: PathBuf) -> save::Result { + match std::fs::metadata(&dir) { + Ok(meta) if meta.is_dir() => {} + Ok(_) => { + return Err(save::Error::message(format!( + "path {} exists but is not a directory", + dir.display() + ))); + } + Err(err) => { + return Err(save::Error::new(err) + .context(format!("while validating path {}", dir.display()))); + } + } + + Ok(Self { + dir, + metadata, + files: Mutex::new(HashSet::new()), + }) + } +} + +impl SaveContext for DiskSaveContext { + type Output = (); + + fn write(&self, key: Option<&str>) -> save::Result> { + // When a human-readable hint is supplied it must be a simple relative file name: + // reject absolute paths, parent traversal, and multi-component paths so the prefix + // below produces a single, well-formed file name in the manifest directory. + if let Some(key) = key { + let mut components = std::path::Path::new(key).components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err(save::Error::message(format!( + "artifact file name hint {:?} must be a relative file name with no path \ + separators", + key, + ))); + } + } + } + + let mut files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + // Prefix each artifact with the count of artifacts written so far, so that reusing + // the same `key` (or omitting it) still yields a unique file name. + let name = match key { + Some(key) => format!("{:03}-{}", files.len(), key), + None => format!("{:03}", files.len()), + }; + + if !files.insert(name.clone()) { + return Err(save::Error::message(format!( + "generated artifact name {:?} collides with an existing artifact", + name, + ))); + } + let full = self.dir.join(&name); + if full.exists() { + return Err(save::Error::message(format!( + "file {} already exists", + full.display() + ))); + } + let file = std::fs::File::create_new(&full).map_err(|err| { + save::Error::new(err).context(format!("while creating new file {}", full.display())) + })?; + Ok(Writer::file(name, file)) + } + + /// Finalize the manifest. + /// + /// Writes the manifest JSON atomically: serializes to a `.temp` file first, + /// then renames it into place. Fails if the temp file already exists (an in-flight + /// save is in progress, or a previous run aborted between rename steps). + fn finish(self, value: Value<'_>) -> save::Result<()> { + let files = self + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); + let f = Final { + files: files.iter().map(|k| &**k).collect(), + value: &value, + }; + + // Fail if the temp file already exists + let mut temp = self.metadata.clone().into_os_string(); + temp.push(".temp"); + let temp = PathBuf::from(temp); + let buffer = std::fs::File::create_new(&temp).map_err(|err| { + if err.kind() == std::io::ErrorKind::AlreadyExists { + save::Error::message(format!( + "Temporary file {} already exists. Aborting!", + temp.display() + )) + } else { + save::Error::new(err).context(format!( + "while creating temp manifest file {}", + temp.display() + )) + } + })?; + + serde_json::to_writer_pretty(buffer, &f) + .map_err(|err| save::Error::new(err).context("while serializing manifest to JSON"))?; + std::fs::rename(&temp, &self.metadata).map_err(|err| { + save::Error::new(err).context(format!( + "while renaming temp manifest {} to final path {}", + temp.display(), + self.metadata.display() + )) + })?; + Ok(()) + } +} + +/// The disk-backed [`LoadContext`]. +/// +/// Reads the manifest produced by [`DiskSaveContext`] and resolves side-car artifact +/// handles against the manifest directory. +#[derive(Debug)] +pub(crate) struct DiskLoadContext { + dir: PathBuf, + files: HashSet, + value: Value<'static>, +} + +#[derive(serde::Deserialize)] +struct FileRepr { + files: HashSet, + value: Value<'static>, +} + +impl DiskLoadContext { + pub(crate) fn new(metadata: &Path, dir: &Path) -> load::Result { + let file = std::fs::File::open(metadata).map_err(|e| { + load::Error::new(e).context(format!("while trying to open {}", metadata.display())) + })?; + + let reader = BufReader::new(file); + let repr: FileRepr = serde_json::from_reader(reader) + .map_err(|e| load::Error::new(e).context("could not deserialize manifest"))?; + + Ok(Self { + dir: dir.into(), + files: repr.files, + value: repr.value, + }) + } +} + +impl LoadContext for DiskLoadContext { + fn value(&self) -> load::Result<&Value<'_>> { + Ok(&self.value) + } + + fn read(&self, key: &str) -> load::Result> { + let key_as_path: &Path = key.as_ref(); + let mut components = key_as_path.components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err( + load::Error::from(load::error::Kind::MissingFile).context(format!( + "handle references file {:?} which escapes the manifest directory", + key, + )), + ); + } + } + if !self.files.contains(key_as_path) { + return Err( + load::Error::from(load::error::Kind::MissingFile).context(format!( + "handle references file {:?} which is not registered in the manifest", + key, + )), + ); + } + + let full = self.dir.join(key); + let file = std::fs::File::open(&full).map_err(|err| { + load::Error::new(err).context(format!("while opening artifact file {}", full.display())) + })?; + + Ok(Reader::new(Box::new(file))) + } +} + +#[cfg(test)] +mod tests { + use std::path::{Path, PathBuf}; + + use super::*; + + #[test] + fn new_rejects_nonexistent_directory() { + let missing = PathBuf::from("does/not/exist/anywhere/at/all"); + let err = DiskSaveContext::new(missing, "meta.json".into()) + .expect_err("a nonexistent directory must be rejected"); + assert!(format!("{err}").contains("while validating path")); + } + + #[test] + fn new_rejects_file_as_directory() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("not_a_dir"); + std::fs::write(&file, b"hi").unwrap(); + let err = DiskSaveContext::new(file, dir.path().join("meta.json")) + .expect_err("a file path must be rejected as a directory"); + assert!(format!("{err}").contains("is not a directory")); + } + + #[test] + fn write_rejects_path_separators_and_traversal() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + for bad in ["sub/dir.bin", "../escape.bin", "/abs.bin"] { + SaveContext::write(&ctx, Some(bad)) + .expect_err("keys with path separators must be rejected"); + } + } + + #[test] + fn write_allows_duplicate_key() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + let first = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap(); + let second = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap(); + assert_ne!( + first.as_str(), + second.as_str(), + "duplicate keys must be disambiguated by the count prefix" + ); + assert_eq!(first.as_str(), "000-artifact.bin"); + assert_eq!(second.as_str(), "001-artifact.bin"); + } + + #[test] + fn write_allows_anonymous_artifact() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + let handle = SaveContext::write(&ctx, None).unwrap().finish().unwrap(); + assert!(!handle.as_str().is_empty()); + } + + fn write_manifest(dir: &Path, files: &[&str]) -> PathBuf { + let manifest = serde_json::json!({ + "files": files, + "value": { "$version": "0.0.0" }, + }); + let metadata = dir.join("metadata.json"); + std::fs::write(&metadata, serde_json::to_vec(&manifest).unwrap()).unwrap(); + metadata + } + + #[test] + fn read_rejects_unregistered_file() { + let dir = tempfile::tempdir().unwrap(); + let metadata = write_manifest(dir.path(), &[]); + let ctx = DiskLoadContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("artifact.bin") else { + panic!("an unregistered file must be rejected"); + }; + assert!(format!("{err}").contains("not registered in the manifest")); + } + + #[test] + fn read_rejects_escaping_handle() { + let dir = tempfile::tempdir().unwrap(); + // Register the escaping name so only the path-shape check can reject it. + let metadata = write_manifest(dir.path(), &["../escape.bin"]); + let ctx = DiskLoadContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("../escape.bin") else { + panic!("a handle escaping the manifest directory must be rejected"); + }; + assert!(format!("{err}").contains("escapes the manifest directory")); + } +} diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs new file mode 100644 index 000000000..0bc9c1701 --- /dev/null +++ b/diskann-record/src/backend/memory.rs @@ -0,0 +1,194 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! In-memory save/load contexts. +//! +//! [`InMemorySaveContext`] and [`InMemoryContext`] mirror the disk-backed contexts but +//! keep the manifest value and every side-car artifact in memory. Saving through an +//! [`InMemorySaveContext`] yields an [`InMemoryContext`] (its +//! [`SaveContext::Output`](crate::save::SaveContext::Output)) that can be loaded directly +//! via [`crate::load::load`]. +//! +//! Unlike the disk path, this round trip never serializes through JSON: the manifest +//! [`Value`] is deep-copied to `'static` via [`Value::into_owned`] and side-car artifacts +//! are buffered in memory through a [`std::io::Cursor`]. As a result these contexts are +//! available regardless of the `disk` / `serde` features. + +use std::{collections::HashMap, io::Cursor, sync::Mutex}; + +use crate::{ + Value, + load::{self, LoadContext, Reader}, + save::{self, SaveContext, Writer}, +}; + +/// A save-side [`SaveContext`] that keeps every side-car artifact and the committed +/// manifest value in memory. +/// +/// [`SaveContext::finish`] consumes the context and returns an [`InMemoryContext`] ready +/// to be loaded with [`crate::load::load`]. +#[derive(Debug, Default)] +pub struct InMemorySaveContext { + files: Mutex>>, +} + +impl InMemorySaveContext { + /// Create an empty in-memory save context. + pub fn new() -> Self { + Self::default() + } +} + +impl SaveContext for InMemorySaveContext { + type Output = InMemoryContext; + + fn write(&self, key: Option<&str>) -> save::Result> { + // Mirror the disk context: a human-readable hint must be a simple relative file + // name so the generated artifact name is a single, well-formed key. + if let Some(key) = key { + let mut components = std::path::Path::new(key).components(); + match components.next() { + Some(std::path::Component::Normal(_)) if components.next().is_none() => {} + _ => { + return Err(save::Error::message(format!( + "artifact file name hint {:?} must be a relative file name with no path \ + separators", + key, + ))); + } + } + } + + let mut files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + // Prefix each artifact with the count of artifacts written so far, so that reusing + // the same `key` (or omitting it) still yields a unique name. + let name = match key { + Some(key) => format!("{:03}-{}", files.len(), key), + None => format!("{:03}", files.len()), + }; + + if files.contains_key(&name) { + return Err(save::Error::message(format!( + "generated artifact name {:?} collides with an existing artifact", + name, + ))); + } + // Reserve the name so the count advances and concurrent writers cannot collide; + // the placeholder is overwritten with the real bytes by `Writer::finish`. + files.insert(name.clone(), Vec::new()); + drop(files); + + Ok(Writer::memory(name, &self.files)) + } + + fn finish(self, value: Value<'_>) -> save::Result { + let files = self + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); + Ok(InMemoryContext { + files, + value: value.into_owned(), + }) + } +} + +/// A load-side [`LoadContext`] backed entirely by in-memory buffers. +/// +/// Produced by [`InMemorySaveContext`] via [`SaveContext::finish`]. Holds the committed +/// manifest [`Value`] and every side-car artifact as an in-memory byte buffer, so loading +/// never serializes through JSON or touches the filesystem. +#[derive(Debug)] +pub struct InMemoryContext { + files: HashMap>, + value: Value<'static>, +} + +impl LoadContext for InMemoryContext { + fn value(&self) -> load::Result<&Value<'_>> { + Ok(&self.value) + } + + fn read(&self, key: &str) -> load::Result> { + match self.files.get(key) { + Some(bytes) => Ok(Reader::new(Box::new(Cursor::new(bytes.as_slice())))), + None => Err( + load::Error::from(load::error::Kind::MissingFile).context(format!( + "handle references artifact {:?} which is not registered in this context", + key, + )), + ), + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + + use super::*; + use crate::{Version, load, save}; + + #[derive(Debug, PartialEq)] + struct Doc { + name: String, + blob: Vec, + } + + impl save::Save for Doc { + const VERSION: Version = Version::new(1, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + let mut io = context.write(Some("blob.bin"))?; + io.write_all(&self.blob).map_err(save::Error::new)?; + let mut record = crate::save_fields!(self, context, [name]); + record.insert("blob", io.finish()?)?; + Ok(record) + } + } + + impl load::Load<'_> for Doc { + const VERSION: Version = Version::new(1, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + crate::load_fields!(object, [name: String, blob: save::Handle]); + let mut io = object.read(&blob)?; + let mut blob = Vec::new(); + io.read_to_end(&mut blob).map_err(load::Error::new)?; + Ok(Self { name, blob }) + } + fn load_legacy(_: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[test] + fn round_trips_without_serde_or_disk() { + let doc = Doc { + name: "example".to_owned(), + blob: vec![1, 2, 3, 4, 5], + }; + + let context = save::save(&doc, InMemorySaveContext::new()).unwrap(); + let restored: Doc = load::load(&context).unwrap(); + + assert_eq!(doc, restored); + } + + #[test] + fn read_rejects_unregistered_artifact() { + let context = InMemoryContext { + files: HashMap::new(), + value: Value::Null, + }; + let err = context + .read("missing.bin") + .err() + .expect("an unregistered artifact must be rejected"); + assert!(format!("{err}").contains("not registered in this context")); + } +} diff --git a/diskann-record/src/backend/mod.rs b/diskann-record/src/backend/mod.rs new file mode 100644 index 000000000..639cd9fa9 --- /dev/null +++ b/diskann-record/src/backend/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +#[cfg(feature = "disk")] +pub(crate) mod disk; + +pub(crate) mod memory; diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index e7250d266..3788c0d44 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -94,6 +94,9 @@ pub use value::{Handle, Keys, Record, Value, Versioned}; pub mod load; pub mod save; +mod backend; +pub use backend::memory::{InMemoryContext, InMemorySaveContext}; + // Canonical wire width for `usize` and `isize` in manifests is 64 bits. Saving a value // on a 64-bit platform and loading it on a 32-bit platform (or vice versa) could silently // truncate values that exceed `u32::MAX` / `i32::MAX`. We therefore require a 64-bit diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index c34139ae8..9cbcd1d0f 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -15,18 +15,11 @@ //! * [`Object::read`] for side-car artifacts referenced by a //! [`save::Handle`](super::save::Handle). //! -//! [`Reader`] implements [`std::io::Read`] and [`std::io::Seek`] over the artifact file. +//! [`Reader`] implements [`std::io::Read`] over a side-car artifact, regardless of the +//! provider's backing store. -use std::{fs::File, io::BufReader}; +use std::io::BufReader; -#[cfg(feature = "disk")] -use std::{ - collections::HashSet, - path::{Path, PathBuf}, -}; - -#[cfg(feature = "disk")] -use crate::load::Error; use crate::{ Number, Version, load::{Loadable, Result, error}, @@ -37,9 +30,9 @@ use crate::{ /// /// A `LoadContext` supplies the root manifest [`save::Value`] ([`LoadContext::value`]) /// and resolves side-car artifacts referenced by handles ([`LoadContext::read`]). The -/// default, disk-backed implementation (`DiskContext`) lives in this module under the -/// `disk` feature; alternative implementations (e.g. a virtual filesystem or a purely -/// in-memory store) can be supplied for testing. +/// concrete implementations live under [`crate::backend`]: a disk-backed context (under +/// the `disk` feature) and an in-memory context. Alternative implementations (e.g. a +/// virtual filesystem) can be supplied for testing. /// /// The generic [`load`](super::load) entry point is parameterized over this trait, and /// [`Context`] / [`Object`] / `Array` borrow it as an object-safe `&dyn LoadContext` @@ -61,87 +54,27 @@ pub trait LoadContext { fn read(&self, key: &str) -> Result>; } -#[cfg(feature = "disk")] -#[derive(Debug, serde::Deserialize)] -pub(super) struct DiskContext { - dir: PathBuf, - files: HashSet, - value: save::Value<'static>, -} - -#[cfg(feature = "disk")] -#[derive(Debug, serde::Deserialize)] -struct FileRepr { - files: HashSet, - value: save::Value<'static>, -} - -#[cfg(feature = "disk")] -impl DiskContext { - pub(super) fn new(metadata: &Path, dir: &Path) -> Result { - let file = std::fs::File::open(metadata).map_err(|e| { - Error::new(e).context(format!("while trying to open {}", metadata.display())) - })?; - - let reader = std::io::BufReader::new(file); - let repr: FileRepr = serde_json::from_reader(reader) - .map_err(|e| Error::new(e).context("could not deserialize manifest"))?; - - let this = Self { - dir: dir.into(), - files: repr.files, - value: repr.value, - }; - Ok(this) - } +/// A borrowed reader over a side-car artifact. +/// +/// Produced by [`Object::read`]. Implements [`std::io::Read`] over whatever backing +/// store the [`LoadContext`] provides, so non-file-backed providers (like an in-memory byte buffer) can supply an +/// arbitrary [`std::io::Read`]. +pub struct Reader<'a> { + io: BufReader>, } -#[cfg(feature = "disk")] -impl LoadContext for DiskContext { - fn value(&self) -> Result<&save::Value<'_>> { - Ok(&self.value) - } - - fn read(&self, key: &str) -> Result> { - let key_as_path: &Path = key.as_ref(); - let mut components = key_as_path.components(); - match components.next() { - Some(std::path::Component::Normal(_)) if components.next().is_none() => {} - _ => { - return Err(Error::from(error::Kind::MissingFile).context(format!( - "handle references file {:?} which escapes the manifest directory", - key, - ))); - } - } - if !self.files.contains(key_as_path) { - return Err(Error::from(error::Kind::MissingFile).context(format!( - "handle references file {:?} which is not registered in the manifest", - key, - ))); +impl<'a> Reader<'a> { + /// Build a reader over an arbitrary borrowed [`std::io::Read`] source. + /// + /// Used by non-file-backed [`LoadContext`] implementations (e.g. the in-memory + /// context) to expose a side-car artifact backed by a [`std::io::Cursor`]. + pub(crate) fn new(io: Box) -> Self { + Self { + io: BufReader::new(io), } - - let full = self.dir.join(key); - let file = std::fs::File::open(&full).map_err(|err| { - Error::new(err).context(format!("while opening artifact file {}", full.display())) - })?; - let reader = Reader { - io: BufReader::new(file), - _lifetime: std::marker::PhantomData, - }; - - Ok(reader) } } -/// A borrowed reader over a side-car artifact. -/// -/// Produced by [`Object::read`]. Implements [`std::io::Read`] and [`std::io::Seek`]. -pub struct Reader<'a> { - io: BufReader, - _lifetime: std::marker::PhantomData<&'a ()>, -} - impl std::io::Read for Reader<'_> { // Required method fn read(&mut self, buf: &mut [u8]) -> std::io::Result { @@ -163,22 +96,6 @@ impl std::io::Read for Reader<'_> { } } -impl std::io::Seek for Reader<'_> { - fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { - self.io.seek(pos) - } - - fn rewind(&mut self) -> std::io::Result<()> { - self.io.rewind() - } - fn stream_position(&mut self) -> std::io::Result { - self.io.stream_position() - } - fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { - self.io.seek_relative(offset) - } -} - /////////////////////// // User facing types // /////////////////////// @@ -435,41 +352,3 @@ impl<'a> Iterator for Iter<'a> { } impl ExactSizeIterator for Iter<'_> {} - -#[cfg(all(test, feature = "disk"))] -mod tests { - use super::*; - - fn write_manifest(dir: &Path, files: &[&str]) -> PathBuf { - let manifest = serde_json::json!({ - "files": files, - "value": { "$version": "0.0.0" }, - }); - let metadata = dir.join("metadata.json"); - std::fs::write(&metadata, serde_json::to_vec(&manifest).unwrap()).unwrap(); - metadata - } - - #[test] - fn read_rejects_unregistered_file() { - let dir = tempfile::tempdir().unwrap(); - let metadata = write_manifest(dir.path(), &[]); - let ctx = DiskContext::new(&metadata, dir.path()).unwrap(); - let Err(err) = ctx.read("artifact.bin") else { - panic!("an unregistered file must be rejected"); - }; - assert!(format!("{err}").contains("not registered in the manifest")); - } - - #[test] - fn read_rejects_escaping_handle() { - let dir = tempfile::tempdir().unwrap(); - // Register the escaping name so only the path-shape check can reject it. - let metadata = write_manifest(dir.path(), &["../escape.bin"]); - let ctx = DiskContext::new(&metadata, dir.path()).unwrap(); - let Err(err) = ctx.read("../escape.bin") else { - panic!("a handle escaping the manifest directory must be rejected"); - }; - assert!(format!("{err}").contains("escapes the manifest directory")); - } -} diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs index 8f6a87e43..808745f7c 100644 --- a/diskann-record/src/load/mod.rs +++ b/diskann-record/src/load/mod.rs @@ -37,7 +37,7 @@ pub mod error; pub use error::{Error, Result}; mod context; -pub use context::{Context, LoadContext, Object}; +pub use context::{Context, LoadContext, Object, Reader}; #[cfg(feature = "disk")] use std::path::Path; @@ -78,7 +78,7 @@ pub fn load_from_disk(metadata: &Path, dir: &Path) -> Result where T: for<'a> Loadable<'a>, { - let context = context::DiskContext::new(metadata, dir)?; + let context = crate::backend::disk::DiskLoadContext::new(metadata, dir)?; load(&context) } diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index 9801211d2..248ab8448 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -14,20 +14,22 @@ //! [`Writer::finish`] flushes the buffer and yields a [`Handle`] that can be inserted //! into a [`super::Record`]. -use std::{fs::File, io::BufWriter}; - -#[cfg(feature = "disk")] -use std::{collections::HashSet, path::PathBuf, sync::Mutex}; +use std::{ + collections::HashMap, + fs::File, + io::{BufWriter, Cursor}, + sync::Mutex, +}; use crate::save::{Error, Handle, Result, Value}; /// The backing store for a save operation. /// /// A `SaveContext` decides where side-car artifacts are written ([`SaveContext::write`]) -/// and how the final manifest is committed ([`SaveContext::finish`]). The default, -/// disk-backed implementation (`DiskContext`) lives in this module under the `disk` -/// feature; alternative implementations (e.g. a virtual filesystem or a purely in-memory -/// store) can be supplied for testing or to avoid touching the filesystem. +/// and how the final manifest is committed ([`SaveContext::finish`]). The concrete +/// implementations live under [`crate::backend`]: a disk-backed context (under the `disk` +/// feature) and an in-memory context. Alternative implementations (e.g. a virtual +/// filesystem) can be supplied for testing or to avoid touching the filesystem. /// /// The generic [`save`](super::save) entry point is parameterized over this trait so /// that the base crate carries no hard dependency on any particular implementation. @@ -76,162 +78,6 @@ where } } -/// The disk-backed [`SaveContext`]. -/// -/// Holds the manifest directory, the manifest path, and the set of artifact file names -/// registered so far. Lookup and insertion go through a [`Mutex`] so that concurrent -/// [`Save`](super::Save) impls cannot accidentally hand out the same artifact name twice. -#[cfg(feature = "disk")] -#[derive(Debug)] -pub(super) struct DiskContext { - dir: PathBuf, - metadata: PathBuf, - files: Mutex>, -} - -#[cfg(feature = "disk")] -#[derive(serde::Serialize)] -struct Final<'a> { - files: Vec<&'a str>, - value: &'a Value<'a>, -} - -#[cfg(feature = "disk")] -impl DiskContext { - /// Create a disk-backed save context targeting `dir` for side-car artifacts and - /// `metadata` for the manifest. Validates that `dir` is an actual directory. - /// - /// # Errors - /// - /// Returns [`Error`] if `dir` does not exist, cannot be inspected, or exists but is - /// not a directory. - pub(super) fn new(dir: PathBuf, metadata: PathBuf) -> Result { - match std::fs::metadata(&dir) { - Ok(meta) if meta.is_dir() => {} - Ok(_) => { - return Err(Error::message(format!( - "path {} exists but is not a directory", - dir.display() - ))); - } - Err(err) => { - return Err( - Error::new(err).context(format!("while validating path {}", dir.display())) - ); - } - } - - Ok(Self { - dir, - metadata, - files: Mutex::new(HashSet::new()), - }) - } -} - -#[cfg(feature = "disk")] -impl SaveContext for DiskContext { - type Output = (); - - fn write(&self, key: Option<&str>) -> Result> { - // When a human-readable hint is supplied it must be a simple relative file name: - // reject absolute paths, parent traversal, and multi-component paths so the prefix - // below produces a single, well-formed file name in the manifest directory. - if let Some(key) = key { - let mut components = std::path::Path::new(key).components(); - match components.next() { - Some(std::path::Component::Normal(_)) if components.next().is_none() => {} - _ => { - return Err(Error::message(format!( - "artifact file name hint {:?} must be a relative file name with no path \ - separators", - key, - ))); - } - } - } - - let mut files = self - .files - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - // Prefix each artifact with the count of artifacts written so far, so that reusing - // the same `key` (or omitting it) still yields a unique file name. - let name = match key { - Some(key) => format!("{:03}-{}", files.len(), key), - None => format!("{:03}", files.len()), - }; - - if !files.insert(name.clone()) { - return Err(Error::message(format!( - "generated artifact name {:?} collides with an existing artifact", - name, - ))); - } - let full = self.dir.join(&name); - if full.exists() { - return Err(Error::message(format!( - "file {} already exists", - full.display() - ))); - } - let file = std::fs::File::create_new(&full).map_err(|err| { - Error::new(err).context(format!("while creating new file {}", full.display())) - })?; - Ok(Writer { - io: BufWriter::new(file), - name, - _lifetime: std::marker::PhantomData, - }) - } - - /// Finalize the manifest. - /// - /// Writes the manifest JSON atomically: serializes to a `.temp` file first, - /// then renames it into place. Fails if the temp file already exists (an in-flight - /// save is in progress, or a previous run aborted between rename steps). - fn finish(self, value: Value<'_>) -> Result<()> { - let files = self - .files - .into_inner() - .unwrap_or_else(|poison| poison.into_inner()); - let f = Final { - files: files.iter().map(|k| &**k).collect(), - value: &value, - }; - - // Fail if the temp file already exists - let mut temp = self.metadata.clone().into_os_string(); - temp.push(".temp"); - let temp = PathBuf::from(temp); - let buffer = std::fs::File::create_new(&temp).map_err(|err| { - if err.kind() == std::io::ErrorKind::AlreadyExists { - Error::message(format!( - "Temporary file {} already exists. Aborting!", - temp.display() - )) - } else { - Error::new(err).context(format!( - "while creating temp manifest file {}", - temp.display() - )) - } - })?; - - serde_json::to_writer_pretty(buffer, &f) - .map_err(|err| Error::new(err).context("while serializing manifest to JSON"))?; - std::fs::rename(&temp, &self.metadata).map_err(|err| { - Error::new(err).context(format!( - "while renaming temp manifest {} to final path {}", - temp.display(), - self.metadata.display() - )) - })?; - Ok(()) - } -} - /// A cheap, clonable handle threaded through every [`Save::save`](super::Save) impl. /// /// `Context` exposes one operation — [`Context::write`] — for allocating a side-car @@ -270,12 +116,48 @@ impl<'a> Context<'a> { /// A borrowed side-car artifact writer produced by [`Context::write`]. /// /// Implements [`std::io::Write`] and [`std::io::Seek`]. Writes are buffered; calling -/// [`Writer::finish`] flushes the buffer, closes the file, and returns a [`Handle`]. +/// [`Writer::finish`] flushes the buffer (or, for an in-memory context, deposits the +/// completed buffer into the store) and returns a [`Handle`]. #[derive(Debug)] pub struct Writer<'a> { - io: BufWriter, + inner: Backend<'a>, name: String, - _lifetime: std::marker::PhantomData<&'a ()>, +} + +/// The backing store a [`Writer`] writes into. +#[derive(Debug)] +enum Backend<'a> { + /// A file on disk; the bytes are persisted as they are written. + #[cfg_attr(not(feature = "disk"), allow(dead_code))] + File(BufWriter), + /// An in-memory buffer; on [`Writer::finish`] the completed buffer is inserted into + /// `store` under the writer's name. + Memory { + buffer: Cursor>, + store: &'a Mutex>>, + }, +} + +impl<'a> Writer<'a> { + /// Construct an in-memory writer that deposits its buffer into `store` on finish. + pub(crate) fn memory(name: String, store: &'a Mutex>>) -> Self { + Self { + inner: Backend::Memory { + buffer: Cursor::new(Vec::new()), + store, + }, + name, + } + } + + /// Construct a file-backed writer that streams bytes straight to `file`. + #[cfg(feature = "disk")] + pub(crate) fn file(name: String, file: File) -> Self { + Self { + inner: Backend::File(BufWriter::new(file)), + name, + } + } } impl Writer<'_> { @@ -285,107 +167,81 @@ impl Writer<'_> { /// [`Record::insert`](super::Record::insert)) so that load-side code can locate the /// artifact through the manifest. pub fn finish(self) -> Result { - // NOTE: self.io.into_inner() will flush the buffer and close the file. - self.io - .into_inner() - .map_err(|err| Error::new(err.into_error()))?; + match self.inner { + // NOTE: into_inner() will flush the buffer and close the file. + Backend::File(io) => { + io.into_inner() + .map_err(|err| Error::new(err.into_error()))?; + } + Backend::Memory { buffer, store } => { + store + .lock() + .unwrap_or_else(|poison| poison.into_inner()) + .insert(self.name.clone(), buffer.into_inner()); + } + } Ok(Handle::new(self.name)) } } impl std::io::Write for Writer<'_> { fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.io.write(buf) + match &mut self.inner { + Backend::File(io) => io.write(buf), + Backend::Memory { buffer, .. } => buffer.write(buf), + } } fn flush(&mut self) -> std::io::Result<()> { - self.io.flush() + match &mut self.inner { + Backend::File(io) => io.flush(), + Backend::Memory { buffer, .. } => buffer.flush(), + } } fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { - self.io.write_vectored(bufs) + match &mut self.inner { + Backend::File(io) => io.write_vectored(bufs), + Backend::Memory { buffer, .. } => buffer.write_vectored(bufs), + } } fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { - self.io.write_all(buf) + match &mut self.inner { + Backend::File(io) => io.write_all(buf), + Backend::Memory { buffer, .. } => buffer.write_all(buf), + } } fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { - self.io.write_fmt(args) + match &mut self.inner { + Backend::File(io) => io.write_fmt(args), + Backend::Memory { buffer, .. } => buffer.write_fmt(args), + } } } impl std::io::Seek for Writer<'_> { fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { - self.io.seek(pos) + match &mut self.inner { + Backend::File(io) => io.seek(pos), + Backend::Memory { buffer, .. } => buffer.seek(pos), + } } fn rewind(&mut self) -> std::io::Result<()> { - self.io.rewind() + match &mut self.inner { + Backend::File(io) => io.rewind(), + Backend::Memory { buffer, .. } => buffer.rewind(), + } } fn stream_position(&mut self) -> std::io::Result { - self.io.stream_position() + match &mut self.inner { + Backend::File(io) => io.stream_position(), + Backend::Memory { buffer, .. } => buffer.stream_position(), + } } fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { - self.io.seek_relative(offset) - } -} - -#[cfg(all(test, feature = "disk"))] -mod tests { - use super::*; - - #[test] - fn new_rejects_nonexistent_directory() { - let missing = PathBuf::from("does/not/exist/anywhere/at/all"); - let err = DiskContext::new(missing, "meta.json".into()) - .expect_err("a nonexistent directory must be rejected"); - assert!(format!("{err}").contains("while validating path")); - } - - #[test] - fn new_rejects_file_as_directory() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("not_a_dir"); - std::fs::write(&file, b"hi").unwrap(); - let err = DiskContext::new(file, dir.path().join("meta.json")) - .expect_err("a file path must be rejected as a directory"); - assert!(format!("{err}").contains("is not a directory")); - } - - #[test] - fn write_rejects_path_separators_and_traversal() { - let dir = tempfile::tempdir().unwrap(); - let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); - for bad in ["sub/dir.bin", "../escape.bin", "/abs.bin"] { - SaveContext::write(&ctx, Some(bad)) - .expect_err("keys with path separators must be rejected"); + match &mut self.inner { + Backend::File(io) => io.seek_relative(offset), + Backend::Memory { buffer, .. } => buffer.seek_relative(offset), } } - - #[test] - fn write_allows_duplicate_key() { - let dir = tempfile::tempdir().unwrap(); - let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); - let first = SaveContext::write(&ctx, Some("artifact.bin")) - .unwrap() - .finish() - .unwrap(); - let second = SaveContext::write(&ctx, Some("artifact.bin")) - .unwrap() - .finish() - .unwrap(); - assert_ne!( - first.as_str(), - second.as_str(), - "duplicate keys must be disambiguated by the count prefix" - ); - assert_eq!(first.as_str(), "000-artifact.bin"); - assert_eq!(second.as_str(), "001-artifact.bin"); - } - - #[test] - fn write_allows_anonymous_artifact() { - let dir = tempfile::tempdir().unwrap(); - let ctx = DiskContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); - let handle = SaveContext::write(&ctx, None).unwrap().finish().unwrap(); - assert!(!handle.as_str().is_empty()); - } } diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index 1d938e3b1..e7b9123c7 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -74,7 +74,8 @@ pub fn save_to_disk( where T: Saveable, { - let context = context::DiskContext::new(dir.as_ref().into(), metadata.as_ref().into())?; + let context = + crate::backend::disk::DiskSaveContext::new(dir.as_ref().into(), metadata.as_ref().into())?; save(x, context) } diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index 2daa1b12a..f427513ad 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -196,6 +196,30 @@ impl From for Value<'_> { } } +impl Value<'_> { + /// Convert this value into a fully owned [`Value<'static>`], deep-copying any borrowed + /// string or byte data. + /// + /// This is the allocation-based equivalent of round-tripping through the wire format: + /// it severs every borrow from the originating data so the result can be stored + /// independently of its source (for example inside an in-memory + /// [`crate::InMemoryContext`]). + pub fn into_owned(self) -> Value<'static> { + match self { + Self::Null => Value::Null, + Self::Bool(b) => Value::Bool(b), + Self::Number(n) => Value::Number(n), + Self::String(s) => Value::String(Cow::Owned(s.into_owned())), + Self::Bytes(b) => Value::Bytes(Cow::Owned(b.into_owned())), + Self::Array(values) => { + Value::Array(values.into_iter().map(Value::into_owned).collect()) + } + Self::Object(versioned) => Value::Object(versioned.into_owned()), + Self::Handle(handle) => Value::Handle(handle), + } + } +} + /// A map of named [`Value`]s. /// /// `Record` is the body of an object in the manifest. On the save side each call to @@ -275,6 +299,18 @@ impl<'a> Record<'a> { pub fn into_value(self, version: Version) -> Value<'a> { Value::Object(Versioned::new(self, version)) } + + /// Convert this record into a fully owned [`Record<'static>`], deep-copying borrowed + /// keys and values. See [`Value::into_owned`]. + pub fn into_owned(self) -> Record<'static> { + Record { + record: self + .record + .into_iter() + .map(|(key, value)| (Cow::Owned(key.into_owned()), value.into_owned())) + .collect(), + } + } } /// Iterator over the keys of a [`Record`]. @@ -329,6 +365,13 @@ impl<'a> Versioned<'a> { pub(crate) fn record(&self) -> &Record<'a> { &self.record } + + pub(crate) fn into_owned(self) -> Versioned<'static> { + Versioned { + record: self.record.into_owned(), + version: self.version, + } + } } /// A reference to a side-car artifact in the manifest directory. diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs index fc823caae..175c3893f 100644 --- a/diskann-record/src/version.rs +++ b/diskann-record/src/version.rs @@ -39,13 +39,51 @@ impl Version { } } +impl std::fmt::Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl std::str::FromStr for Version { + type Err = ParseVersionError; + + fn from_str(s: &str) -> Result { + let mut parts = s.split('.'); + let major = parts.next().and_then(|s| s.parse::().ok()); + let minor = parts.next().and_then(|s| s.parse::().ok()); + let patch = parts.next().and_then(|s| s.parse::().ok()); + match (major, minor, patch, parts.next()) { + (Some(major), Some(minor), Some(patch), None) => Ok(Version { + major, + minor, + patch, + }), + _ => Err(ParseVersionError(s.to_owned())), + } + } +} + +/// Error returned when a string cannot be parsed as a [`Version`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParseVersionError(String); + +impl std::fmt::Display for ParseVersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "unknown version {:?}: expected three `.`-separated u32 components", + self.0, + ) + } +} + +impl std::error::Error for ParseVersionError {} + #[cfg(feature = "serde")] impl Serialize for Version { fn serialize(&self, ser: S) -> Result { - ser.collect_str(&format_args!( - "{}.{}.{}", - self.major, self.minor, self.patch - )) + ser.collect_str(self) } } @@ -62,21 +100,7 @@ impl<'de> Deserialize<'de> for Version { } fn visit_str(self, v: &str) -> Result { - let mut parts = v.split('.'); - let major = parts.next().and_then(|s| s.parse::().ok()); - let minor = parts.next().and_then(|s| s.parse::().ok()); - let patch = parts.next().and_then(|s| s.parse::().ok()); - match (major, minor, patch, parts.next()) { - (Some(major), Some(minor), Some(patch), None) => Ok(Version { - major, - minor, - patch, - }), - _ => Err(E::custom(format!( - "unknown version {:?}: expected three `.`-separated u32 components", - v, - ))), - } + v.parse().map_err(E::custom) } } From 07f00207e8cf24ba9ecb633d357e6cd88b21bfa1 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 23 Jun 2026 14:37:38 -0700 Subject: [PATCH 09/23] added fallback to None if key does not pass validation in SaveContext::write() --- diskann-record/src/backend/disk.rs | 38 ++++++++++++++++-------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index 01a258b16..49397e7cb 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -75,22 +75,15 @@ impl SaveContext for DiskSaveContext { type Output = (); fn write(&self, key: Option<&str>) -> save::Result> { - // When a human-readable hint is supplied it must be a simple relative file name: - // reject absolute paths, parent traversal, and multi-component paths so the prefix - // below produces a single, well-formed file name in the manifest directory. - if let Some(key) = key { + // When a human-readable hint is supplied it must be a simple relative file name. + // NOTE:: Absolute paths, parent traversal, and multi-component paths cannot produce a + // single, well-formed file name in the manifest directory, so they are ignored and + // treated as if no hint had been supplied. + let key = key.filter(|key| { let mut components = std::path::Path::new(key).components(); - match components.next() { - Some(std::path::Component::Normal(_)) if components.next().is_none() => {} - _ => { - return Err(save::Error::message(format!( - "artifact file name hint {:?} must be a relative file name with no path \ - separators", - key, - ))); - } - } - } + matches!(components.next(), Some(std::path::Component::Normal(_))) + && components.next().is_none() + }); let mut files = self .files @@ -266,12 +259,21 @@ mod tests { } #[test] - fn write_rejects_path_separators_and_traversal() { + fn write_ignores_path_separators_and_traversal() { let dir = tempfile::tempdir().unwrap(); let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); for bad in ["sub/dir.bin", "../escape.bin", "/abs.bin"] { - SaveContext::write(&ctx, Some(bad)) - .expect_err("keys with path separators must be rejected"); + let handle = SaveContext::write(&ctx, Some(bad)) + .expect("keys with path separators are treated as anonymous") + .finish() + .unwrap(); + let mut components = std::path::Path::new(handle.as_str()).components(); + assert!( + matches!(components.next(), Some(std::path::Component::Normal(_))) + && components.next().is_none(), + "generated name {:?} must be a single relative file name", + handle.as_str(), + ); } } From 0c7da43f4364743b5ba294ae8747b3fbbccbe610 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 23 Jun 2026 14:43:05 -0700 Subject: [PATCH 10/23] added a quick short README --- diskann-record/README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 diskann-record/README.md diff --git a/diskann-record/README.md b/diskann-record/README.md new file mode 100644 index 000000000..5731e2a68 --- /dev/null +++ b/diskann-record/README.md @@ -0,0 +1,14 @@ +# DiskANN Record + +This crate provides a small framework for persisting structured Rust values as a +manifest (can be serialized to JSON) plus a set of side-car binary artifacts, and +reloading them later. It is can be used by `diskann` providers and indexes to +implement durable, consistent and backward-compatible checkpoints. + +Types describe how they map to a versioned record by implementing the `save::Save` and +`load::Load` traits; the `save_fields!` and `load_fields!` macros handle the +field-by-field plumbing for plain structs. Every record carries a `Version` so loaders +can detect schema changes and either upgrade or fall back through a probing chain. + +The goal is to allow crates like `diskann` to checkpoint their state without depending on +a particular serialization backend. This crate has minimal dependencies by design. \ No newline at end of file From 3f82e6b50cf93fd49233bd43f2a9052e2f4bdeab Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Wed, 24 Jun 2026 12:08:42 -0700 Subject: [PATCH 11/23] abstracted away Backend --> Writer now stores a dyn WriterInner + added WriterInner impls for DiskWriter and MemoryWriter; renamed InMemoryContext -> MemoryContext (same for InMemorySaveContext) --- diskann-record/src/backend/disk.rs | 23 +++- diskann-record/src/backend/memory.rs | 59 ++++++--- diskann-record/src/lib.rs | 2 +- diskann-record/src/save/context.rs | 185 ++++++++++----------------- diskann-record/src/save/mod.rs | 1 + diskann-record/src/value.rs | 2 +- 6 files changed, 137 insertions(+), 135 deletions(-) diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index 49397e7cb..6f6341c23 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -11,6 +11,7 @@ use std::{ collections::HashSet, + fs::File, io::BufReader, path::{Path, PathBuf}, sync::Mutex, @@ -18,7 +19,7 @@ use std::{ use crate::{ load::{self, LoadContext, Reader}, - save::{self, SaveContext, Value, Writer}, + save::{self, Handle, SaveContext, Value, Writer, delegate_write_and_seek}, }; /// The disk-backed [`SaveContext`]. @@ -113,7 +114,7 @@ impl SaveContext for DiskSaveContext { let file = std::fs::File::create_new(&full).map_err(|err| { save::Error::new(err).context(format!("while creating new file {}", full.display())) })?; - Ok(Writer::file(name, file)) + Ok(Writer::new(FileWriter { file }, name)) } /// Finalize the manifest. @@ -162,6 +163,24 @@ impl SaveContext for DiskSaveContext { } } +/// A file-backed [`WriterInner`](save::WriterInner) that streams bytes straight to disk. +/// +/// The bytes are persisted as they are written; [`WriterInner::finish`](save::WriterInner::finish) +/// only needs to mint the [`Handle`] (the buffered bytes are already flushed into the file +/// by [`Writer::finish`]). +#[derive(Debug)] +struct FileWriter { + file: File, +} + +impl save::WriterInner for FileWriter { + fn finish(self: Box, name: String) -> save::Result { + Ok(Handle::new(name)) + } +} + +delegate_write_and_seek!(file, FileWriter); + /// The disk-backed [`LoadContext`]. /// /// Reads the manifest produced by [`DiskSaveContext`] and resolves side-car artifact diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index 0bc9c1701..35516deb7 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -5,9 +5,9 @@ //! In-memory save/load contexts. //! -//! [`InMemorySaveContext`] and [`InMemoryContext`] mirror the disk-backed contexts but +//! [`MemorySaveContext`] and [`MemoryContext`] mirror the disk-backed contexts but //! keep the manifest value and every side-car artifact in memory. Saving through an -//! [`InMemorySaveContext`] yields an [`InMemoryContext`] (its +//! [`MemorySaveContext`] yields an [`MemoryContext`] (its //! [`SaveContext::Output`](crate::save::SaveContext::Output)) that can be loaded directly //! via [`crate::load::load`]. //! @@ -21,28 +21,28 @@ use std::{collections::HashMap, io::Cursor, sync::Mutex}; use crate::{ Value, load::{self, LoadContext, Reader}, - save::{self, SaveContext, Writer}, + save::{self, Handle, SaveContext, Writer, delegate_write_and_seek}, }; /// A save-side [`SaveContext`] that keeps every side-car artifact and the committed /// manifest value in memory. /// -/// [`SaveContext::finish`] consumes the context and returns an [`InMemoryContext`] ready +/// [`SaveContext::finish`] consumes the context and returns an [`MemoryContext`] ready /// to be loaded with [`crate::load::load`]. #[derive(Debug, Default)] -pub struct InMemorySaveContext { +pub struct MemorySaveContext { files: Mutex>>, } -impl InMemorySaveContext { +impl MemorySaveContext { /// Create an empty in-memory save context. pub fn new() -> Self { Self::default() } } -impl SaveContext for InMemorySaveContext { - type Output = InMemoryContext; +impl SaveContext for MemorySaveContext { + type Output = MemoryContext; fn write(&self, key: Option<&str>) -> save::Result> { // Mirror the disk context: a human-readable hint must be a simple relative file @@ -84,33 +84,60 @@ impl SaveContext for InMemorySaveContext { files.insert(name.clone(), Vec::new()); drop(files); - Ok(Writer::memory(name, &self.files)) + Ok(Writer::new( + MemoryWriter { + cursor: Cursor::new(Vec::new()), + parent: self, + }, + name, + )) } - fn finish(self, value: Value<'_>) -> save::Result { + fn finish(self, value: Value<'_>) -> save::Result { let files = self .files .into_inner() .unwrap_or_else(|poison| poison.into_inner()); - Ok(InMemoryContext { + Ok(MemoryContext { files, value: value.into_owned(), }) } } +/// An in-memory [`WriterInner`](save::WriterInner) that buffers bytes in a [`Cursor`] and, +/// on finish, deposits the completed buffer into its parent context's file store. +#[derive(Debug)] +struct MemoryWriter<'a> { + cursor: Cursor>, + parent: &'a MemorySaveContext, +} + +impl save::WriterInner for MemoryWriter<'_> { + fn finish(self: Box, name: String) -> save::Result { + self.parent + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()) + .insert(name.clone(), self.cursor.into_inner()); + Ok(Handle::new(name)) + } +} + +delegate_write_and_seek!(cursor, MemoryWriter<'_>); + /// A load-side [`LoadContext`] backed entirely by in-memory buffers. /// -/// Produced by [`InMemorySaveContext`] via [`SaveContext::finish`]. Holds the committed +/// Produced by [`MemorySaveContext`] via [`SaveContext::finish`]. Holds the committed /// manifest [`Value`] and every side-car artifact as an in-memory byte buffer, so loading /// never serializes through JSON or touches the filesystem. #[derive(Debug)] -pub struct InMemoryContext { +pub struct MemoryContext { files: HashMap>, value: Value<'static>, } -impl LoadContext for InMemoryContext { +impl LoadContext for MemoryContext { fn value(&self) -> load::Result<&Value<'_>> { Ok(&self.value) } @@ -173,7 +200,7 @@ mod tests { blob: vec![1, 2, 3, 4, 5], }; - let context = save::save(&doc, InMemorySaveContext::new()).unwrap(); + let context = save::save(&doc, MemorySaveContext::new()).unwrap(); let restored: Doc = load::load(&context).unwrap(); assert_eq!(doc, restored); @@ -181,7 +208,7 @@ mod tests { #[test] fn read_rejects_unregistered_artifact() { - let context = InMemoryContext { + let context = MemoryContext { files: HashMap::new(), value: Value::Null, }; diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index 3788c0d44..b21767e32 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -95,7 +95,7 @@ pub mod load; pub mod save; mod backend; -pub use backend::memory::{InMemoryContext, InMemorySaveContext}; +pub use backend::memory::{MemoryContext, MemorySaveContext}; // Canonical wire width for `usize` and `isize` in manifests is 64 bits. Saving a value // on a 64-bit platform and loading it on a 32-bit platform (or vice versa) could silently diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index 248ab8448..b073a264e 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -14,15 +14,51 @@ //! [`Writer::finish`] flushes the buffer and yields a [`Handle`] that can be inserted //! into a [`super::Record`]. -use std::{ - collections::HashMap, - fs::File, - io::{BufWriter, Cursor}, - sync::Mutex, -}; +use std::io::BufWriter; use crate::save::{Error, Handle, Result, Value}; +/// Generate forwarding [`std::io::Write`] and [`std::io::Seek`] impls for `$T` that +/// delegate every method to its `$field` member. +macro_rules! delegate_write_and_seek { + ($field:ident, $T:ty) => { + impl std::io::Write for $T { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.$field.write(buf) + } + fn flush(&mut self) -> std::io::Result<()> { + self.$field.flush() + } + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + self.$field.write_vectored(bufs) + } + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + self.$field.write_all(buf) + } + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { + self.$field.write_fmt(args) + } + } + + impl std::io::Seek for $T { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.$field.seek(pos) + } + fn rewind(&mut self) -> std::io::Result<()> { + self.$field.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.$field.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.$field.seek_relative(offset) + } + } + }; +} + +pub(crate) use delegate_write_and_seek; + /// The backing store for a save operation. /// /// A `SaveContext` decides where side-car artifacts are written ([`SaveContext::write`]) @@ -113,135 +149,54 @@ impl<'a> Context<'a> { } } +/// The backend-specific half of a [`Writer`]. +/// +/// Each [`SaveContext`](super::SaveContext) implementation supplies its own +/// `WriterInner` (e.g. an in-memory cursor or an on-disk file). [`Writer`] wraps it in a +/// [`BufWriter`] and forwards [`std::io::Write`] / [`std::io::Seek`] to it; on +/// [`Writer::finish`] the (flushed) inner is consumed to commit the artifact and yield a +/// [`Handle`]. +pub(crate) trait WriterInner: std::io::Write + std::io::Seek + std::fmt::Debug { + /// Commit the completed artifact under `name`, returning its [`Handle`]. + fn finish(self: Box, name: String) -> Result; +} + /// A borrowed side-car artifact writer produced by [`Context::write`]. /// /// Implements [`std::io::Write`] and [`std::io::Seek`]. Writes are buffered; calling -/// [`Writer::finish`] flushes the buffer (or, for an in-memory context, deposits the -/// completed buffer into the store) and returns a [`Handle`]. +/// [`Writer::finish`] flushes the buffer, commits the artifact through the backing +/// [`WriterInner`], and returns a [`Handle`]. #[derive(Debug)] pub struct Writer<'a> { - inner: Backend<'a>, + inner: BufWriter>, name: String, } -/// The backing store a [`Writer`] writes into. -#[derive(Debug)] -enum Backend<'a> { - /// A file on disk; the bytes are persisted as they are written. - #[cfg_attr(not(feature = "disk"), allow(dead_code))] - File(BufWriter), - /// An in-memory buffer; on [`Writer::finish`] the completed buffer is inserted into - /// `store` under the writer's name. - Memory { - buffer: Cursor>, - store: &'a Mutex>>, - }, -} - impl<'a> Writer<'a> { - /// Construct an in-memory writer that deposits its buffer into `store` on finish. - pub(crate) fn memory(name: String, store: &'a Mutex>>) -> Self { + /// Wrap a backend-specific [`WriterInner`] into a buffered [`Writer`] named `name`. + pub(crate) fn new(inner: T, name: String) -> Self + where + T: WriterInner + 'a, + { Self { - inner: Backend::Memory { - buffer: Cursor::new(Vec::new()), - store, - }, + inner: BufWriter::new(Box::new(inner)), name, } } - /// Construct a file-backed writer that streams bytes straight to `file`. - #[cfg(feature = "disk")] - pub(crate) fn file(name: String, file: File) -> Self { - Self { - inner: Backend::File(BufWriter::new(file)), - name, - } - } -} - -impl Writer<'_> { /// Flush and close the writer, returning a [`Handle`] for the artifact. /// /// Insert the returned handle into a [`Record`](super::Record) (typically via /// [`Record::insert`](super::Record::insert)) so that load-side code can locate the /// artifact through the manifest. pub fn finish(self) -> Result { - match self.inner { - // NOTE: into_inner() will flush the buffer and close the file. - Backend::File(io) => { - io.into_inner() - .map_err(|err| Error::new(err.into_error()))?; - } - Backend::Memory { buffer, store } => { - store - .lock() - .unwrap_or_else(|poison| poison.into_inner()) - .insert(self.name.clone(), buffer.into_inner()); - } - } - Ok(Handle::new(self.name)) - } -} - -impl std::io::Write for Writer<'_> { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - match &mut self.inner { - Backend::File(io) => io.write(buf), - Backend::Memory { buffer, .. } => buffer.write(buf), - } - } - - fn flush(&mut self) -> std::io::Result<()> { - match &mut self.inner { - Backend::File(io) => io.flush(), - Backend::Memory { buffer, .. } => buffer.flush(), - } - } - fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { - match &mut self.inner { - Backend::File(io) => io.write_vectored(bufs), - Backend::Memory { buffer, .. } => buffer.write_vectored(bufs), - } - } - fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { - match &mut self.inner { - Backend::File(io) => io.write_all(buf), - Backend::Memory { buffer, .. } => buffer.write_all(buf), - } - } - fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { - match &mut self.inner { - Backend::File(io) => io.write_fmt(args), - Backend::Memory { buffer, .. } => buffer.write_fmt(args), - } + // into_inner() flushes the buffered bytes into the backend before we commit it. + let inner = self + .inner + .into_inner() + .map_err(|err| Error::new(err.into_error()))?; + inner.finish(self.name) } } -impl std::io::Seek for Writer<'_> { - fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { - match &mut self.inner { - Backend::File(io) => io.seek(pos), - Backend::Memory { buffer, .. } => buffer.seek(pos), - } - } - - fn rewind(&mut self) -> std::io::Result<()> { - match &mut self.inner { - Backend::File(io) => io.rewind(), - Backend::Memory { buffer, .. } => buffer.rewind(), - } - } - fn stream_position(&mut self) -> std::io::Result { - match &mut self.inner { - Backend::File(io) => io.stream_position(), - Backend::Memory { buffer, .. } => buffer.stream_position(), - } - } - fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { - match &mut self.inner { - Backend::File(io) => io.seek_relative(offset), - Backend::Memory { buffer, .. } => buffer.seek_relative(offset), - } - } -} +delegate_write_and_seek!(inner, Writer<'_>); diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index e7b9123c7..449a17711 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -29,6 +29,7 @@ pub use crate::value::{Handle, Keys, Record, Value, Versioned}; mod context; pub use context::{Context, SaveContext, Writer}; +pub(crate) use context::{WriterInner, delegate_write_and_seek}; mod error; pub use error::{Error, Result}; diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index f427513ad..0eb396d1c 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -203,7 +203,7 @@ impl Value<'_> { /// This is the allocation-based equivalent of round-tripping through the wire format: /// it severs every borrow from the originating data so the result can be stored /// independently of its source (for example inside an in-memory - /// [`crate::InMemoryContext`]). + /// [`crate::MemoryContext`]). pub fn into_owned(self) -> Value<'static> { match self { Self::Null => Value::Null, From 6c51b9140801ec1c0308a042a416e35138864e28 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Wed, 24 Jun 2026 13:26:56 -0700 Subject: [PATCH 12/23] added ReaderInner trait to allow trait object inside Reader + doc updates --- diskann-record/src/backend/disk.rs | 2 +- diskann-record/src/backend/memory.rs | 2 +- diskann-record/src/load/context.rs | 51 +++++++++++++++++++++------- diskann-record/src/save/context.rs | 4 +-- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index 6f6341c23..91d501161 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -249,7 +249,7 @@ impl LoadContext for DiskLoadContext { load::Error::new(err).context(format!("while opening artifact file {}", full.display())) })?; - Ok(Reader::new(Box::new(file))) + Ok(Reader::new(file)) } } diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index 35516deb7..639aa13cb 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -144,7 +144,7 @@ impl LoadContext for MemoryContext { fn read(&self, key: &str) -> load::Result> { match self.files.get(key) { - Some(bytes) => Ok(Reader::new(Box::new(Cursor::new(bytes.as_slice())))), + Some(bytes) => Ok(Reader::new(Cursor::new(bytes.as_slice()))), None => Err( load::Error::from(load::error::Kind::MissingFile).context(format!( "handle references artifact {:?} which is not registered in this context", diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index 9cbcd1d0f..aac43ba3c 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -15,8 +15,8 @@ //! * [`Object::read`] for side-car artifacts referenced by a //! [`save::Handle`](super::save::Handle). //! -//! [`Reader`] implements [`std::io::Read`] over a side-car artifact, regardless of the -//! provider's backing store. +//! [`Reader`] implements [`std::io::Read`] and [`std::io::Seek`] over a side-car +//! artifact, regardless of the provider's backing store. use std::io::BufReader; @@ -30,7 +30,7 @@ use crate::{ /// /// A `LoadContext` supplies the root manifest [`save::Value`] ([`LoadContext::value`]) /// and resolves side-car artifacts referenced by handles ([`LoadContext::read`]). The -/// concrete implementations live under [`crate::backend`]: a disk-backed context (under +/// concrete implementations live under `crate::backend`: a disk-backed context (under /// the `disk` feature) and an in-memory context. Alternative implementations (e.g. a /// virtual filesystem) can be supplied for testing. /// @@ -54,23 +54,35 @@ pub trait LoadContext { fn read(&self, key: &str) -> Result>; } +/// The backend-specific half of a [`Reader`]. +/// +/// Each [`LoadContext`] implementation supplies its own `ReaderInner` (e.g. an in-memory +/// cursor or an on-disk file). The blanket impl covers any type that is both +/// [`std::io::Read`] and [`std::io::Seek`]. +pub(crate) trait ReaderInner: std::io::Read + std::io::Seek {} + +impl ReaderInner for T where T: std::io::Read + std::io::Seek {} + /// A borrowed reader over a side-car artifact. /// -/// Produced by [`Object::read`]. Implements [`std::io::Read`] over whatever backing -/// store the [`LoadContext`] provides, so non-file-backed providers (like an in-memory byte buffer) can supply an -/// arbitrary [`std::io::Read`]. +/// Produced by [`Object::read`]. Implements [`std::io::Read`] and [`std::io::Seek`] over +/// whatever backing store the [`LoadContext`] provides, so non-file-backed providers +/// (like an in-memory byte buffer) can supply an arbitrary seekable reader. pub struct Reader<'a> { - io: BufReader>, + io: BufReader>, } impl<'a> Reader<'a> { - /// Build a reader over an arbitrary borrowed [`std::io::Read`] source. + /// Build a reader over an arbitrary borrowed [`ReaderInner`] source. /// - /// Used by non-file-backed [`LoadContext`] implementations (e.g. the in-memory - /// context) to expose a side-car artifact backed by a [`std::io::Cursor`]. - pub(crate) fn new(io: Box) -> Self { + /// Used by [`LoadContext`] implementations to expose a side-car artifact backed by a + /// file or an in-memory [`std::io::Cursor`]. + pub(crate) fn new(io: T) -> Self + where + T: ReaderInner + 'a, + { Self { - io: BufReader::new(io), + io: BufReader::new(Box::new(io)), } } } @@ -96,6 +108,21 @@ impl std::io::Read for Reader<'_> { } } +impl std::io::Seek for Reader<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.io.seek(pos) + } + fn rewind(&mut self) -> std::io::Result<()> { + self.io.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.io.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.io.seek_relative(offset) + } +} + /////////////////////// // User facing types // /////////////////////// diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index b073a264e..e6929aa5e 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -63,7 +63,7 @@ pub(crate) use delegate_write_and_seek; /// /// A `SaveContext` decides where side-car artifacts are written ([`SaveContext::write`]) /// and how the final manifest is committed ([`SaveContext::finish`]). The concrete -/// implementations live under [`crate::backend`]: a disk-backed context (under the `disk` +/// implementations live under `crate::backend`: a disk-backed context (under the `disk` /// feature) and an in-memory context. Alternative implementations (e.g. a virtual /// filesystem) can be supplied for testing or to avoid touching the filesystem. /// @@ -165,7 +165,7 @@ pub(crate) trait WriterInner: std::io::Write + std::io::Seek + std::fmt::Debug { /// /// Implements [`std::io::Write`] and [`std::io::Seek`]. Writes are buffered; calling /// [`Writer::finish`] flushes the buffer, commits the artifact through the backing -/// [`WriterInner`], and returns a [`Handle`]. +/// writer, and returns a [`Handle`]. #[derive(Debug)] pub struct Writer<'a> { inner: BufWriter>, From db625999da2517e2bf57fbc34d0d5dd61e57089f Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 29 Jun 2026 10:26:05 -0700 Subject: [PATCH 13/23] Made Write pub(crate); removed Bytes and its serde path; added explicit NaN,+inf,-inf handling for f64 --- diskann-record/src/backend/memory.rs | 15 +++---- diskann-record/src/load/context.rs | 4 +- diskann-record/src/load/mod.rs | 7 +-- diskann-record/src/number.rs | 65 ++++++++++++++++++++++++++++ diskann-record/src/save/context.rs | 4 +- diskann-record/src/save/mod.rs | 8 ++-- diskann-record/src/value.rs | 23 ++++------ 7 files changed, 93 insertions(+), 33 deletions(-) diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index 639aa13cb..8e06ba9e4 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -7,9 +7,8 @@ //! //! [`MemorySaveContext`] and [`MemoryContext`] mirror the disk-backed contexts but //! keep the manifest value and every side-car artifact in memory. Saving through an -//! [`MemorySaveContext`] yields an [`MemoryContext`] (its -//! [`SaveContext::Output`](crate::save::SaveContext::Output)) that can be loaded directly -//! via [`crate::load::load`]. +//! [`MemorySaveContext`] yields an [`MemoryContext`] (its `SaveContext::Output`) that can +//! be loaded directly via the `load` entry point. //! //! Unlike the disk path, this round trip never serializes through JSON: the manifest //! [`Value`] is deep-copied to `'static` via [`Value::into_owned`] and side-car artifacts @@ -24,11 +23,11 @@ use crate::{ save::{self, Handle, SaveContext, Writer, delegate_write_and_seek}, }; -/// A save-side [`SaveContext`] that keeps every side-car artifact and the committed +/// A save-side `SaveContext` that keeps every side-car artifact and the committed /// manifest value in memory. /// -/// [`SaveContext::finish`] consumes the context and returns an [`MemoryContext`] ready -/// to be loaded with [`crate::load::load`]. +/// `SaveContext::finish` consumes the context and returns an [`MemoryContext`] ready +/// to be loaded with the `load` entry point. #[derive(Debug, Default)] pub struct MemorySaveContext { files: Mutex>>, @@ -126,9 +125,9 @@ impl save::WriterInner for MemoryWriter<'_> { delegate_write_and_seek!(cursor, MemoryWriter<'_>); -/// A load-side [`LoadContext`] backed entirely by in-memory buffers. +/// A load-side `LoadContext` backed entirely by in-memory buffers. /// -/// Produced by [`MemorySaveContext`] via [`SaveContext::finish`]. Holds the committed +/// Produced by [`MemorySaveContext`] via `SaveContext::finish`. Holds the committed /// manifest [`Value`] and every side-car artifact as an in-memory byte buffer, so loading /// never serializes through JSON or touches the filesystem. #[derive(Debug)] diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index aac43ba3c..aa77fdcdc 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -37,7 +37,7 @@ use crate::{ /// The generic [`load`](super::load) entry point is parameterized over this trait, and /// [`Context`] / [`Object`] / `Array` borrow it as an object-safe `&dyn LoadContext` /// so the load tree is agnostic to the concrete context type. -pub trait LoadContext { +pub(crate) trait LoadContext { /// The root value of the manifest. /// /// # Errors @@ -66,7 +66,7 @@ impl ReaderInner for T where T: std::io::Read + std::io::Seek {} /// A borrowed reader over a side-car artifact. /// /// Produced by [`Object::read`]. Implements [`std::io::Read`] and [`std::io::Seek`] over -/// whatever backing store the [`LoadContext`] provides, so non-file-backed providers +/// whatever backing store the `LoadContext` provides, so non-file-backed providers /// (like an in-memory byte buffer) can supply an arbitrary seekable reader. pub struct Reader<'a> { io: BufReader>, diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs index 808745f7c..fee2c55bc 100644 --- a/diskann-record/src/load/mod.rs +++ b/diskann-record/src/load/mod.rs @@ -9,7 +9,7 @@ //! primitive-like leaves, [`Loadable`]) and obtain an [`Object`] / [`Context`] from which //! they extract individual fields and side-car artifacts. //! -//! The generic entry point is [`load`]; `load_from_disk` (available under the `disk` +//! The generic entry point is `load`; `load_from_disk` (available under the `disk` //! feature) is the disk-backed convenience wrapper that reads a manifest and dispatches //! into the user type's [`Load`] impl. //! @@ -37,7 +37,8 @@ pub mod error; pub use error::{Error, Result}; mod context; -pub use context::{Context, LoadContext, Object, Reader}; +pub use context::{Context, Object, Reader}; +pub(crate) use context::LoadContext; #[cfg(feature = "disk")] use std::path::Path; @@ -54,7 +55,7 @@ use crate::{Version, save}; /// # Errors /// /// Returns [`Error`] if the context cannot produce a root value or if `T`'s loader fails. -pub fn load<'a, T, C>(context: &'a C) -> Result +pub(crate) fn load<'a, T, C>(context: &'a C) -> Result where T: Loadable<'a>, C: LoadContext, diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index c7c97643e..82d123c70 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -28,9 +28,40 @@ pub enum Number { F64(f64), } +impl Number { + /// Returns the string sentinel for a non-finite `f64`, or `None` for any finite or + /// integer value. + /// + /// JSON cannot represent `NaN`/`\pm inf` as numeric literals, so these are encoded as + /// the strings `"nan"`, `"inf"`, and `"neg_inf"` instead. [`Number::from_sentinel`] + /// is the inverse. + pub(crate) fn sentinel(self) -> Option<&'static str> { + match self { + Self::F64(v) if v.is_nan() => Some("nan"), + Self::F64(v) if v == f64::INFINITY => Some("inf"), + Self::F64(v) if v == f64::NEG_INFINITY => Some("neg_inf"), + _ => None, + } + } + + /// Decodes a non-finite `f64` sentinel produced by [`Number::sentinel`], or `None` + /// for any other string. + pub(crate) fn from_sentinel(s: &str) -> Option { + match s { + "nan" => Some(Self::F64(f64::NAN)), + "inf" => Some(Self::F64(f64::INFINITY)), + "neg_inf" => Some(Self::F64(f64::NEG_INFINITY)), + _ => None, + } + } +} + #[cfg(feature = "serde")] impl Serialize for Number { fn serialize(&self, serializer: S) -> Result { + if let Some(sentinel) = self.sentinel() { + return serializer.serialize_str(sentinel); + } match *self { Self::U64(v) => serializer.serialize_u64(v), Self::I64(v) => serializer.serialize_i64(v), @@ -62,6 +93,12 @@ impl<'de> Deserialize<'de> for Number { fn visit_f64(self, v: f64) -> Result { Ok(Number::F64(v)) } + + fn visit_str(self, v: &str) -> Result { + Number::from_sentinel(v).ok_or_else(|| { + de::Error::custom(format!("expected a number or numeric sentinel, got {v:?}")) + }) + } } deserializer.deserialize_any(NumberVisitor) @@ -205,4 +242,32 @@ mod tests { assert_eq!(u16::try_from(Number::U64(300)).unwrap(), 300); assert!(usize::try_from(Number::I64(-1)).is_err()); } + + #[cfg(feature = "disk")] + #[test] + fn non_finite_floats_round_trip_via_sentinels() { + for (value, sentinel) in [ + (f64::NAN, "\"nan\""), + (f64::INFINITY, "\"inf\""), + (f64::NEG_INFINITY, "\"neg_inf\""), + ] { + let json = serde_json::to_string(&Number::F64(value)).unwrap(); + assert_eq!(json, sentinel); + + let back: Number = serde_json::from_str(&json).unwrap(); + match back { + Number::F64(v) if value.is_nan() => assert!(v.is_nan()), + Number::F64(v) => assert_eq!(v, value), + other => panic!("expected F64, got {other:?}"), + } + } + } + + #[cfg(feature = "disk")] + #[test] + fn finite_floats_serialize_as_json_numbers() { + assert_eq!(serde_json::to_string(&Number::F64(1.5)).unwrap(), "1.5"); + assert_eq!(serde_json::to_string(&Number::U64(7)).unwrap(), "7"); + assert_eq!(serde_json::to_string(&Number::I64(-7)).unwrap(), "-7"); + } } diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs index e6929aa5e..cc9a6381b 100644 --- a/diskann-record/src/save/context.rs +++ b/diskann-record/src/save/context.rs @@ -69,7 +69,7 @@ pub(crate) use delegate_write_and_seek; /// /// The generic [`save`](super::save) entry point is parameterized over this trait so /// that the base crate carries no hard dependency on any particular implementation. -pub trait SaveContext { +pub(crate) trait SaveContext { /// The value produced once the manifest has been committed by /// [`SaveContext::finish`]. For the disk-backed context this is `()`. type Output; @@ -119,7 +119,7 @@ where /// `Context` exposes one operation — [`Context::write`] — for allocating a side-car /// artifact. The same context is passed to nested [`Save`](super::Save) impls (typically /// via the [`save_fields!`](crate::save_fields) macro), so a single save tree shares -/// artifact-name bookkeeping. It borrows the backing [`SaveContext`] as an object-safe +/// artifact-name bookkeeping. It borrows the backing `SaveContext` as an object-safe /// `GetWrite` so that the save tree is agnostic to the concrete context type. #[derive(Clone)] pub struct Context<'a> { diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs index 449a17711..606b9a122 100644 --- a/diskann-record/src/save/mod.rs +++ b/diskann-record/src/save/mod.rs @@ -9,7 +9,7 @@ //! (or, for primitive-like leaves, [`Saveable`]) and obtain a [`Context`] from which they //! request side-car artifact writers and assemble a [`Record`] of named fields. //! -//! The generic entry point is [`save`]; `save_to_disk` (available under the `disk` +//! The generic entry point is `save`; `save_to_disk` (available under the `disk` //! feature) is the disk-backed convenience wrapper that serializes a value into a //! caller-chosen directory plus a manifest path. //! @@ -28,8 +28,8 @@ pub use crate::value::{Handle, Keys, Record, Value, Versioned}; mod context; -pub use context::{Context, SaveContext, Writer}; -pub(crate) use context::{WriterInner, delegate_write_and_seek}; +pub use context::{Context, Writer}; +pub(crate) use context::{SaveContext, WriterInner, delegate_write_and_seek}; mod error; pub use error::{Error, Result}; @@ -47,7 +47,7 @@ use crate::Version; /// /// Returns [`Error`] if a user impl returns an error or if the context fails to commit /// the manifest. -pub fn save(x: &T, context: C) -> Result +pub(crate) fn save(x: &T, context: C) -> Result where T: Saveable, C: SaveContext, diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index 0eb396d1c..9b327150b 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -11,8 +11,8 @@ //! //! Every field stored in a manifest is one of: //! -//! * [`Value::Null`] / [`Value::Bool`] / [`Value::Number`] / [`Value::String`] / -//! [`Value::Bytes`] — primitive scalars. +//! * [`Value::Null`] / [`Value::Bool`] / [`Value::Number`] / [`Value::String`] — +//! primitive scalars. //! * [`Value::Array`] — a homogeneous sequence (used by `Vec` and `&[T]`). //! * [`Value::Object`] — a [`Versioned`] [`Record`] (the canonical encoding for a //! `T: crate::save::Save`). @@ -39,7 +39,7 @@ use crate::{Number, Version, save::Error}; /// The wire-level union of every saveable kind. /// /// See the module-level docs for an overview of when each variant is produced. The -/// borrowing parameter `'a` lets [`Value::String`], [`Value::Bytes`], and nested +/// borrowing parameter `'a` lets [`Value::String`] and nested /// records reuse memory owned by the caller without copying. #[derive(Debug)] pub enum Value<'a> { @@ -47,7 +47,6 @@ pub enum Value<'a> { Bool(bool), Number(Number), String(Cow<'a, str>), - Bytes(Cow<'a, [u8]>), Array(Vec>), Object(Versioned<'a>), Handle(Handle), @@ -61,7 +60,6 @@ impl Serialize for Value<'_> { Self::Bool(b) => ser.serialize_bool(*b), Self::Number(n) => n.serialize(ser), Self::String(s) => ser.serialize_str(s), - Self::Bytes(b) => ser.serialize_bytes(b), Self::Array(a) => a.serialize(ser), Self::Object(v) => v.serialize(ser), Self::Handle(h) => h.serialize(ser), @@ -116,21 +114,19 @@ impl<'de> Deserialize<'de> for Value<'static> { } fn visit_str(self, v: &str) -> Result, E> { + if let Some(n) = Number::from_sentinel(v) { + return Ok(Value::Number(n)); + } Ok(Value::String(Cow::Owned(v.to_owned()))) } fn visit_string(self, v: String) -> Result, E> { + if let Some(n) = Number::from_sentinel(&v) { + return Ok(Value::Number(n)); + } Ok(Value::String(Cow::Owned(v))) } - fn visit_bytes(self, v: &[u8]) -> Result, E> { - Ok(Value::Bytes(Cow::Owned(v.to_owned()))) - } - - fn visit_byte_buf(self, v: Vec) -> Result, E> { - Ok(Value::Bytes(Cow::Owned(v))) - } - fn visit_seq(self, mut seq: A) -> Result, A::Error> where A: SeqAccess<'de>, @@ -210,7 +206,6 @@ impl Value<'_> { Self::Bool(b) => Value::Bool(b), Self::Number(n) => Value::Number(n), Self::String(s) => Value::String(Cow::Owned(s.into_owned())), - Self::Bytes(b) => Value::Bytes(Cow::Owned(b.into_owned())), Self::Array(values) => { Value::Array(values.into_iter().map(Value::into_owned).collect()) } From a404c05a2b02954ecef8ca5da825153354ac076e Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 29 Jun 2026 10:56:03 -0700 Subject: [PATCH 14/23] cleanup failed save in DiskSaveContext; treat hint as None in MemorySaveContext if hint is invalid --- diskann-record/src/backend/disk.rs | 224 ++++++++++++++++++++++----- diskann-record/src/backend/memory.rs | 45 +++--- 2 files changed, 208 insertions(+), 61 deletions(-) diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index 91d501161..6c63fc8be 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -10,7 +10,7 @@ //! communicate only through the filesystem. Both are available under the `disk` feature. use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, fs::File, io::BufReader, path::{Path, PathBuf}, @@ -24,15 +24,21 @@ use crate::{ /// The disk-backed [`SaveContext`]. /// -/// Holds the manifest directory, the manifest path, and the set of artifact file names -/// registered so far. Lookup and insertion go through a [`Mutex`] so that concurrent -/// [`Save`](crate::save::Save) impls cannot accidentally hand out the same artifact name -/// twice. +/// Holds the manifest directory, the manifest path, and the artifact file names registered +/// so far paired with whether their writer has finished. Lookup and insertion go through a +/// [`Mutex`] so that concurrent [`Save`](crate::save::Save) impls cannot accidentally hand +/// out the same artifact name twice. +/// +/// # Cleanup on failure +/// +/// Save can fail part-way, so the [`Drop`] impl ensures cleanup of any artifacts created +/// before the failure. #[derive(Debug)] pub(crate) struct DiskSaveContext { dir: PathBuf, metadata: PathBuf, - files: Mutex>, + files: Mutex>, + committed: bool, } #[derive(serde::Serialize)] @@ -67,9 +73,35 @@ impl DiskSaveContext { Ok(Self { dir, metadata, - files: Mutex::new(HashSet::new()), + files: Mutex::new(HashMap::new()), + committed: false, }) } + + /// Path of the temp manifest written by [`SaveContext::finish`] before the atomic + /// rename into [`Self::metadata`]. + fn temp_metadata(&self) -> PathBuf { + let mut temp = self.metadata.clone().into_os_string(); + temp.push(".temp"); + PathBuf::from(temp) + } +} + +impl Drop for DiskSaveContext { + /// Best-effort cleanup for an uncommitted save. + fn drop(&mut self) { + if self.committed { + return; + } + let files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + for name in files.keys() { + let _ = std::fs::remove_file(self.dir.join(name)); + } + let _ = std::fs::remove_file(self.temp_metadata()); + } } impl SaveContext for DiskSaveContext { @@ -98,7 +130,7 @@ impl SaveContext for DiskSaveContext { None => format!("{:03}", files.len()), }; - if !files.insert(name.clone()) { + if files.contains_key(&name) { return Err(save::Error::message(format!( "generated artifact name {:?} collides with an existing artifact", name, @@ -114,7 +146,10 @@ impl SaveContext for DiskSaveContext { let file = std::fs::File::create_new(&full).map_err(|err| { save::Error::new(err).context(format!("while creating new file {}", full.display())) })?; - Ok(Writer::new(FileWriter { file }, name)) + // Reserve the name as not-yet-finished; `FileWriter::finish` flips it to `true`, and + // `SaveContext::finish` reports any slot that was reserved but never finished. + files.insert(name.clone(), false); + Ok(Writer::new(FileWriter { file, parent: self }, name)) } /// Finalize the manifest. @@ -122,36 +157,47 @@ impl SaveContext for DiskSaveContext { /// Writes the manifest JSON atomically: serializes to a `.temp` file first, /// then renames it into place. Fails if the temp file already exists (an in-flight /// save is in progress, or a previous run aborted between rename steps). - fn finish(self, value: Value<'_>) -> save::Result<()> { - let files = self - .files - .into_inner() - .unwrap_or_else(|poison| poison.into_inner()); - let f = Final { - files: files.iter().map(|k| &**k).collect(), - value: &value, - }; - - // Fail if the temp file already exists - let mut temp = self.metadata.clone().into_os_string(); - temp.push(".temp"); - let temp = PathBuf::from(temp); - let buffer = std::fs::File::create_new(&temp).map_err(|err| { - if err.kind() == std::io::ErrorKind::AlreadyExists { - save::Error::message(format!( - "Temporary file {} already exists. Aborting!", - temp.display() - )) - } else { - save::Error::new(err).context(format!( - "while creating temp manifest file {}", - temp.display() - )) + /// + /// On failure, context is dropped without committing ==> [`Drop`] impl + /// removes the artifacts + temp manifest. Save is marked committed once + /// rename succeeds and artifacts are in place. + fn finish(mut self, value: Value<'_>) -> save::Result<()> { + let temp = self.temp_metadata(); + { + let files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + if let Some((name, _)) = files.iter().find(|(_, finished)| !**finished) { + return Err(save::Error::message(format!( + "artifact {:?} was reserved but never finished", + name, + ))); } - })?; - - serde_json::to_writer_pretty(buffer, &f) - .map_err(|err| save::Error::new(err).context("while serializing manifest to JSON"))?; + let f = Final { + files: files.keys().map(|k| &**k).collect(), + value: &value, + }; + + // Fail if the temp file already exists + let buffer = std::fs::File::create_new(&temp).map_err(|err| { + if err.kind() == std::io::ErrorKind::AlreadyExists { + save::Error::message(format!( + "Temporary file {} already exists. Aborting!", + temp.display() + )) + } else { + save::Error::new(err).context(format!( + "while creating temp manifest file {}", + temp.display() + )) + } + })?; + + serde_json::to_writer_pretty(buffer, &f).map_err(|err| { + save::Error::new(err).context("while serializing manifest to JSON") + })?; + } std::fs::rename(&temp, &self.metadata).map_err(|err| { save::Error::new(err).context(format!( "while renaming temp manifest {} to final path {}", @@ -159,6 +205,8 @@ impl SaveContext for DiskSaveContext { self.metadata.display() )) })?; + // Manifest now in place, artifacts belong to a valid record + self.committed = true; Ok(()) } } @@ -166,20 +214,26 @@ impl SaveContext for DiskSaveContext { /// A file-backed [`WriterInner`](save::WriterInner) that streams bytes straight to disk. /// /// The bytes are persisted as they are written; [`WriterInner::finish`](save::WriterInner::finish) -/// only needs to mint the [`Handle`] (the buffered bytes are already flushed into the file -/// by [`Writer::finish`]). +/// only needs to mint the [`Handle`] and mark the artifact finished in its parent context +/// (the buffered bytes are already flushed into the file by [`Writer::finish`]). #[derive(Debug)] -struct FileWriter { +struct FileWriter<'a> { file: File, + parent: &'a DiskSaveContext, } -impl save::WriterInner for FileWriter { +impl save::WriterInner for FileWriter<'_> { fn finish(self: Box, name: String) -> save::Result { + self.parent + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()) + .insert(name.clone(), true); Ok(Handle::new(name)) } } -delegate_write_and_seek!(file, FileWriter); +delegate_write_and_seek!(file, FileWriter<'_>); /// The disk-backed [`LoadContext`]. /// @@ -325,6 +379,90 @@ mod tests { assert!(!handle.as_str().is_empty()); } + #[test] + fn drop_without_finish_cleans_up_artifacts() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("meta.json"); + let name = { + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let name = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap() + .as_str() + .to_owned(); + assert!(dir.path().join(&name).exists()); + name + // `ctx` is dropped here without ever being committed via `finish`. + }; + assert!( + !dir.path().join(&name).exists(), + "an uncommitted save must clean up the artifacts it created" + ); + } + + #[test] + fn failed_finish_cleans_up_artifacts_and_temp() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("meta.json"); + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let name = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap() + .as_str() + .to_owned(); + + // Pre-create the temp manifest so `finish` aborts on the `create_new` collision. + let temp = ctx.temp_metadata(); + std::fs::write(&temp, b"stale").unwrap(); + + let err = ctx + .finish(Value::Null) + .expect_err("finish must fail when the temp manifest already exists"); + assert!(format!("{err}").contains("already exists")); + + assert!( + !dir.path().join(&name).exists(), + "a failed finish must clean up the artifacts it created" + ); + assert!( + !metadata.exists(), + "a failed finish must not leave a committed manifest" + ); + } + + #[test] + fn committed_save_preserves_artifacts() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("meta.json"); + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let name = SaveContext::write(&ctx, Some("artifact.bin")) + .unwrap() + .finish() + .unwrap() + .as_str() + .to_owned(); + + ctx.finish(Value::Null).unwrap(); + + assert!( + dir.path().join(&name).exists(), + "a committed save must keep its artifacts" + ); + assert!(metadata.exists(), "a committed save must write the manifest"); + assert!( + !ctx_temp(&metadata).exists(), + "a committed save must not leave a temp manifest" + ); + } + + fn ctx_temp(metadata: &Path) -> PathBuf { + let mut temp = metadata.to_owned().into_os_string(); + temp.push(".temp"); + PathBuf::from(temp) + } + fn write_manifest(dir: &Path, files: &[&str]) -> PathBuf { let manifest = serde_json::json!({ "files": files, diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index 8e06ba9e4..207580975 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -28,9 +28,13 @@ use crate::{ /// /// `SaveContext::finish` consumes the context and returns an [`MemoryContext`] ready /// to be loaded with the `load` entry point. +/// +/// # Cleanup on failure +/// +/// Failures in the same process are automatically cleaned up. #[derive(Debug, Default)] pub struct MemorySaveContext { - files: Mutex>>, + files: Mutex>>>, } impl MemorySaveContext { @@ -45,20 +49,14 @@ impl SaveContext for MemorySaveContext { fn write(&self, key: Option<&str>) -> save::Result> { // Mirror the disk context: a human-readable hint must be a simple relative file - // name so the generated artifact name is a single, well-formed key. - if let Some(key) = key { + // name. Absolute paths, parent traversal, and multi-component paths cannot produce + // a single, well-formed key, so they are ignored and treated as if no hint had been + // supplied. + let key = key.filter(|key| { let mut components = std::path::Path::new(key).components(); - match components.next() { - Some(std::path::Component::Normal(_)) if components.next().is_none() => {} - _ => { - return Err(save::Error::message(format!( - "artifact file name hint {:?} must be a relative file name with no path \ - separators", - key, - ))); - } - } - } + matches!(components.next(), Some(std::path::Component::Normal(_))) + && components.next().is_none() + }); let mut files = self .files @@ -78,9 +76,10 @@ impl SaveContext for MemorySaveContext { name, ))); } - // Reserve the name so the count advances and concurrent writers cannot collide; - // the placeholder is overwritten with the real bytes by `Writer::finish`. - files.insert(name.clone(), Vec::new()); + // Reserve the name with an empty slot so the count advances and concurrent writers + // cannot collide; the slot is filled with the real bytes by `Writer::finish`. A slot + // that is never filled is reported by `SaveContext::finish`. + files.insert(name.clone(), None); drop(files); Ok(Writer::new( @@ -97,6 +96,16 @@ impl SaveContext for MemorySaveContext { .files .into_inner() .unwrap_or_else(|poison| poison.into_inner()); + let files = files + .into_iter() + .map(|(name, bytes)| match bytes { + Some(bytes) => Ok((name, bytes)), + None => Err(save::Error::message(format!( + "artifact {:?} was reserved but never finished", + name, + ))), + }) + .collect::>>()?; Ok(MemoryContext { files, value: value.into_owned(), @@ -118,7 +127,7 @@ impl save::WriterInner for MemoryWriter<'_> { .files .lock() .unwrap_or_else(|poison| poison.into_inner()) - .insert(name.clone(), self.cursor.into_inner()); + .insert(name.clone(), Some(self.cursor.into_inner())); Ok(Handle::new(name)) } } From e177357cb5d3769b1143c2b16debadf87b40993c Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 29 Jun 2026 11:19:29 -0700 Subject: [PATCH 15/23] more test coverage --- diskann-record/src/backend/disk.rs | 9 +- diskann-record/src/lib.rs | 235 +++++++++++++++++++++++++++++ diskann-record/src/load/mod.rs | 2 +- 3 files changed, 242 insertions(+), 4 deletions(-) diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index 6c63fc8be..6881021fa 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -30,7 +30,7 @@ use crate::{ /// out the same artifact name twice. /// /// # Cleanup on failure -/// +/// /// Save can fail part-way, so the [`Drop`] impl ensures cleanup of any artifacts created /// before the failure. #[derive(Debug)] @@ -159,7 +159,7 @@ impl SaveContext for DiskSaveContext { /// save is in progress, or a previous run aborted between rename steps). /// /// On failure, context is dropped without committing ==> [`Drop`] impl - /// removes the artifacts + temp manifest. Save is marked committed once + /// removes the artifacts + temp manifest. Save is marked committed once /// rename succeeds and artifacts are in place. fn finish(mut self, value: Value<'_>) -> save::Result<()> { let temp = self.temp_metadata(); @@ -450,7 +450,10 @@ mod tests { dir.path().join(&name).exists(), "a committed save must keep its artifacts" ); - assert!(metadata.exists(), "a committed save must write the manifest"); + assert!( + metadata.exists(), + "a committed save must write the manifest" + ); assert!( !ctx_temp(&metadata).exists(), "a committed save must not leave a temp manifest" diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index b21767e32..a5aa1872a 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -488,6 +488,241 @@ mod tests { Ok(()) } + ////////////////////////// + // Legacy version path // + ////////////////////////// + + // Sample record that requires a legacy version path (disk version older than loader version). + #[derive(Debug, PartialEq)] + struct Upgraded { + // Stored in the legacy (v0) record. + count: u32, + // Absent from the v0 record; reconstructed by `load_legacy` from `count`. + scaled: u32, + } + + impl save::Save for Upgraded { + // Write using an "old" schema: only `count` is written to disk. + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [count])) + } + } + + impl load::Load<'_> for Upgraded { + // New schema: differs from the `0.0.0` stamped on disk, forcing `load_legacy`. + const VERSION: Version = Version::new(1, 0, 0); + fn load(_object: load::Object<'_>) -> load::Result { + panic!("matching-version load must not run for a legacy record"); + } + fn load_legacy(object: load::Object<'_>) -> load::Result { + // Upgrade a v0 record: derive `scaled` from the stored `count`. + load_fields!(object, [count: u32]); + Ok(Self { + count, // The original count value on disk + scaled: count * 10, // "default"/derived value after upgrade + }) + } + } + + #[test] + fn legacy_record_dispatches_to_load_legacy() -> anyhow::Result<()> { + // Save stamps the old `0.0.0` schema; the loader's `1.0.0` differs, so the + // round trip must flow through `load_legacy`, which upgrades the record. + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&Upgraded { count: 4, scaled: 0 }, dir, &metadata)?; + let restored: Upgraded = load::load_from_disk(&metadata, dir)?; + + assert_eq!( + restored, + Upgraded { + count: 4, + scaled: 40 + } + ); + Ok(()) + } + + // A record whose loader has no upgrade path for the older on-disk schema: + // `load_legacy` refuses with `UnknownVersion`. + #[derive(Debug, PartialEq)] + struct NoUpgrade { + value: i32, + } + + impl save::Save for NoUpgrade { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [value])) + } + } + + impl load::Load<'_> for NoUpgrade { + const VERSION: Version = Version::new(2, 0, 0); + fn load(_object: load::Object<'_>) -> load::Result { + panic!("matching-version load must not run for a legacy record"); + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + // Check if version.major is older than 1.0.0 + if _object.version().major < 1 { + return Err(load::error::Kind::UnknownVersion.into()); + } + panic!("should not reach this point"); + } + } + + #[test] + fn legacy_record_without_upgrade_path_is_rejected() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&NoUpgrade { value: 7 }, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("a legacy record with no upgrade path must fail"); + let msg = format!("{err}"); + assert!( + msg.contains("unknown version"), + "expected UnknownVersion error, got: {msg}" + ); + Ok(()) + } + + /////////////////////////////////// + // Built-in primitive round-trip // + /////////////////////////////////// + + // Covers the built-in `Loadable`/`Saveable` impls that the structural round-trip + // tests above don't reach: the wider integer widths and every `NonZero*` type. + #[derive(Debug, PartialEq)] + struct Primitives { + a: u16, + b: u32, + c: u64, + d: i16, + e: i32, + f: i64, + g: f64, + nz32: std::num::NonZeroU32, + nz64: std::num::NonZeroU64, + nzsize: std::num::NonZeroUsize, + } + + impl save::Save for Primitives { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!( + self, + context, + [a, b, c, d, e, f, g, nz32, nz64, nzsize] + )) + } + } + + impl load::Load<'_> for Primitives { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + a: u16, + b: u32, + c: u64, + d: i16, + e: i32, + f: i64, + g: f64, + nz32: std::num::NonZeroU32, + nz64: std::num::NonZeroU64, + nzsize: std::num::NonZeroUsize, + ] + ); + Ok(Self { + a, + b, + c, + d, + e, + f, + g, + nz32, + nz64, + nzsize, + }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn builtin_primitives_round_trip() -> anyhow::Result<()> { + let value = Primitives { + a: 4242, + b: 4_000_000_000, + c: 1 << 40, + d: -12345, + e: -2_000_000_000, + f: -(1 << 40), + g: -2.5e-9, + nz32: std::num::NonZeroU32::new(7).unwrap(), + nz64: std::num::NonZeroU64::new(1 << 50).unwrap(), + nzsize: std::num::NonZeroUsize::new(99).unwrap(), + }; + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let restored: Primitives = load::load_from_disk(&metadata, dir)?; + + assert_eq!(value, restored); + Ok(()) + } + + #[test] + fn nonzero_rejects_zero_on_load() -> anyhow::Result<()> { + // A hand-crafted manifest storing `0` in a `NonZeroU32` field must be rejected + // with `NumberOutOfRange` rather than producing an invalid value. + #[derive(Debug)] + struct NzHolder { + _nz: std::num::NonZeroU32, + } + + impl load::Load<'_> for NzHolder { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [nz: std::num::NonZeroU32]); + Ok(Self { _nz: nz }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + let manifest = serde_json::json!({ + "files": [], + "value": { "$version": "0.0.0", "nz": 0 }, + }); + std::fs::write(&metadata, serde_json::to_vec(&manifest)?)?; + + let err = load::load_from_disk::(&metadata, dir) + .expect_err("zero stored in a NonZero field must be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("number out of range"), + "expected NumberOutOfRange error, got: {msg}" + ); + Ok(()) + } + /////////////////////////////// // Manifest directory escape // /////////////////////////////// diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs index fee2c55bc..3efe47540 100644 --- a/diskann-record/src/load/mod.rs +++ b/diskann-record/src/load/mod.rs @@ -37,8 +37,8 @@ pub mod error; pub use error::{Error, Result}; mod context; -pub use context::{Context, Object, Reader}; pub(crate) use context::LoadContext; +pub use context::{Context, Object, Reader}; #[cfg(feature = "disk")] use std::path::Path; From 11cbc8905ab5f8de0d7ec23264e410827bbc54e2 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Mon, 29 Jun 2026 11:22:12 -0700 Subject: [PATCH 16/23] making cargo fmt happy --- diskann-record/src/lib.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index a5aa1872a..e9c9df8ab 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -533,7 +533,14 @@ mod tests { let dir = temp_dir.path(); let metadata = dir.join("metadata.json"); - save::save_to_disk(&Upgraded { count: 4, scaled: 0 }, dir, &metadata)?; + save::save_to_disk( + &Upgraded { + count: 4, + scaled: 0, + }, + dir, + &metadata, + )?; let restored: Upgraded = load::load_from_disk(&metadata, dir)?; assert_eq!( From d5e40876e26e19a6cc5e7d5aec7cd7e9e2ae74a0 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 11:01:52 -0700 Subject: [PATCH 17/23] updated README; dropped 'patch' from VERSION --- diskann-record/README.md | 7 +++- diskann-record/src/backend/disk.rs | 2 +- diskann-record/src/backend/memory.rs | 4 +-- diskann-record/src/lib.rs | 54 ++++++++++++++-------------- diskann-record/src/value.rs | 2 +- diskann-record/src/version.rs | 41 +++++++++------------ 6 files changed, 54 insertions(+), 56 deletions(-) diff --git a/diskann-record/README.md b/diskann-record/README.md index 5731e2a68..07f59dd55 100644 --- a/diskann-record/README.md +++ b/diskann-record/README.md @@ -11,4 +11,9 @@ field-by-field plumbing for plain structs. Every record carries a `Version` so l can detect schema changes and either upgrade or fall back through a probing chain. The goal is to allow crates like `diskann` to checkpoint their state without depending on -a particular serialization backend. This crate has minimal dependencies by design. \ No newline at end of file +a particular serialization backend. Currently, this crate implements `Disk` and `Memory` backends +-- `Disk` for persistent storage and `Memory` for in-memory operations. +`Memory` pipes the output of its `SaveContext` directly into the input of its `LoadContext`, so it is not compatible with other backends. +A hypothetical `ObjectStore` backend could be implemented to be compatible with `Disk` backends to support caching and other advanced features. + +This crate has minimal dependencies by design. \ No newline at end of file diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index 6881021fa..bca439dd5 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -469,7 +469,7 @@ mod tests { fn write_manifest(dir: &Path, files: &[&str]) -> PathBuf { let manifest = serde_json::json!({ "files": files, - "value": { "$version": "0.0.0" }, + "value": { "$version": "0.0" }, }); let metadata = dir.join("metadata.json"); std::fs::write(&metadata, serde_json::to_vec(&manifest).unwrap()).unwrap(); diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index 207580975..fc334ec3e 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -177,7 +177,7 @@ mod tests { } impl save::Save for Doc { - const VERSION: Version = Version::new(1, 0, 0); + const VERSION: Version = Version::new(1, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { let mut io = context.write(Some("blob.bin"))?; io.write_all(&self.blob).map_err(save::Error::new)?; @@ -188,7 +188,7 @@ mod tests { } impl load::Load<'_> for Doc { - const VERSION: Version = Version::new(1, 0, 0); + const VERSION: Version = Version::new(1, 0); fn load(object: load::Object<'_>) -> load::Result { crate::load_fields!(object, [name: String, blob: save::Handle]); let mut io = object.read(&blob)?; diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs index e9c9df8ab..79a907da3 100644 --- a/diskann-record/src/lib.rs +++ b/diskann-record/src/lib.rs @@ -44,14 +44,14 @@ //! struct Config { dim: usize, label: String } //! //! impl save::Save for Config { -//! const VERSION: Version = Version::new(0, 0, 0); +//! const VERSION: Version = Version::new(0, 0); //! fn save(&self, context: save::Context<'_>) -> save::Result> { //! Ok(diskann_record::save_fields!(self, context, [dim, label])) //! } //! } //! //! impl load::Load<'_> for Config { -//! const VERSION: Version = Version::new(0, 0, 0); +//! const VERSION: Version = Version::new(0, 0); //! fn load(object: load::Object<'_>) -> load::Result { //! diskann_record::load_fields!(object, [dim: usize, label: String]); //! Ok(Self { dim, label }) @@ -157,7 +157,7 @@ mod tests { } impl save::Save for Inner { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { Ok(save_fields!( self, @@ -168,7 +168,7 @@ mod tests { } impl save::Save for Test { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { // We save `x`, `y`, and `inner` directly into the manifest. // The raw vector data we instead store in an auxiliary file. @@ -183,7 +183,7 @@ mod tests { } impl load::Load<'_> for Test { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!( object, @@ -219,7 +219,7 @@ mod tests { } impl load::Load<'_> for Inner { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!( object, @@ -293,7 +293,7 @@ mod tests { } impl save::Save for Metric { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { let mut record = save::Record::empty(); match self { @@ -313,7 +313,7 @@ mod tests { } impl load::Load<'_> for Metric { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { match object.single_key()? { "L2" => Ok(Self::L2), @@ -341,14 +341,14 @@ mod tests { } impl save::Save for MetricBag { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { Ok(save_fields!(self, context, [primary, alternatives])) } } impl load::Load<'_> for MetricBag { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!(object, [primary: Metric, alternatives: Vec]); Ok(Self { @@ -391,14 +391,14 @@ mod tests { } impl save::Save for StructShape { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { Ok(save_fields!(self, context, [x])) } } impl load::Load<'_> for StructShape { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!(object, [x: i32]); Ok(Self { x }) @@ -414,7 +414,7 @@ mod tests { } impl save::Save for EnumShape { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { let mut record = save::Record::empty(); match self { @@ -428,7 +428,7 @@ mod tests { } impl load::Load<'_> for EnumShape { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { match object.single_key()? { "Only" => { @@ -503,15 +503,15 @@ mod tests { impl save::Save for Upgraded { // Write using an "old" schema: only `count` is written to disk. - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { Ok(save_fields!(self, context, [count])) } } impl load::Load<'_> for Upgraded { - // New schema: differs from the `0.0.0` stamped on disk, forcing `load_legacy`. - const VERSION: Version = Version::new(1, 0, 0); + // New schema: differs from the `0.0` stamped on disk, forcing `load_legacy`. + const VERSION: Version = Version::new(1, 0); fn load(_object: load::Object<'_>) -> load::Result { panic!("matching-version load must not run for a legacy record"); } @@ -527,7 +527,7 @@ mod tests { #[test] fn legacy_record_dispatches_to_load_legacy() -> anyhow::Result<()> { - // Save stamps the old `0.0.0` schema; the loader's `1.0.0` differs, so the + // Save stamps the old `0.0` schema; the loader's `1.0` differs, so the // round trip must flow through `load_legacy`, which upgrades the record. let temp_dir = tempfile::tempdir()?; let dir = temp_dir.path(); @@ -561,19 +561,19 @@ mod tests { } impl save::Save for NoUpgrade { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { Ok(save_fields!(self, context, [value])) } } impl load::Load<'_> for NoUpgrade { - const VERSION: Version = Version::new(2, 0, 0); + const VERSION: Version = Version::new(2, 0); fn load(_object: load::Object<'_>) -> load::Result { panic!("matching-version load must not run for a legacy record"); } fn load_legacy(_object: load::Object<'_>) -> load::Result { - // Check if version.major is older than 1.0.0 + // Check if version.major is older than 1.0 if _object.version().major < 1 { return Err(load::error::Kind::UnknownVersion.into()); } @@ -619,7 +619,7 @@ mod tests { } impl save::Save for Primitives { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn save(&self, context: save::Context<'_>) -> save::Result> { Ok(save_fields!( self, @@ -630,7 +630,7 @@ mod tests { } impl load::Load<'_> for Primitives { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!( object, @@ -701,7 +701,7 @@ mod tests { } impl load::Load<'_> for NzHolder { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!(object, [nz: std::num::NonZeroU32]); Ok(Self { _nz: nz }) @@ -716,7 +716,7 @@ mod tests { let metadata = dir.join("metadata.json"); let manifest = serde_json::json!({ "files": [], - "value": { "$version": "0.0.0", "nz": 0 }, + "value": { "$version": "0.0", "nz": 0 }, }); std::fs::write(&metadata, serde_json::to_vec(&manifest)?)?; @@ -743,7 +743,7 @@ mod tests { } impl load::Load<'_> for HandleOnly { - const VERSION: Version = Version::new(0, 0, 0); + const VERSION: Version = Version::new(0, 0); fn load(object: load::Object<'_>) -> load::Result { load_fields!(object, [blob: save::Handle]); let mut io = object.read(&blob)?; @@ -765,7 +765,7 @@ mod tests { // path-shape check as the thing rejecting the load. "files": [handle_target], "value": { - "$version": "0.0.0", + "$version": "0.0", "blob": { "$handle": handle_target }, }, }); diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index 9b327150b..04dc5504d 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -433,7 +433,7 @@ mod tests { #[cfg(feature = "disk")] #[test] fn deserialize_rejects_handle_with_extra_fields() { - let json = r#"{ "$handle": "a.bin", "$version": "0.0.0" }"#; + let json = r#"{ "$handle": "a.bin", "$version": "0.0" }"#; serde_json::from_str::>(json) .expect_err("handle object with extra fields must be rejected"); } diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs index 175c3893f..444e9dded 100644 --- a/diskann-record/src/version.rs +++ b/diskann-record/src/version.rs @@ -16,32 +16,30 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; /// [`Load::load`](crate::load::Load::load) and /// [`Load::load_legacy`](crate::load::Load::load_legacy). /// -/// The framework treats versions as opaque triples and only checks them for equality; +/// The framework treats versions as opaque pairs and only checks them for equality; /// ordering / semver semantics are entirely up to the implementing type. /// +/// The `major.minor` format aids code readability: a reader can tell at a glance +/// that, for example, version `1.0` is compatible with version `1.2`. +/// /// On the wire, a `Version` is encoded as a single string of the form -/// `"major.minor.patch"` (e.g. `"0.0.0"`). +/// `"major.minor"` (e.g. `"0.0"`). #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Version { pub major: u32, pub minor: u32, - pub patch: u32, } impl Version { - /// Construct a [`Version`] from its three components. - pub const fn new(major: u32, minor: u32, patch: u32) -> Self { - Self { - major, - minor, - patch, - } + /// Construct a [`Version`] from its two components. + pub const fn new(major: u32, minor: u32) -> Self { + Self { major, minor } } } impl std::fmt::Display for Version { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + write!(f, "{}.{}", self.major, self.minor) } } @@ -52,13 +50,8 @@ impl std::str::FromStr for Version { let mut parts = s.split('.'); let major = parts.next().and_then(|s| s.parse::().ok()); let minor = parts.next().and_then(|s| s.parse::().ok()); - let patch = parts.next().and_then(|s| s.parse::().ok()); - match (major, minor, patch, parts.next()) { - (Some(major), Some(minor), Some(patch), None) => Ok(Version { - major, - minor, - patch, - }), + match (major, minor, parts.next()) { + (Some(major), Some(minor), None) => Ok(Version { major, minor }), _ => Err(ParseVersionError(s.to_owned())), } } @@ -72,7 +65,7 @@ impl std::fmt::Display for ParseVersionError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "unknown version {:?}: expected three `.`-separated u32 components", + "unknown version {:?}: expected two `.`-separated u32 components", self.0, ) } @@ -96,7 +89,7 @@ impl<'de> Deserialize<'de> for Version { type Value = Version; fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.write_str("a version string of the form \"major.minor.patch\"") + f.write_str("a version string of the form \"major.minor\"") } fn visit_str(self, v: &str) -> Result { @@ -114,20 +107,20 @@ mod tests { #[test] fn serializes_as_dotted_string() { - let json = serde_json::to_string(&Version::new(1, 2, 3)).unwrap(); - assert_eq!(json, "\"1.2.3\""); + let json = serde_json::to_string(&Version::new(1, 2)).unwrap(); + assert_eq!(json, "\"1.2\""); } #[test] fn round_trips_through_json() { - let v = Version::new(4, 5, 6); + let v = Version::new(4, 5); let back: Version = serde_json::from_str(&serde_json::to_string(&v).unwrap()).unwrap(); assert_eq!(v, back); } #[test] fn rejects_malformed_strings() { - for bad in ["\"1.2\"", "\"1.2.3.4\"", "\"1.x.3\"", "\"abc\""] { + for bad in ["\"1\"", "\"1.2.3\"", "\"1.x\"", "\"abc\""] { serde_json::from_str::(bad) .expect_err("malformed version string must be rejected"); } From 89b894fdacabf042b0625908d8e6d9a41a3ad3b1 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 11:20:25 -0700 Subject: [PATCH 18/23] improved test coverage for disk + memory backends --- diskann-record/src/backend/disk.rs | 102 +++++++++++++++++++++++++++ diskann-record/src/backend/memory.rs | 52 ++++++++++++++ diskann-record/src/load/context.rs | 52 ++++++++++++++ 3 files changed, 206 insertions(+) diff --git a/diskann-record/src/backend/disk.rs b/diskann-record/src/backend/disk.rs index bca439dd5..914fc1138 100644 --- a/diskann-record/src/backend/disk.rs +++ b/diskann-record/src/backend/disk.rs @@ -379,6 +379,76 @@ mod tests { assert!(!handle.as_str().is_empty()); } + #[test] + fn write_rejects_preexisting_file_on_disk() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + // The first artifact is named with a `000` count prefix; pre-create that exact file + // on disk so the `full.exists()` guard rejects the allocation. + std::fs::write(dir.path().join("000-artifact.bin"), b"stale").unwrap(); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("an artifact whose file already exists on disk must be rejected"); + assert!(format!("{err}").contains("already exists")); + } + + #[test] + fn finish_rejects_unfinished_artifact() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + // Reserve an artifact slot but drop the writer without calling `finish`, leaving the + // slot marked as not-yet-finished. + let writer = SaveContext::write(&ctx, Some("artifact.bin")).unwrap(); + drop(writer); + let err = ctx + .finish(Value::Null) + .expect_err("finish must fail when an artifact was reserved but never finished"); + assert!(format!("{err}").contains("was reserved but never finished")); + } + + #[test] + fn write_rejects_name_collision() { + let dir = tempfile::tempdir().unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), dir.path().join("meta.json")).unwrap(); + // The count prefix normally makes a collision impossible, so seed the bookkeeping map + // directly with the exact name the next `write` will generate (one entry => count 1 => + // `001-artifact.bin`). + ctx.files + .lock() + .unwrap() + .insert("001-artifact.bin".to_string(), true); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("a generated name that is already registered must be rejected"); + assert!(format!("{err}").contains("collides with an existing artifact")); + } + + #[test] + fn write_reports_file_creation_failure() { + let dir = tempfile::tempdir().unwrap(); + let artifacts = dir.path().join("artifacts"); + std::fs::create_dir(&artifacts).unwrap(); + let ctx = DiskSaveContext::new(artifacts.clone(), dir.path().join("meta.json")).unwrap(); + // Remove the validated directory so `create_new` fails with a non-"exists" IO error + // (the `full.exists()` guard passes because the path is gone). + std::fs::remove_dir(&artifacts).unwrap(); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("creating an artifact in a missing directory must fail"); + assert!(format!("{err}").contains("while creating new file")); + } + + #[test] + fn finish_reports_rename_failure() { + let dir = tempfile::tempdir().unwrap(); + // Make the final manifest path an existing directory: the `.temp` file is + // created and serialized fine, but renaming a file onto a directory fails. + let metadata = dir.path().join("meta.json"); + std::fs::create_dir(&metadata).unwrap(); + let ctx = DiskSaveContext::new(dir.path().into(), metadata.clone()).unwrap(); + let err = ctx + .finish(Value::Null) + .expect_err("renaming the temp manifest onto a directory must fail"); + assert!(format!("{err}").contains("while renaming temp manifest")); + } + #[test] fn drop_without_finish_cleans_up_artifacts() { let dir = tempfile::tempdir().unwrap(); @@ -498,4 +568,36 @@ mod tests { }; assert!(format!("{err}").contains("escapes the manifest directory")); } + + #[test] + fn read_reports_missing_artifact_file() { + let dir = tempfile::tempdir().unwrap(); + // Register the artifact in the manifest but never create the file on disk, so the + // path/registration checks pass and only the file `open` fails. + let metadata = write_manifest(dir.path(), &["artifact.bin"]); + let ctx = DiskLoadContext::new(&metadata, dir.path()).unwrap(); + let Err(err) = ctx.read("artifact.bin") else { + panic!("a registered artifact missing from disk must be reported"); + }; + assert!(format!("{err}").contains("while opening artifact file")); + } + + #[test] + fn new_load_reports_missing_manifest() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("does-not-exist.json"); + let err = DiskLoadContext::new(&metadata, dir.path()) + .expect_err("a missing manifest file must be reported"); + assert!(format!("{err}").contains("while trying to open")); + } + + #[test] + fn new_load_rejects_malformed_manifest() { + let dir = tempfile::tempdir().unwrap(); + let metadata = dir.path().join("metadata.json"); + std::fs::write(&metadata, b"this is not json").unwrap(); + let err = DiskLoadContext::new(&metadata, dir.path()) + .expect_err("a malformed manifest must be rejected"); + assert!(format!("{err}").contains("could not deserialize manifest")); + } } diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index fc334ec3e..bf2ffa0a6 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -226,4 +226,56 @@ mod tests { .expect("an unregistered artifact must be rejected"); assert!(format!("{err}").contains("not registered in this context")); } + + #[test] + fn write_rejects_name_collision() { + let ctx = MemorySaveContext::new(); + // The count prefix normally makes a collision impossible, so seed the bookkeeping map + // directly with the exact name the next `write` will generate (one entry => count 1 => + // `001-artifact.bin`). + ctx.files + .lock() + .unwrap() + .insert("001-artifact.bin".to_string(), None); + let err = SaveContext::write(&ctx, Some("artifact.bin")) + .expect_err("a generated name that is already registered must be rejected"); + assert!(format!("{err}").contains("collides with an existing artifact")); + } + + #[test] + fn finish_rejects_unfinished_artifact() { + let ctx = MemorySaveContext::new(); + // Reserve an artifact slot but drop the writer without calling `finish`, leaving the + // slot empty. + let writer = SaveContext::write(&ctx, Some("artifact.bin")).unwrap(); + drop(writer); + let err = ctx + .finish(Value::Null) + .expect_err("finish must fail when an artifact was reserved but never finished"); + assert!(format!("{err}").contains("was reserved but never finished")); + } + + #[test] + fn write_names_anonymous_artifact_with_count_prefix() { + let ctx = MemorySaveContext::new(); + // Passing `None` as the hint exercises the count-only naming branch. + let handle = SaveContext::write(&ctx, None).unwrap().finish().unwrap(); + assert_eq!(handle.as_str(), "000"); + } + + #[test] + fn load_dispatches_to_load_legacy_on_version_mismatch() { + // Build an object whose version does not match `Doc::VERSION` (1.0) so the `Loadable` + // blanket dispatches to `Doc::load_legacy`, which refuses with `UnknownVersion`. + let value = save::Record::empty() + .into_value(Version::new(2, 0)) + .into_owned(); + let context = MemoryContext { + files: HashMap::new(), + value, + }; + let err = load::load::(&context) + .expect_err("a version mismatch must dispatch to load_legacy, which refuses"); + assert!(format!("{err}").contains("unknown version")); + } } diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs index aa77fdcdc..5a4fd51ac 100644 --- a/diskann-record/src/load/context.rs +++ b/diskann-record/src/load/context.rs @@ -379,3 +379,55 @@ impl<'a> Iterator for Iter<'a> { } impl ExactSizeIterator for Iter<'_> {} + +#[cfg(test)] +mod tests { + use super::*; + + /// Minimal [`LoadContext`] used to anchor a [`Context`] in tests; the `Object` + /// methods under test never touch the context's `value` / `read` operations. + struct TestContext; + + impl LoadContext for TestContext { + fn value(&self) -> Result<&save::Value<'_>> { + unimplemented!("not exercised by Object accessor tests") + } + + fn read(&self, _key: &str) -> Result> { + unimplemented!("not exercised by Object accessor tests") + } + } + + #[test] + fn object_reports_populated_keys() { + let mut record = save::Record::empty(); + record.insert("alpha", save::Value::Null).unwrap(); + record.insert("beta", save::Value::Bool(true)).unwrap(); + let value = record.into_value(Version::new(1, 0)); + + let inner = TestContext; + let object = Context::new(&inner, &value) + .as_object() + .expect("a versioned object value must yield an Object"); + + assert_eq!(object.len(), 2); + assert!(!object.is_empty()); + let mut keys: Vec<&str> = object.keys().collect(); + keys.sort_unstable(); + assert_eq!(keys, ["alpha", "beta"]); + } + + #[test] + fn object_reports_empty_record() { + let value = save::Record::empty().into_value(Version::new(1, 0)); + + let inner = TestContext; + let object = Context::new(&inner, &value) + .as_object() + .expect("a versioned object value must yield an Object"); + + assert_eq!(object.len(), 0); + assert!(object.is_empty()); + assert_eq!(object.keys().count(), 0); + } +} From aaaea47deee3351c151115366dfb199dceec3fd0 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 11:29:02 -0700 Subject: [PATCH 19/23] bugfix NaN + infinity handling (thanks to Jordan) --- diskann-record/src/backend/memory.rs | 49 ++++++++++++++++++++++++++++ diskann-record/src/number.rs | 23 +++++++++++++ diskann-record/src/value.rs | 18 ++++++++++ 3 files changed, 90 insertions(+) diff --git a/diskann-record/src/backend/memory.rs b/diskann-record/src/backend/memory.rs index bf2ffa0a6..8a980a078 100644 --- a/diskann-record/src/backend/memory.rs +++ b/diskann-record/src/backend/memory.rs @@ -278,4 +278,53 @@ mod tests { .expect_err("a version mismatch must dispatch to load_legacy, which refuses"); assert!(format!("{err}").contains("unknown version")); } + + // Regression for the NaN float-field bug, demonstrated end-to-end through the in-memory + // backend (no serde / disk required). A struct carrying NaN-valued `f64` / `f32` fields is + // saved and reloaded; the reload previously FAILED with `NumberOutOfRange` because + // `Number::as_f64` / `as_f32` rejected NaN. + #[derive(Debug)] + struct Floats { + finite: f64, + nan64: f64, + nan32: f32, + } + + impl save::Save for Floats { + const VERSION: Version = Version::new(1, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(crate::save_fields!(self, context, [finite, nan64, nan32])) + } + } + + impl load::Load<'_> for Floats { + const VERSION: Version = Version::new(1, 0); + fn load(object: load::Object<'_>) -> load::Result { + crate::load_fields!(object, [finite: f64, nan64: f64, nan32: f32]); + Ok(Self { + finite, + nan64, + nan32, + }) + } + fn load_legacy(_: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[test] + fn nan_float_fields_round_trip_in_memory() { + let value = Floats { + finite: 1.5, + nan64: f64::NAN, + nan32: f32::NAN, + }; + + let context = save::save(&value, MemorySaveContext::new()).unwrap(); + let restored: Floats = load::load(&context).expect("NaN float fields must reload"); + + assert_eq!(restored.finite, 1.5); + assert!(restored.nan64.is_nan(), "f64 NaN field lost on reload"); + assert!(restored.nan32.is_nan(), "f32 NaN field lost on reload"); + } } diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index 82d123c70..a331c17e3 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -236,6 +236,29 @@ mod tests { assert_eq!(Number::F64(-1.0).as_u32(), None); } + // Regression for the NaN narrowing-accessor bug. + #[test] + fn nan_survives_float_accessors() { + assert_eq!( + Number::F64(f64::NAN).as_f64().map(f64::is_nan), + Some(true), + "as_f64 dropped a NaN" + ); + assert_eq!( + Number::F64(f64::NAN).as_f32().map(f32::is_nan), + Some(true), + "as_f32 dropped a NaN" + ); + // Sanity: infinities already work, so this is specific to NaN. + assert_eq!(Number::F64(f64::INFINITY).as_f64(), Some(f64::INFINITY)); + } + + #[test] + fn nan_round_trips_through_try_from() { + let back = f64::try_from(Number::F64(f64::NAN)).expect("NaN must convert back to f64"); + assert!(back.is_nan()); + } + #[test] fn try_from_surfaces_out_of_range() { assert!(u8::try_from(Number::U64(300)).is_err()); diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index 04dc5504d..742912cf8 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -445,4 +445,22 @@ mod tests { serde_json::from_str::>(json) .expect_err("object without $version or $handle must be rejected"); } + + // Regression for the string/float-sentinel collision. A `String` value whose contents + // are exactly a non-finite float sentinel (`"nan"`, `"inf"`, `"neg_inf"`) serializes to + // the same wire token as `Number::F64`, so the manifest deserializer (`visit_str`) + // resurrects it as a `Number` instead of a `String`. This corrupts any string field + // that happens to hold one of those three values. + #[cfg(feature = "disk")] + #[test] + fn string_equal_to_float_sentinel_stays_a_string() { + for s in ["nan", "inf", "neg_inf"] { + let json = serde_json::to_string(&Value::String(Cow::Borrowed(s))).unwrap(); + let back: Value<'static> = serde_json::from_str(&json).unwrap(); + match back { + Value::String(v) => assert_eq!(v, s), + other => panic!("string {s:?} deserialized as a non-string value: {other:?}"), + } + } + } } From eb85f2631ddbc6b387e36fda76d309f3d592d95f Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 11:35:41 -0700 Subject: [PATCH 20/23] NaN + inf handling part 2 -- fix test --- diskann-record/src/number.rs | 3 +++ diskann-record/src/value.rs | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index a331c17e3..f6e637c4f 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -130,6 +130,9 @@ macro_rules! float { match self { Self::U64(v) => try_cast!(v:u64 => $T), Self::I64(v) => try_cast!(v:i64 => $T), + // NaN is representable in every float target but never equals itself, so the + // `try_cast!` round-trip guard would reject it; pass it through directly. + Self::F64(v) if v.is_nan() => Some(v as $T), Self::F64(v) => try_cast!(v:f64 => $T), } } diff --git a/diskann-record/src/value.rs b/diskann-record/src/value.rs index 742912cf8..c81b9ad73 100644 --- a/diskann-record/src/value.rs +++ b/diskann-record/src/value.rs @@ -446,20 +446,21 @@ mod tests { .expect_err("object without $version or $handle must be rejected"); } - // Regression for the string/float-sentinel collision. A `String` value whose contents - // are exactly a non-finite float sentinel (`"nan"`, `"inf"`, `"neg_inf"`) serializes to - // the same wire token as `Number::F64`, so the manifest deserializer (`visit_str`) - // resurrects it as a `Number` instead of a `String`. This corrupts any string field - // that happens to hold one of those three values. + // The non-finite float sentinels (`"nan"`, `"inf"`, `"neg_inf"`) are reserved wire + // tokens: a JSON string equal to one of them always deserializes back as the + // corresponding `Number::F64`, never as a `Value::String`. This is intentional — those + // exact strings are not used as user string values in this repository, so the + // round-trip ambiguity is resolved in favor of the float sentinel. #[cfg(feature = "disk")] #[test] - fn string_equal_to_float_sentinel_stays_a_string() { - for s in ["nan", "inf", "neg_inf"] { + fn float_sentinel_strings_deserialize_as_numbers() { + for (s, is_nan) in [("nan", true), ("inf", false), ("neg_inf", false)] { let json = serde_json::to_string(&Value::String(Cow::Borrowed(s))).unwrap(); let back: Value<'static> = serde_json::from_str(&json).unwrap(); match back { - Value::String(v) => assert_eq!(v, s), - other => panic!("string {s:?} deserialized as a non-string value: {other:?}"), + Value::Number(Number::F64(v)) if is_nan => assert!(v.is_nan()), + Value::Number(Number::F64(_)) => {} + other => panic!("sentinel {s:?} deserialized as {other:?}, expected Number::F64"), } } } From bf8c31890738979299340bae6d6e18c623e91e59 Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 11:53:12 -0700 Subject: [PATCH 21/23] improve test coverage for number.rs --- diskann-record/src/number.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index f6e637c4f..6d7c91500 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -296,4 +296,33 @@ mod tests { assert_eq!(serde_json::to_string(&Number::U64(7)).unwrap(), "7"); assert_eq!(serde_json::to_string(&Number::I64(-7)).unwrap(), "-7"); } + + #[cfg(feature = "disk")] + #[test] + fn json_numbers_select_matching_variant() { + // Unsigned, signed, and floating-point literals each route to the visitor + // method that preserves the writer's chosen kind. + assert!(matches!( + serde_json::from_str("7").unwrap(), + Number::U64(7) + )); + assert!(matches!( + serde_json::from_str("-7").unwrap(), + Number::I64(-7) + )); + match serde_json::from_str("1.5").unwrap() { + Number::F64(v) => assert_eq!(v, 1.5), + other => panic!("expected F64, got {other:?}"), + } + } + + #[cfg(feature = "disk")] + #[test] + fn non_sentinel_string_is_rejected() { + let err = serde_json::from_str::("\"bogus\"").unwrap_err(); + assert!( + err.to_string().contains("numeric sentinel"), + "unexpected error message: {err}" + ); + } } From 740e9776d1db781f520549b5c1008de9bcf2404f Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 12:52:25 -0700 Subject: [PATCH 22/23] make cargo fmt happy --- diskann-record/src/number.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs index 6d7c91500..3a2b8b34f 100644 --- a/diskann-record/src/number.rs +++ b/diskann-record/src/number.rs @@ -302,10 +302,7 @@ mod tests { fn json_numbers_select_matching_variant() { // Unsigned, signed, and floating-point literals each route to the visitor // method that preserves the writer's chosen kind. - assert!(matches!( - serde_json::from_str("7").unwrap(), - Number::U64(7) - )); + assert!(matches!(serde_json::from_str("7").unwrap(), Number::U64(7))); assert!(matches!( serde_json::from_str("-7").unwrap(), Number::I64(-7) From c2e83b184289c26b579546f54ec97232fd49c13b Mon Sep 17 00:00:00 2001 From: Suhas Jayaram Subramanya Date: Tue, 30 Jun 2026 12:59:30 -0700 Subject: [PATCH 23/23] readme update --- diskann-record/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diskann-record/README.md b/diskann-record/README.md index 07f59dd55..42dd56990 100644 --- a/diskann-record/README.md +++ b/diskann-record/README.md @@ -1,9 +1,9 @@ # DiskANN Record -This crate provides a small framework for persisting structured Rust values as a -manifest (can be serialized to JSON) plus a set of side-car binary artifacts, and -reloading them later. It is can be used by `diskann` providers and indexes to -implement durable, consistent and backward-compatible checkpoints. +This crate provides a small framework for persisting structured Rust values (`struct`, `enum`, +`Option`, `Vec`, and primitives) as a manifest (can be serialized to JSON) plus a set +of side-car binary artifacts, and reloading them later. It is can be used by `diskann` +providers and indexes to implement durable, consistent and backward-compatible checkpoints. Types describe how they map to a versioned record by implementing the `save::Save` and `load::Load` traits; the `save_fields!` and `load_fields!` macros handle the