diff --git a/rust/lance-arrow/src/ipc.rs b/rust/lance-arrow/src/ipc.rs index 1c6364c4525..8b6e5cf41fe 100644 --- a/rust/lance-arrow/src/ipc.rs +++ b/rust/lance-arrow/src/ipc.rs @@ -270,7 +270,7 @@ pub fn read_ipc_stream_single_at( /// Modern IPC streams have an 8-byte prefix `[continuation: 4][size: 4]`. /// Legacy streams have a 4-byte prefix `[size: 4]`. Returns `(prefix_len, meta_size)`. fn parse_ipc_message_prefix(buf: &Buffer) -> Result<(usize, usize), ArrowError> { - let has_continuation = buf.len() >= 4 && buf[..4] == [0xff; 4]; + let has_continuation = buf.len() >= 4 && buf[..4] == IPC_CONTINUATION; if has_continuation { if buf.len() < 8 { return Err(ArrowError::ParseError( @@ -358,6 +358,134 @@ pub fn read_ipc_stream_single(data: &Bytes) -> Result { } } +// --------------------------------------------------------------------------- +// Aligned IPC sections +// --------------------------------------------------------------------------- + +/// Byte alignment that each IPC section's stream start is padded to. +/// +/// When several IPC streams are concatenated into one larger blob (e.g. a +/// cache entry), a section that starts at an arbitrary offset would leave its +/// array data misaligned. [`FileDecoder`] with `require_alignment = false` +/// then silently copies each buffer into a freshly aligned allocation on +/// every read, defeating zero-copy. Padding each section start to a 64-byte +/// boundary keeps the decoded buffers borrowed directly from the input. +pub const IPC_SECTION_ALIGNMENT: usize = 64; + +/// Number of zero-padding bytes needed to advance `pos` to the next +/// [`IPC_SECTION_ALIGNMENT`] boundary. +fn section_padding(pos: usize) -> usize { + (IPC_SECTION_ALIGNMENT - (pos % IPC_SECTION_ALIGNMENT)) % IPC_SECTION_ALIGNMENT +} + +/// A [`Write`] adapter that counts the bytes written through it. +struct CountingWriter<'a> { + inner: &'a mut dyn Write, + count: usize, +} + +impl Write for CountingWriter<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let n = self.inner.write(buf)?; + self.count += n; + Ok(n) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.inner.flush() + } +} + +/// Write zero padding so the next byte lands on an [`IPC_SECTION_ALIGNMENT`] +/// boundary, advancing `pos` past it. +fn write_section_padding(writer: &mut dyn Write, pos: &mut usize) -> Result<(), ArrowError> { + let pad = section_padding(*pos); + if pad > 0 { + const ZEROS: [u8; IPC_SECTION_ALIGNMENT] = [0u8; IPC_SECTION_ALIGNMENT]; + writer + .write_all(&ZEROS[..pad]) + .map_err(|e| ArrowError::IoError(e.to_string(), e))?; + *pos += pad; + } + Ok(()) +} + +/// Write `batch` as a 64-byte-aligned single-batch Arrow IPC section. +/// +/// `pos` is the absolute byte offset of `writer` within the enclosing blob. +/// Zero padding is written first so the IPC stream begins on an +/// [`IPC_SECTION_ALIGNMENT`] boundary, then the stream itself. `pos` is +/// advanced past both the padding and the stream so the caller can write +/// further aligned sections. +/// +/// Paired with [`read_ipc_section_at`]. For the decoded buffers to be borrowed +/// zero-copy, the blob must ultimately be read back from a buffer whose base +/// address is at least 64-byte aligned. +pub fn write_ipc_section( + writer: &mut dyn Write, + pos: &mut usize, + batch: &RecordBatch, +) -> Result<(), ArrowError> { + write_section_padding(writer, pos)?; + + let mut counting = CountingWriter { + inner: writer, + count: 0, + }; + write_ipc_stream(batch, &mut counting)?; + *pos += counting.count; + Ok(()) +} + +/// Read a single [`RecordBatch`] from an aligned IPC section at `offset`. +/// +/// Skips the alignment padding written by [`write_ipc_section`], then reads +/// the stream, advancing `offset` past the section (padding + stream + EOS). +/// +/// Zero-copy: array buffers borrow from `data`'s allocation when `data`'s base +/// address is at least 64-byte aligned (see [`write_ipc_section`]). +pub fn read_ipc_section_at(data: &Bytes, offset: &mut usize) -> Result { + *offset += section_padding(*offset); + read_ipc_stream_single_at(data, offset) +} + +/// Write `batches` as a single 64-byte-aligned multi-batch Arrow IPC section. +/// +/// Like [`write_ipc_section`] but emits every batch from `iter` into one IPC +/// stream (schema + N batches + EOS). `iter` must yield at least one batch. +/// Paired with [`read_ipc_section_batches_at`]. +pub fn write_ipc_section_batches( + writer: &mut dyn Write, + pos: &mut usize, + iter: I, +) -> Result<(), ArrowError> +where + I: IntoIterator, +{ + write_section_padding(writer, pos)?; + + let mut counting = CountingWriter { + inner: writer, + count: 0, + }; + write_ipc_stream_batches(iter, &mut counting)?; + *pos += counting.count; + Ok(()) +} + +/// Read all [`RecordBatch`]es from an aligned multi-batch IPC section at +/// `offset`, advancing `offset` past the section (padding + stream + EOS). +/// +/// Zero-copy: array buffers borrow from `data`'s allocation when `data`'s base +/// address is at least 64-byte aligned (see [`write_ipc_section_batches`]). +pub fn read_ipc_section_batches_at( + data: &Bytes, + offset: &mut usize, +) -> Result, ArrowError> { + *offset += section_padding(*offset); + read_ipc_stream_at(data, offset) +} + #[cfg(test)] mod tests { use arrow_array::{ArrayRef, record_batch}; @@ -403,4 +531,90 @@ mod tests { assert_col_zero_copy(batch.column(1)); } } + + /// Allocate a [`Bytes`] whose base address is 64-byte aligned, modelling a + /// backend that reads cache entries into an aligned buffer. A plain + /// `Bytes::from(vec)` only guarantees the allocator's alignment for `u8`. + fn aligned_bytes(payload: &[u8]) -> Bytes { + let mut v = vec![0u8; payload.len() + IPC_SECTION_ALIGNMENT]; + let pad = section_padding(v.as_ptr() as usize); + v[pad..pad + payload.len()].copy_from_slice(payload); + Bytes::from(v).slice(pad..pad + payload.len()) + } + + #[test] + fn test_aligned_ipc_sections_are_zero_copy() { + // A LargeBinary column exercises the i64-offset buffer whose 8-byte + // alignment requirement triggers a realigning memcpy when misaligned. + let blocks = arrow_array::LargeBinaryArray::from_vec(vec![&b"hello"[..], b"world"]); + let section_a = RecordBatch::try_from_iter([("a", Arc::new(blocks) as ArrayRef)]).unwrap(); + let section_b = record_batch!(("b", Int64, [10i64, 20, 30, 40, 50])).unwrap(); + + let mut buf = Vec::new(); + // Arbitrary, deliberately non-64-aligned preamble so the first section + // must be padded rather than landing at offset 0 by luck. + buf.extend_from_slice(&[0xABu8; 7]); + let mut pos = buf.len(); + // The first section's stream begins after padding the 7-byte preamble + // up to the next 64-byte boundary. + assert_eq!(7 + section_padding(7), IPC_SECTION_ALIGNMENT); + write_ipc_section(&mut buf, &mut pos, §ion_a).unwrap(); + write_ipc_section(&mut buf, &mut pos, §ion_b).unwrap(); + + let data = aligned_bytes(&buf); + assert_eq!( + section_padding(data.as_ptr() as usize), + 0, + "base not aligned" + ); + + let mut offset = 7; + let read_a = read_ipc_section_at(&data, &mut offset).unwrap(); + let read_b = read_ipc_section_at(&data, &mut offset).unwrap(); + assert_eq!(read_a, section_a); + assert_eq!(read_b, section_b); + + let data_base = data.as_ptr() as usize; + let data_end = data_base + data.len(); + for batch in [&read_a, &read_b] { + for buffer in batch.column(0).to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= data_base && ptr < data_end, + "section buffer at {ptr:#x} was realigned out of the input \ + [{data_base:#x}..{data_end:#x}) — misaligned section", + ); + } + } + } + + #[test] + fn test_aligned_multi_batch_section_roundtrip_zero_copy() { + // A multi-batch section (e.g. IVF SQ storage chunks) must round-trip + // every batch and decode the first batch's buffers zero-copy. + let b1 = record_batch!(("v", Int64, [1i64, 2, 3])).unwrap(); + let b2 = record_batch!(("v", Int64, [4i64, 5])).unwrap(); + let b3 = record_batch!(("v", Int64, [6i64])).unwrap(); + + let mut buf = vec![0xCDu8; 5]; + let mut pos = buf.len(); + write_ipc_section_batches(&mut buf, &mut pos, [b1.clone(), b2.clone(), b3.clone()]) + .unwrap(); + + let data = aligned_bytes(&buf); + let mut offset = 5; + let read = read_ipc_section_batches_at(&data, &mut offset).unwrap(); + assert_eq!(read, vec![b1, b2, b3]); + assert_eq!(offset, buf.len(), "offset should land at section end"); + + let data_base = data.as_ptr() as usize; + let data_end = data_base + data.len(); + for buffer in read[0].column(0).to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= data_base && ptr < data_end, + "first batch buffer at {ptr:#x} was realigned out of the input", + ); + } + } } diff --git a/rust/lance-core/src/cache/codec.rs b/rust/lance-core/src/cache/codec.rs index 34e5264bb28..bba54840829 100644 --- a/rust/lance-core/src/cache/codec.rs +++ b/rust/lance-core/src/cache/codec.rs @@ -5,12 +5,184 @@ //! //! Implement [`CacheCodecImpl`] on concrete types, then use //! [`CacheCodec::from_impl`] to produce a type-erased codec for the cache. +//! +//! # Wire format +//! +//! Every serialized entry begins with a small hand-framed **envelope** so the +//! reader can validate it before trusting the body: +//! +//! ```text +//! [magic: 4B = b"LCE1"] +//! [envelope_version: u8] +//! [type_id_len: u16 LE][type_id: utf8] # stable, author-assigned +//! [type_version: u32 LE] # per-type body schema version +//! +//! ``` +//! +//! The envelope is deliberately *not* protobuf: it is the most +//! stability-critical part, must parse robustly against arbitrary bytes +//! (including data written by older, pre-stabilization builds), and never +//! changes shape. Bodies use protobuf headers, where field-number evolution +//! pays off. +//! +//! # Decode outcome +//! +//! Deserialization never propagates a parse failure as a hard error into the +//! cache path. Anything the reader cannot confidently interpret — absent or +//! wrong magic, an unknown `envelope_version`, a `type_id` mismatch, an +//! unsupported `type_version`, or a body decode error — becomes +//! [`CacheDecode::Miss`]. A backend turns `Miss` into a normal cache miss and +//! recomputes the value. This is what lets data written by an older format +//! self-heal: it simply fails the magic check and is regenerated. +use std::io::Write; use std::sync::Arc; use bytes::Bytes; -use crate::Result; +use crate::{Error, Result}; + +use super::{CacheEntryReader, CacheEntryWriter}; + +// --------------------------------------------------------------------------- +// Envelope +// --------------------------------------------------------------------------- + +/// Magic bytes that prefix every stabilized cache entry. +/// +/// An ASCII tag (`0x4C 0x43 0x45 0x31`) chosen so it cannot collide with any +/// pre-stabilization blob: those began with either a small little-endian +/// length (tens of bytes) or a small tag byte, never these values. +/// +/// Exported so backends can cheaply identify Lance cache entries (e.g. when +/// scanning a persistent store at startup) without hardcoding the bytes — +/// prefer [`has_cache_envelope`] over comparing against this directly. +pub const MAGIC: [u8; 4] = *b"LCE1"; + +/// Returns `true` if `data` begins with the cache-entry [`MAGIC`]. +/// +/// A cheap prefix check for backends that need to recognize Lance cache +/// entries without fully [`deserialize`](CacheCodec::deserialize)-ing them. A +/// `true` result only means the framing looks like ours; the entry can still +/// decode to a [`Miss`](CacheDecode::Miss) (e.g. wrong `type_id`). +pub fn has_cache_envelope(data: &[u8]) -> bool { + data.get(..MAGIC.len()) == Some(&MAGIC[..]) +} + +/// Version of the envelope framing itself. Bumped only if the outer frame +/// (magic/version/type_id/type_version layout) ever changes — expected never. +const ENVELOPE_VERSION: u8 = 1; + +/// Parsed envelope borrowed from the input bytes. +struct ParsedEnvelope<'a> { + type_id: &'a str, + type_version: u32, + /// Offset of the first body byte within the input. + body_offset: usize, +} + +/// Parse and validate the envelope at the start of `data`. +/// +/// Returns `None` for anything that is not a well-formed envelope this build +/// understands (wrong/absent magic, unknown `envelope_version`, truncation, +/// non-utf8 `type_id`). Callers translate `None` into [`CacheDecode::Miss`]. +fn parse_envelope(data: &Bytes) -> Option> { + let bytes = data.as_ref(); + let mut off = 0usize; + + let magic = bytes.get(off..off + 4)?; + if magic != MAGIC { + return None; + } + off += 4; + + if *bytes.get(off)? != ENVELOPE_VERSION { + return None; + } + off += 1; + + let type_id_len = u16::from_le_bytes(bytes.get(off..off + 2)?.try_into().ok()?) as usize; + off += 2; + + let type_id = std::str::from_utf8(bytes.get(off..off + type_id_len)?).ok()?; + off += type_id_len; + + let type_version = u32::from_le_bytes(bytes.get(off..off + 4)?.try_into().ok()?); + off += 4; + + Some(ParsedEnvelope { + type_id, + type_version, + body_offset: off, + }) +} + +/// Write the envelope for `type_id`/`type_version`, returning the number of +/// bytes written (the body's starting offset). +fn write_envelope(writer: &mut dyn Write, type_id: &str, type_version: u32) -> Result { + let type_id_len = u16::try_from(type_id.len()).map_err(|_| { + Error::io(format!( + "cache codec type_id too long ({} bytes, max {})", + type_id.len(), + u16::MAX + )) + })?; + + writer.write_all(&MAGIC)?; + writer.write_all(&[ENVELOPE_VERSION])?; + writer.write_all(&type_id_len.to_le_bytes())?; + writer.write_all(type_id.as_bytes())?; + writer.write_all(&type_version.to_le_bytes())?; + + Ok(4 + 1 + 2 + type_id.len() + 4) +} + +// --------------------------------------------------------------------------- +// CacheDecode — first-class cache-miss outcome +// --------------------------------------------------------------------------- + +/// Why a cache entry could not be decoded into the expected type. +/// +/// Carried by [`CacheDecode::Miss`] so backends can emit targeted metrics +/// (e.g. distinguish "evicting due to a stale format" from "type collision") +/// without re-parsing. Every reason maps to the same behavior — recompute via +/// the loader — so callers that don't care can ignore it. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheMissReason { + /// Absent or wrong magic, unknown `envelope_version`, truncated framing, or + /// a non-utf8 `type_id`. Typically an entry written by a pre-stabilization + /// or otherwise foreign build. + InvalidEnvelope, + /// Well-formed envelope, but its `type_id` names a different entry type than + /// the codec reading it. + TypeMismatch, + /// Written by a newer build whose `type_version` this build does not + /// understand and must not attempt to interpret. + VersionTooNew, + /// Envelope validated, but the body failed to decode (truncation, a + /// malformed protobuf header, an IPC error, etc.). + BodyError, +} + +/// Outcome of deserializing a cache entry. +/// +/// `Miss` means the bytes could not be confidently decoded into `T`; the +/// [`CacheMissReason`] says why. A backend treats any `Miss` exactly like a key +/// that was never present: recompute via the loader. +#[derive(Debug)] +pub enum CacheDecode { + Hit(T), + Miss(CacheMissReason), +} + +impl CacheDecode { + pub fn hit(self) -> Option { + match self { + Self::Hit(v) => Some(v), + Self::Miss(_) => None, + } + } +} // --------------------------------------------------------------------------- // CacheCodecImpl — trait for serializable cache entry types @@ -18,31 +190,40 @@ use crate::Result; /// Serialization trait for cache entries. /// -/// **Experimental**: the serialized format is not stable and may change -/// between releases without notice. +/// **Experimental**: the serialized format is not yet covered by a stability +/// guarantee and may change between releases. When it does stabilize, the +/// rules are: `TYPE_ID`, protobuf field numbers, and enum values are +/// append-only forever; format changes that protobuf cannot express +/// transparently bump [`CURRENT_VERSION`](Self::CURRENT_VERSION). /// -/// Implement this on concrete types that need to survive serialization -/// through a persistent cache backend. Then wire it into a [`CacheKey`](super::CacheKey) -/// via [`CacheCodec::from_impl`]: +/// Implement this on concrete types that need to survive serialization through +/// a persistent cache backend, then wire it into a +/// [`CacheKey`](super::CacheKey) via [`CacheCodec::from_impl`]. /// -/// ```ignore -/// impl CacheCodecImpl for MyData { -/// fn serialize(&self, w: &mut dyn Write) -> Result<()> { /* ... */ } -/// fn deserialize(data: &Bytes) -> Result { /* ... */ } -/// } -/// -/// impl CacheKey for MyDataKey { -/// type ValueType = MyData; -/// fn codec() -> Option { -/// Some(CacheCodec::from_impl::()) -/// } -/// // ... -/// } -/// ``` +/// The envelope (magic/version/type_id/type_version) is written and validated +/// by the [`CacheCodec`] wrapper. [`serialize`](Self::serialize) writes only +/// the body — a header followed by sections in a fixed, version-keyed order — +/// and [`deserialize`](Self::deserialize) reads them back in that same order. +/// The read sequence mirroring the write sequence for each `type_version` is +/// the invariant the implementor owns. pub trait CacheCodecImpl: Send + Sync { - fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()>; + /// Stable identity for this entry type. **Must not change once shipped.** + /// This is a deliberate author-assigned string, not `std::any::type_name` + /// (which is not stable across compiler versions). + const TYPE_ID: &'static str; + + /// Body schema version this build writes. Bump when the body layout + /// changes in a way protobuf field additions cannot express transparently + /// (adding/removing/reordering sections, a raw-blob encoding change, etc.). + const CURRENT_VERSION: u32; + + /// Write the body: a header, then sections in a fixed order. + fn serialize(&self, writer: &mut CacheEntryWriter<'_>) -> Result<()>; - fn deserialize(data: &Bytes) -> Result + /// Reconstruct from the body. Branch on + /// [`reader.version()`](CacheEntryReader::version) for backward compat; + /// sections are read in write order. + fn deserialize(reader: &mut CacheEntryReader<'_>) -> Result where Self: Sized; } @@ -55,25 +236,31 @@ pub(crate) type ArcAny = Arc; /// Type-erased codec for serializing and deserializing cache entries. /// -/// `CacheCodec` is two plain function pointers — it is `Copy` and has no -/// heap allocation. Construct one via [`CacheCodec::from_impl`] for types -/// that implement [`CacheCodecImpl`], or [`CacheCodec::new`] for custom -/// cases (e.g. when the orphan rule prevents a direct impl). +/// `CacheCodec` carries the entry's stable `type_id`/`version` plus two plain +/// function pointers — it is `Copy` and has no heap allocation. Construct one +/// via [`CacheCodec::from_impl`] for types that implement [`CacheCodecImpl`], +/// or [`CacheCodec::new`] for custom cases (e.g. when the orphan rule prevents +/// a direct impl). #[derive(Copy, Clone)] pub struct CacheCodec { - pub(crate) serialize: fn(&ArcAny, &mut dyn std::io::Write) -> Result<()>, - pub(crate) deserialize: fn(&Bytes) -> Result, + type_id: &'static str, + version: u32, + serialize_body: fn(&ArcAny, &mut CacheEntryWriter<'_>) -> Result<()>, + deserialize_body: fn(&mut CacheEntryReader<'_>) -> Result, } impl std::fmt::Debug for CacheCodec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CacheCodec").finish_non_exhaustive() + f.debug_struct("CacheCodec") + .field("type_id", &self.type_id) + .field("version", &self.version) + .finish_non_exhaustive() } } fn serialize_via_impl( any: &ArcAny, - writer: &mut dyn std::io::Write, + writer: &mut CacheEntryWriter<'_>, ) -> Result<()> { let val = any .downcast_ref::() @@ -81,44 +268,278 @@ fn serialize_via_impl( val.serialize(writer) } -fn deserialize_via_impl(data: &Bytes) -> Result { - let val = T::deserialize(data)?; +fn deserialize_via_impl( + reader: &mut CacheEntryReader<'_>, +) -> Result { + let val = T::deserialize(reader)?; Ok(Arc::new(val) as ArcAny) } impl CacheCodec { - /// Create a `CacheCodec` from plain function pointers. + /// Create a `CacheCodec` from explicit body function pointers. /// /// Prefer [`from_impl`](Self::from_impl) when the value type implements /// [`CacheCodecImpl`]. Use this for types where a direct impl isn't - /// possible (e.g. orphan rule prevents it). + /// possible (e.g. the orphan rule prevents it). `type_id` and `version` + /// play the same role as the corresponding [`CacheCodecImpl`] constants. pub fn new( - serialize: fn(&ArcAny, &mut dyn std::io::Write) -> Result<()>, - deserialize: fn(&Bytes) -> Result, + type_id: &'static str, + version: u32, + serialize_body: fn(&ArcAny, &mut CacheEntryWriter<'_>) -> Result<()>, + deserialize_body: fn(&mut CacheEntryReader<'_>) -> Result, ) -> Self { Self { - serialize, - deserialize, + type_id, + version, + serialize_body, + deserialize_body, } } /// Create a `CacheCodec` from a [`CacheCodecImpl`] implementation. - /// - /// For **sized** types stored directly in the cache. The codec - /// downcasts `&dyn Any` to `&T` for serialization and returns `Arc` - /// from deserialization. pub fn from_impl() -> Self { Self { - serialize: serialize_via_impl::, - deserialize: deserialize_via_impl::, + type_id: T::TYPE_ID, + version: T::CURRENT_VERSION, + serialize_body: serialize_via_impl::, + deserialize_body: deserialize_via_impl::, } } - pub fn serialize(&self, value: &ArcAny, writer: &mut dyn std::io::Write) -> Result<()> { - (self.serialize)(value, writer) + /// Serialize `value` into `writer`: envelope first, then the body. + pub fn serialize(&self, value: &ArcAny, writer: &mut dyn Write) -> Result<()> { + let body_offset = write_envelope(writer, self.type_id, self.version)?; + let mut entry_writer = CacheEntryWriter::with_pos(writer, body_offset); + (self.serialize_body)(value, &mut entry_writer) + } + + /// Deserialize an entry from `data`. + /// + /// Never fails: any non-fatal failure to interpret the bytes becomes a + /// [`CacheDecode::Miss`] with the reason why (see [`CacheMissReason`]). + /// Reading from an in-memory [`Bytes`] cannot do I/O, so there is no fault + /// channel — a miss is the only non-`Hit` outcome. + pub fn deserialize(&self, data: &Bytes) -> CacheDecode { + let Some(envelope) = parse_envelope(data) else { + log::debug!("cache entry rejected: missing or invalid envelope"); + return CacheDecode::Miss(CacheMissReason::InvalidEnvelope); + }; + + if envelope.type_id != self.type_id { + log::debug!( + "cache entry type_id mismatch: got {:?}, expected {:?}", + envelope.type_id, + self.type_id + ); + return CacheDecode::Miss(CacheMissReason::TypeMismatch); + } + + // A version newer than this build writes was produced by a newer build + // whose body layout we cannot assume to understand. Older/equal versions + // are the impl's responsibility to handle (branching on reader.version()). + if envelope.type_version > self.version { + log::debug!( + "cache entry {:?} has unsupported type_version {} (this build writes {})", + self.type_id, + envelope.type_version, + self.version + ); + return CacheDecode::Miss(CacheMissReason::VersionTooNew); + } + + let mut reader = CacheEntryReader::new(data, envelope.body_offset, envelope.type_version); + match (self.deserialize_body)(&mut reader) { + Ok(value) => CacheDecode::Hit(value), + Err(e) => { + log::debug!( + "cache entry {:?} v{} failed to decode: {e}", + self.type_id, + envelope.type_version + ); + CacheDecode::Miss(CacheMissReason::BodyError) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// A trivial codec used to exercise the envelope and miss semantics + /// without pulling in arrow-backed payloads. + #[derive(Debug, PartialEq)] + struct Widget { + n: u32, + } + + impl CacheCodecImpl for Widget { + const TYPE_ID: &'static str = "test.Widget"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, writer: &mut CacheEntryWriter<'_>) -> Result<()> { + writer.write_raw(&self.n.to_le_bytes()) + } + + fn deserialize(reader: &mut CacheEntryReader<'_>) -> Result { + let bytes = reader.read_raw()?; + let n = u32::from_le_bytes( + bytes + .as_ref() + .try_into() + .map_err(|_| Error::io("bad widget".to_string()))?, + ); + Ok(Self { n }) + } + } + + fn serialize_widget(widget: &Widget) -> Bytes { + let codec = CacheCodec::from_impl::(); + let any: ArcAny = Arc::new(Widget { n: widget.n }); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + Bytes::from(buf) + } + + /// The miss reason, or `None` if the decode was a hit. + fn miss_reason(data: &Bytes) -> Option { + match deserialize_widget(data) { + CacheDecode::Hit(_) => None, + CacheDecode::Miss(reason) => Some(reason), + } } - pub fn deserialize(&self, data: &Bytes) -> Result { - (self.deserialize)(data) + fn deserialize_widget(data: &Bytes) -> CacheDecode { + let codec = CacheCodec::from_impl::(); + match codec.deserialize(data) { + CacheDecode::Hit(any) => { + CacheDecode::Hit(Arc::try_unwrap(any.downcast::().unwrap()).unwrap()) + } + CacheDecode::Miss(reason) => CacheDecode::Miss(reason), + } + } + + #[test] + fn envelope_roundtrip_hits() { + let bytes = serialize_widget(&Widget { n: 0xDEADBEEF }); + // Sanity: the entry starts with the magic. + assert_eq!(&bytes[..4], b"LCE1"); + let decoded = deserialize_widget(&bytes).hit().unwrap(); + assert_eq!(decoded, Widget { n: 0xDEADBEEF }); + } + + #[test] + fn has_cache_envelope_detects_magic() { + let bytes = serialize_widget(&Widget { n: 1 }); + assert!(has_cache_envelope(&bytes)); + assert!(has_cache_envelope(&MAGIC)); // exactly the magic, nothing after + assert!(!has_cache_envelope(b"LCE")); // too short + assert!(!has_cache_envelope(b"JUNK and more")); + assert!(!has_cache_envelope(&[])); + } + + #[test] + fn wrong_magic_is_miss() { + let mut bytes = serialize_widget(&Widget { n: 7 }).to_vec(); + bytes[0] = b'X'; + assert_eq!( + miss_reason(&Bytes::from(bytes)), + Some(CacheMissReason::InvalidEnvelope) + ); + } + + #[test] + fn pre_stabilization_blob_is_miss() { + // An old unstable blob led with a small u64 LE length prefix (a JSON + // header of tens of bytes) — no magic. It must self-heal to a miss. + let mut blob = Vec::new(); + blob.extend_from_slice(&(42u64).to_le_bytes()); + blob.extend_from_slice(&[0u8; 42]); + assert_eq!( + miss_reason(&Bytes::from(blob)), + Some(CacheMissReason::InvalidEnvelope) + ); + + // A different unstable shape led with a small u8 tag (0/1/2). + assert_eq!( + miss_reason(&Bytes::from(vec![0u8, 1, 2, 3])), + Some(CacheMissReason::InvalidEnvelope) + ); + } + + #[test] + fn unknown_envelope_version_is_miss() { + let mut bytes = serialize_widget(&Widget { n: 7 }).to_vec(); + bytes[4] = 0xFF; // envelope_version byte + assert_eq!( + miss_reason(&Bytes::from(bytes)), + Some(CacheMissReason::InvalidEnvelope) + ); + } + + #[test] + fn type_id_mismatch_is_miss() { + // Hand-build an envelope with a foreign type_id but valid framing. + let mut buf = Vec::new(); + write_envelope(&mut buf, "some.OtherType", 1).unwrap(); + buf.extend_from_slice(&(4u64).to_le_bytes()); + buf.extend_from_slice(&99u32.to_le_bytes()); + assert_eq!( + miss_reason(&Bytes::from(buf)), + Some(CacheMissReason::TypeMismatch) + ); + } + + #[test] + fn unsupported_future_type_version_is_miss() { + // An entry written by a newer build (higher type_version) must miss + // rather than be misread by this build. + let mut buf = Vec::new(); + write_envelope(&mut buf, Widget::TYPE_ID, Widget::CURRENT_VERSION + 1).unwrap(); + lance_arrow::ipc::write_len_prefixed_bytes(&mut buf, &9u32.to_le_bytes()).unwrap(); + assert_eq!( + miss_reason(&Bytes::from(buf)), + Some(CacheMissReason::VersionTooNew) + ); + } + + #[test] + fn truncated_envelope_is_miss() { + let bytes = serialize_widget(&Widget { n: 7 }); + for cut in [0, 1, 4, 5, 7, 9] { + assert_eq!( + miss_reason(&bytes.slice(..cut.min(bytes.len()))), + Some(CacheMissReason::InvalidEnvelope), + "truncating to {cut} bytes should miss as InvalidEnvelope" + ); + } + } + + #[test] + fn body_decode_error_is_miss() { + // Valid envelope, but the body is too short for the widget. + let mut buf = Vec::new(); + write_envelope(&mut buf, Widget::TYPE_ID, Widget::CURRENT_VERSION).unwrap(); + buf.extend_from_slice(&(1u64).to_le_bytes()); + buf.push(0u8); + assert_eq!( + miss_reason(&Bytes::from(buf)), + Some(CacheMissReason::BodyError) + ); + } + + #[test] + fn reader_exposes_envelope_version() { + // type_version travels through the envelope to reader.version(). + let mut buf = Vec::new(); + write_envelope(&mut buf, Widget::TYPE_ID, 7).unwrap(); + let body_off = buf.len(); + // A widget body so the codec can decode it. + lance_arrow::ipc::write_len_prefixed_bytes(&mut buf, &5u32.to_le_bytes()).unwrap(); + let data = Bytes::from(buf); + + let mut r = CacheEntryReader::new(&data, body_off, 7); + assert_eq!(r.version(), 7); + assert_eq!(r.read_raw().unwrap().as_ref(), 5u32.to_le_bytes()); } } diff --git a/rust/lance-core/src/cache/entry_io.rs b/rust/lance-core/src/cache/entry_io.rs new file mode 100644 index 00000000000..fe91b11ca7d --- /dev/null +++ b/rust/lance-core/src/cache/entry_io.rs @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Streaming readers/writers for cache entry bodies. +//! +//! [`CacheCodecImpl`](super::CacheCodecImpl) bodies are written and read +//! through these wrappers. They keep serialization streaming (no buffering of +//! the whole entry) and reads zero-copy (sections borrow from the input +//! [`Bytes`]), while tracking the byte position needed to keep Arrow IPC +//! sections 64-byte aligned (see [`lance_arrow::ipc`]). +//! +//! Body layout primitives: +//! +//! ```text +//! HEADER : [header_len: u32 LE][header proto bytes] +//! ARROW_IPC : [pad to 64B][self-delimiting IPC stream] +//! RAW_BLOB : [len: u64 LE][bytes] +//! ``` + +use std::io::Write; + +use arrow_array::RecordBatch; +use bytes::Bytes; +use prost::Message; + +use crate::{Error, Result}; + +/// Writes a cache entry body: a header followed by sections, streaming +/// directly to the underlying writer. +/// +/// The envelope is written by the [`CacheCodec`](super::CacheCodec) wrapper +/// before this writer is handed to +/// [`CacheCodecImpl::serialize`](super::CacheCodecImpl::serialize). +pub struct CacheEntryWriter<'a> { + writer: &'a mut dyn Write, + /// Absolute byte offset within the entry, used to align IPC sections. + pos: usize, +} + +impl<'a> CacheEntryWriter<'a> { + /// Create a writer positioned at the start of an entry (offset 0). + /// + /// Use this for nested serialization into a standalone buffer. The + /// envelope-aware entry point is [`CacheCodec::serialize`](super::CacheCodec::serialize). + pub fn new(writer: &'a mut dyn Write) -> Self { + Self { writer, pos: 0 } + } + + /// Create a writer whose section alignment accounts for `pos` bytes + /// already written ahead of the body (i.e. the envelope). + pub(crate) fn with_pos(writer: &'a mut dyn Write, pos: usize) -> Self { + Self { writer, pos } + } + + /// Write a single discriminant byte (e.g. a variant tag). + pub fn write_u8(&mut self, value: u8) -> Result<()> { + self.writer.write_all(&[value])?; + self.pos += 1; + Ok(()) + } + + /// Write a protobuf header as `[len: u32 LE][bytes]`. + pub fn write_header(&mut self, header: &P) -> Result<()> { + let bytes = header.encode_to_vec(); + let len = u32::try_from(bytes.len()) + .map_err(|_| Error::io(format!("cache header too large: {} bytes", bytes.len())))?; + self.writer.write_all(&len.to_le_bytes())?; + self.writer.write_all(&bytes)?; + self.pos += 4 + bytes.len(); + Ok(()) + } + + /// Write `batch` as a 64-byte-aligned Arrow IPC section. + pub fn write_ipc(&mut self, batch: &RecordBatch) -> Result<()> { + lance_arrow::ipc::write_ipc_section(self.writer, &mut self.pos, batch) + .map_err(|e| Error::io(e.to_string())) + } + + /// Write `batches` as a single 64-byte-aligned multi-batch Arrow IPC + /// section. The iterator must yield at least one batch. + pub fn write_ipc_batches(&mut self, batches: I) -> Result<()> + where + I: IntoIterator, + { + lance_arrow::ipc::write_ipc_section_batches(self.writer, &mut self.pos, batches) + .map_err(|e| Error::io(e.to_string())) + } + + /// Write a raw blob as `[len: u64 LE][bytes]`. + /// + /// Only for byte payloads that already have their own stable, portable + /// encoding (e.g. a roaring bitmap, a varint-packed stream). + pub fn write_raw(&mut self, bytes: &[u8]) -> Result<()> { + lance_arrow::ipc::write_len_prefixed_bytes(self.writer, bytes) + .map_err(|e| Error::io(e.to_string()))?; + self.pos += 8 + bytes.len(); + Ok(()) + } + + /// The underlying writer, for a payload that carries its own framing. + /// + /// Use this only when the codec writes a self-delimiting or whole-body + /// payload — e.g. streaming a roaring bitmap as the entire body, where the + /// length prefix of [`write_raw`](Self::write_raw) would be redundant and + /// buffering to measure that length would force an extra copy. For + /// structured bodies prefer [`write_header`](Self::write_header) / + /// [`write_ipc`](Self::write_ipc) / [`write_raw`](Self::write_raw), which + /// give you versioning and 64-byte IPC alignment. + /// + /// Bytes written through this do **not** advance the section-alignment + /// position, so it must not be interleaved with [`write_ipc`](Self::write_ipc). + pub fn raw_writer(&mut self) -> &mut dyn Write { + self.writer + } +} + +/// Reads a cache entry body, tracking an offset into the input and exposing +/// the entry's `type_version` so implementors can branch for backward compat. +/// +/// All reads are zero-copy: returned [`Bytes`] and the buffers behind decoded +/// [`RecordBatch`]es borrow from the input allocation. +pub struct CacheEntryReader<'a> { + data: &'a Bytes, + offset: usize, + version: u32, +} + +impl<'a> CacheEntryReader<'a> { + /// Create a reader over `data`, starting at body byte `offset`, for an + /// entry written at `version`. + pub fn new(data: &'a Bytes, offset: usize, version: u32) -> Self { + Self { + data, + offset, + version, + } + } + + /// The `type_version` from the envelope. Branch on this for backward compat. + pub fn version(&self) -> u32 { + self.version + } + + /// Read a single discriminant byte written by [`CacheEntryWriter::write_u8`]. + pub fn read_u8(&mut self) -> Result { + let bytes = self.data.as_ref(); + let v = *bytes + .get(self.offset) + .ok_or_else(|| Error::io("cache entry: truncated, missing tag byte".to_string()))?; + self.offset += 1; + Ok(v) + } + + /// Read a protobuf header written by [`CacheEntryWriter::write_header`]. + pub fn read_header(&mut self) -> Result

{ + let bytes = self.data.as_ref(); + let len_end = self + .offset + .checked_add(4) + .filter(|&e| e <= bytes.len()) + .ok_or_else(|| Error::io("cache header: truncated length prefix".to_string()))?; + let len = u32::from_le_bytes(bytes[self.offset..len_end].try_into().unwrap()) as usize; + let data_end = len_end + .checked_add(len) + .filter(|&e| e <= bytes.len()) + .ok_or_else(|| Error::io("cache header: truncated body".to_string()))?; + let msg = P::decode(&bytes[len_end..data_end]) + .map_err(|e| Error::io(format!("cache header decode failed: {e}")))?; + self.offset = data_end; + Ok(msg) + } + + /// Read one [`RecordBatch`] from a 64-byte-aligned IPC section. + pub fn read_ipc(&mut self) -> Result { + lance_arrow::ipc::read_ipc_section_at(self.data, &mut self.offset) + .map_err(|e| Error::io(e.to_string())) + } + + /// Read all [`RecordBatch`]es from a 64-byte-aligned multi-batch IPC + /// section written by [`CacheEntryWriter::write_ipc_batches`]. + pub fn read_ipc_batches(&mut self) -> Result> { + lance_arrow::ipc::read_ipc_section_batches_at(self.data, &mut self.offset) + .map_err(|e| Error::io(e.to_string())) + } + + /// Read a raw blob written by [`CacheEntryWriter::write_raw`], zero-copy. + pub fn read_raw(&mut self) -> Result { + lance_arrow::ipc::read_len_prefixed_bytes_at(self.data, &mut self.offset) + .map_err(|e| Error::io(e.to_string())) + } + + /// The not-yet-consumed body bytes as a zero-copy slice. + /// + /// For a payload that carries its own framing and is parsed with the + /// codec's own cursor — the read counterpart of + /// [`CacheEntryWriter::raw_writer`]. For structured bodies prefer + /// [`read_header`](Self::read_header) / [`read_ipc`](Self::read_ipc) / + /// [`read_raw`](Self::read_raw). + pub fn body(&self) -> Bytes { + self.data.slice(self.offset..) + } +} diff --git a/rust/lance-core/src/cache/mod.rs b/rust/lance-core/src/cache/mod.rs index f62837fe3cc..61c0c58446e 100644 --- a/rust/lance-core/src/cache/mod.rs +++ b/rust/lance-core/src/cache/mod.rs @@ -47,10 +47,14 @@ pub mod backend; pub mod codec; +mod entry_io; mod moka; pub use backend::{CacheBackend, CacheEntry, InternalCacheKey}; -pub use codec::{CacheCodec, CacheCodecImpl}; +pub use codec::{ + CacheCodec, CacheCodecImpl, CacheDecode, CacheMissReason, MAGIC, has_cache_envelope, +}; +pub use entry_io::{CacheEntryReader, CacheEntryWriter}; pub use moka::MokaCacheBackend; use std::borrow::Cow; diff --git a/rust/lance-index/build.rs b/rust/lance-index/build.rs index 0617de8c806..b47744f7b5a 100644 --- a/rust/lance-index/build.rs +++ b/rust/lance-index/build.rs @@ -6,6 +6,9 @@ use std::io::Result; fn main() -> Result<()> { println!("cargo:rerun-if-changed=protos"); + // Cache-entry protos are library-internal serialization, not part of the + // on-disk format spec, so they live here rather than in the shared `protos/`. + println!("cargo:rerun-if-changed=protos-cache"); #[cfg(feature = "protoc")] // Use vendored protobuf compiler if requested. @@ -17,8 +20,12 @@ fn main() -> Result<()> { prost_build.protoc_arg("--experimental_allow_proto3_optional"); prost_build.enable_type_names(); prost_build.compile_protos( - &["./protos/index.proto", "./protos/index_old.proto"], - &["./protos"], + &[ + "./protos/index.proto", + "./protos/index_old.proto", + "./protos-cache/cache.proto", + ], + &["./protos", "./protos-cache"], )?; let rust_toolchain = env::var("RUSTUP_TOOLCHAIN") diff --git a/rust/lance-index/protos-cache/cache.proto b/rust/lance-index/protos-cache/cache.proto new file mode 100644 index 00000000000..5bbea1b9c48 --- /dev/null +++ b/rust/lance-index/protos-cache/cache.proto @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +// Protobuf headers for serialized index cache entries. +// +// These messages describe the *cache* serialization format, not the on-disk +// Lance format spec, so they live with the library (lance-index) rather than in +// the top-level `protos/` spec folder. +// +// Field numbers and enum values are append-only across all messages here: never +// renumber or reuse them. A change the proto cannot express transparently +// (adding/removing/reordering the IPC/raw sections that follow a header) must +// bump the relevant codec's `CURRENT_VERSION` instead. + +syntax = "proto3"; + +package lance.index.cache; + +// --------------------------------------------------------------------------- +// Full-text search (FTS) posting lists +// --------------------------------------------------------------------------- + +// Header for a serialized `CompressedPostingList` cache entry. +message CompressedPostingHeader { + float max_score = 1; + uint32 length = 2; + PostingTailCodec posting_tail_codec = 3; + PositionStorage position_storage = 4; + // Only meaningful when position_storage == POSITION_STORAGE_SHARED. + PositionStreamCodec position_stream_codec = 5; +} + +// Header for a serialized `PlainPostingList` cache entry. Followed by an Arrow +// IPC section of (row_ids: UInt64, frequencies: Float32), then — when +// position_storage == POSITION_STORAGE_LEGACY — an IPC section of the per-doc +// position list. Plain postings never carry a shared position stream. +message PlainPostingHeader { + // Absent when the posting has no precomputed block-max score (the in-memory + // `max_score` is `None`); present otherwise. + optional float max_score = 1; + // POSITION_STORAGE_NONE or POSITION_STORAGE_LEGACY only. + PositionStorage position_storage = 2; +} + +// Header for a serialized standalone `Positions` cache entry. Followed by the +// position sections framed per `position_storage`, which is never +// POSITION_STORAGE_NONE for a standalone entry. +message PositionsHeader { + PositionStorage position_storage = 1; + // Only meaningful when position_storage == POSITION_STORAGE_SHARED. + PositionStreamCodec position_stream_codec = 2; +} + +// Tail-block encoding of a compressed posting list. +enum PostingTailCodec { + POSTING_TAIL_CODEC_FIXED32 = 0; + POSTING_TAIL_CODEC_VARINT_DELTA = 1; +} + +// Encoding of a shared position stream's byte buffer. +enum PositionStreamCodec { + POSITION_STREAM_CODEC_VARINT_DOC_DELTA = 0; + POSITION_STREAM_CODEC_PACKED_DELTA = 1; +} + +// Which (if any) positions accompany the posting list, and how they are framed +// in the sections after the header. +enum PositionStorage { + POSITION_STORAGE_NONE = 0; + // Legacy per-doc positions as a single Arrow IPC section. + POSITION_STORAGE_LEGACY = 1; + // Shared stream: an Arrow IPC section of block offsets, then a raw blob of + // the (codec-encoded) position bytes. + POSITION_STORAGE_SHARED = 2; +} + +// --------------------------------------------------------------------------- +// Scalar indices +// --------------------------------------------------------------------------- + +// Header for a serialized `BTreeIndexState` cache entry, followed by a single +// Arrow IPC section holding the page-lookup batch. +message BTreeIndexHeader { + uint64 batch_size = 1; + // Whether an explicit page-range -> file mapping is present. Distinguishes a + // non-range-partitioned index (false) from a range-partitioned one whose map + // happens to be empty (true with no entries). + bool has_ranges_to_files = 2; + repeated RangeToFile ranges_to_files = 3; +} + +// One entry of a `BTreeIndexState` page-range -> file mapping. The range is +// inclusive on both ends (a `RangeInclusive`). +message RangeToFile { + uint32 start = 1; + uint32 end = 2; + uint32 page_offset = 3; + string path = 4; +} + +// --------------------------------------------------------------------------- +// Vector indices (IVF partitions) +// --------------------------------------------------------------------------- + +// Headers for serialized IVF partition cache entries (`PartitionEntry`). +// +// Each header is followed by 64-byte-aligned Arrow IPC sections in a fixed, +// version-keyed order (sub-index, then any quantizer-specific arrays, then the +// quantizer storage batches). + +// Distance metric a quantizer's storage was built for. +enum DistanceType { + DISTANCE_TYPE_L2 = 0; + DISTANCE_TYPE_COSINE = 1; + DISTANCE_TYPE_DOT = 2; + DISTANCE_TYPE_HAMMING = 3; +} + +// Rotation applied by a RabitQ quantizer. +enum RotationType { + ROTATION_TYPE_MATRIX = 0; + ROTATION_TYPE_FAST = 1; +} + +// Estimator a RabitQ quantizer uses at query time. +enum RabitQueryEstimator { + RABIT_QUERY_ESTIMATOR_RESIDUAL_QUERY = 0; + RABIT_QUERY_ESTIMATOR_RAW_QUERY = 1; +} + +// Product quantizer. Sections: sub-index IPC, codebook IPC, storage IPC. +message PqPartitionHeader { + DistanceType distance_type = 1; + uint32 nbits = 2; + uint64 num_sub_vectors = 3; + uint64 dimension = 4; + bool transposed = 5; +} + +// Flat (float) and flat-binary quantizers. Sections: sub-index IPC, storage IPC. +message FlatPartitionHeader { + DistanceType distance_type = 1; + uint64 dim = 2; +} + +// Scalar quantizer. Sections: sub-index IPC, storage IPC (possibly multi-batch). +message SqPartitionHeader { + DistanceType distance_type = 1; + uint32 num_bits = 2; + uint64 dim = 3; + double bounds_start = 4; + double bounds_end = 5; +} + +// Header for a serialized IVF index state (`IvfIndexState`), followed by +// three raw blobs: the IVF model protobuf, the quantizer's extra-metadata +// buffer (may be empty), and the auxiliary IVF model protobuf. +message IvfStateHeader { + string index_file_path = 1; + string uuid = 2; + string distance_type = 3; + repeated string sub_index_metadata = 4; + string sub_index_type = 5; + string quantization_type = 6; + // Per-quantizer `Q::Metadata` as JSON. Kept as a string because the metadata + // type is generic over the quantizer; the proto envelope still provides + // additive evolution for the surrounding fields. + string quantizer_metadata_json = 7; + string cache_key_prefix = 8; + uint64 index_file_size = 9; + uint64 aux_file_size = 10; +} + +// RabitQ quantizer. Sections: sub-index IPC, rotate-matrix IPC (Matrix rotation +// only), storage IPC. +message RabitPartitionHeader { + DistanceType distance_type = 1; + uint32 num_bits = 2; + uint32 code_dim = 3; + RotationType rotation_type = 4; + // Fast-rotation sign vector; present only when rotation_type == + // ROTATION_TYPE_FAST (the Matrix case stores its rotation as an IPC section). + optional bytes fast_rotation_signs = 5; + // Estimator the RabitQ storage uses at query time (residual vs raw query). + RabitQueryEstimator query_estimator = 6; +} diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 888070a3c1f..10d4b2b8001 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -68,6 +68,13 @@ pub mod pbold { include!(concat!(env!("OUT_DIR"), "/lance.table.rs")); } +/// Protobuf headers for serialized index cache entries (FTS posting lists, +/// scalar indices, and IVF vector partitions). +pub mod cache_pb { + #![allow(clippy::use_self)] + include!(concat!(env!("OUT_DIR"), "/lance.index.cache.rs")); +} + /// Generic methods common across all types of secondary indices /// #[async_trait] diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 1ae2faf6e6b..c2a6e80e82b 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -18,14 +18,13 @@ use bytes::Bytes; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::ScalarValue; use futures::{StreamExt, TryStreamExt, stream}; -use lance_arrow::ipc::{ - read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream, - write_len_prefixed_bytes, -}; use lance_core::deepsize::DeepSizeOf; use lance_core::{ Error, ROW_ID, Result, - cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}, + cache::{ + CacheCodec, CacheCodecImpl, CacheEntryReader, CacheEntryWriter, CacheKey, LanceCache, + WeakLanceCache, + }, error::LanceOptionExt, utils::tokio::get_num_compute_intensive_cpus, }; @@ -212,6 +211,32 @@ impl BitmapIndexState { frag_reuse_index, ))) } + + /// Build a state directly from its parts, for codec tests in sibling + /// modules (e.g. the label-list index, which nests a bitmap state). + #[cfg(test)] + pub(crate) fn new_for_test( + index_map: BTreeMap, + null_map: RowAddrTreeMap, + value_type: DataType, + ) -> Result { + Ok(Self { + lookup_batch: build_lookup_batch(&index_map, &value_type)?, + null_map: Arc::new(null_map), + value_type, + index_map: Arc::new(index_map), + }) + } + + #[cfg(test)] + pub(crate) fn lookup_batch(&self) -> &RecordBatch { + &self.lookup_batch + } + + #[cfg(test)] + pub(crate) fn null_map(&self) -> &RowAddrTreeMap { + &self.null_map + } } fn build_lookup_batch( @@ -251,25 +276,27 @@ fn parse_lookup_batch(batch: &RecordBatch) -> Result, offsets: UInt64)] + /// RAW_BLOB : null_map (roaring tree map, portable encoding) + /// ARROW_IPC : (keys: , offsets: UInt64) /// ``` - /// The value type is recovered from the IPC stream schema. - fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + /// The value type is recovered from the IPC section schema. + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { let mut null_bytes = Vec::with_capacity(self.null_map.serialized_size()); self.null_map.serialize_into(&mut null_bytes)?; - write_len_prefixed_bytes(writer, &null_bytes)?; - write_ipc_stream(&self.lookup_batch, writer)?; + w.write_raw(&null_bytes)?; + w.write_ipc(&self.lookup_batch)?; Ok(()) } - fn deserialize(data: &bytes::Bytes) -> Result { - let mut offset = 0; - let null_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let null_bytes = r.read_raw()?; let null_map = Arc::new(RowAddrTreeMap::deserialize_from(null_bytes.as_ref())?); - let lookup_batch = read_ipc_stream_single_at(data, &mut offset)?; + let lookup_batch = r.read_ipc()?; let value_type = lookup_batch.schema().field(0).data_type().clone(); let index_map = Arc::new(parse_lookup_batch(&lookup_batch)?); Ok(Self { @@ -1821,8 +1848,12 @@ mod tests { fn assert_state_roundtrips(state: &BitmapIndexState) { let mut buf = Vec::new(); - state.serialize(&mut buf).unwrap(); - let restored = BitmapIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + state + .serialize(&mut CacheEntryWriter::new(&mut buf)) + .unwrap(); + let data = bytes::Bytes::from(buf); + let mut reader = CacheEntryReader::new(&data, 0, BitmapIndexState::CURRENT_VERSION); + let restored = BitmapIndexState::deserialize(&mut reader).unwrap(); assert_eq!(restored.lookup_batch, state.lookup_batch); assert_eq!(&*restored.null_map, &*state.null_map); assert_eq!(restored.value_type, state.value_type); @@ -1856,6 +1887,53 @@ mod tests { assert_state_roundtrips(&empty_state); } + /// The lookup batch must decode zero-copy through the full envelope-bearing + /// [`CacheCodec`] even though the envelope pushes the IPC section to a + /// non-aligned starting offset. + #[test] + fn test_bitmap_index_state_lookup_is_zero_copy() { + const ALIGN: usize = 64; + let mut index_map = BTreeMap::new(); + for k in 0..32i32 { + index_map.insert( + OrderableScalarValue(ScalarValue::Int32(Some(k))), + k as usize, + ); + } + let state = BitmapIndexState { + lookup_batch: build_lookup_batch(&index_map, &DataType::Int32).unwrap(), + null_map: Arc::new(RowAddrTreeMap::new()), + value_type: DataType::Int32, + index_map: Arc::new(index_map), + }; + + let codec = CacheCodec::from_impl::(); + let any: Arc = Arc::new(state); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + + // Model a backend reading into a 64-byte-aligned buffer. + let mut v = vec![0u8; buf.len() + ALIGN]; + let pad = (ALIGN - (v.as_ptr() as usize % ALIGN)) % ALIGN; + v[pad..pad + buf.len()].copy_from_slice(&buf); + let data = bytes::Bytes::from(v).slice(pad..pad + buf.len()); + + let restored = codec.deserialize(&data).hit().unwrap(); + let restored = restored.downcast::().unwrap(); + + let base = data.as_ptr() as usize; + let end = base + data.len(); + for col in restored.lookup_batch.columns() { + for buffer in col.to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= base && ptr < end, + "lookup batch buffer was realigned out of the input — misaligned IPC section", + ); + } + } + } + #[tokio::test] async fn test_bitmap_lazy_loading_and_cache() { // Create a temporary directory for the index diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 6128248308e..11ed38ca248 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -15,6 +15,7 @@ use super::{ OldIndexDataFilter, SargableQuery, ScalarIndex, ScalarIndexParams, SearchResult, compute_next_prefix, }; +use crate::cache_pb::{BTreeIndexHeader, RangeToFile}; use crate::{Index, IndexType}; use crate::{ frag_reuse::FragReuseIndex, @@ -52,11 +53,13 @@ use futures::{ future::BoxFuture, stream::{self}, }; -use lance_arrow::ipc::{read_ipc_stream_single_at, write_ipc_stream}; use lance_core::deepsize::DeepSizeOf; use lance_core::{ Error, ROW_ID, Result, - cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}, + cache::{ + CacheCodec, CacheCodecImpl, CacheEntryReader, CacheEntryWriter, CacheKey, LanceCache, + WeakLanceCache, + }, error::LanceOptionExt, utils::{ tokio::get_num_compute_intensive_cpus, @@ -1402,106 +1405,58 @@ impl BTreeIndexState { } impl CacheCodecImpl for BTreeIndexState { - /// Wire format (no stability guarantees yet — the cache is rebuilt from - /// source on any version mismatch): + const TYPE_ID: &'static str = "lance.scalar.BTreeIndexState"; + const CURRENT_VERSION: u32 = 1; + + /// Wire format: /// ```text - /// u64 batch_size (LE) - /// u8 has_ranges (0 = None, 1 = Some) - /// if has_ranges: - /// u32 entry_count (LE) - /// per entry: u32 start | u32 end | u32 offset | u32 path_len | path bytes - /// lookup batch (Arrow IPC stream) + /// HEADER : BTreeIndexHeader proto (batch_size + page-range mapping) + /// ARROW_IPC : page-lookup batch /// ``` - fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { - writer.write_all(&self.batch_size.to_le_bytes())?; - match &self.ranges_to_files { - None => writer.write_all(&[0u8])?, - Some(ranges) => { - writer.write_all(&[1u8])?; - let count = u32::try_from(ranges.len()).map_err(|_| { - Error::io("BTreeIndexState: ranges_to_files exceeds u32::MAX entries") - })?; - writer.write_all(&count.to_le_bytes())?; - for (range, (path, page_offset)) in ranges.iter() { - writer.write_all(&range.start().to_le_bytes())?; - writer.write_all(&range.end().to_le_bytes())?; - writer.write_all(&page_offset.to_le_bytes())?; - let path_len = u32::try_from(path.len()).map_err(|_| { - Error::io("BTreeIndexState: ranges_to_files path exceeds u32::MAX bytes") - })?; - writer.write_all(&path_len.to_le_bytes())?; - writer.write_all(path.as_bytes())?; - } - } - } - write_ipc_stream(&self.lookup_batch, writer)?; + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let ranges_to_files = match &self.ranges_to_files { + None => Vec::new(), + Some(ranges) => ranges + .iter() + .map(|(range, (path, page_offset))| RangeToFile { + start: *range.start(), + end: *range.end(), + page_offset: *page_offset, + path: path.clone(), + }) + .collect(), + }; + let header = BTreeIndexHeader { + batch_size: self.batch_size, + has_ranges_to_files: self.ranges_to_files.is_some(), + ranges_to_files, + }; + w.write_header(&header)?; + w.write_ipc(&self.lookup_batch)?; Ok(()) } - fn deserialize(data: &bytes::Bytes) -> Result { - let mut offset = 0; - let batch_size = read_u64_le(data, &mut offset)?; - let has_ranges = read_u8(data, &mut offset)?; - let ranges_to_files = match has_ranges { - 0 => None, - 1 => { - let count = read_u32_le(data, &mut offset)? as usize; - let mut entries = Vec::with_capacity(count); - for _ in 0..count { - let start = read_u32_le(data, &mut offset)?; - let end = read_u32_le(data, &mut offset)?; - let page_offset = read_u32_le(data, &mut offset)?; - let path_len = read_u32_le(data, &mut offset)? as usize; - let path = read_bytes(data, &mut offset, path_len)?; - let path = std::str::from_utf8(&path) - .map_err(|e| Error::io(format!("BTreeIndexState path: {e}")))? - .to_string(); - entries.push((start..=end, (path, page_offset))); - } - Some(Arc::new(entries.into_iter().collect())) - } - other => { - return Err(Error::io(format!( - "BTreeIndexState: invalid has_ranges tag {other}" - ))); - } + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: BTreeIndexHeader = r.read_header()?; + let ranges_to_files = if header.has_ranges_to_files { + let map: RangeInclusiveMap = header + .ranges_to_files + .into_iter() + .map(|entry| (entry.start..=entry.end, (entry.path, entry.page_offset))) + .collect(); + Some(Arc::new(map)) + } else { + None }; - let lookup_batch = read_ipc_stream_single_at(data, &mut offset)?; + let lookup_batch = r.read_ipc()?; Ok(Self { lookup_batch, - batch_size, + batch_size: header.batch_size, ranges_to_files, }) } } -fn read_bytes(data: &bytes::Bytes, offset: &mut usize, len: usize) -> Result { - if data.len() < *offset + len { - return Err(Error::io(format!( - "BTreeIndexState: short read of {len} bytes at offset {offset} (have {})", - data.len() - ))); - } - let slice = data.slice(*offset..*offset + len); - *offset += len; - Ok(slice) -} - -fn read_u8(data: &bytes::Bytes, offset: &mut usize) -> Result { - let bytes = read_bytes(data, offset, 1)?; - Ok(bytes[0]) -} - -fn read_u32_le(data: &bytes::Bytes, offset: &mut usize) -> Result { - let bytes = read_bytes(data, offset, 4)?; - Ok(u32::from_le_bytes(bytes.as_ref().try_into().unwrap())) -} - -fn read_u64_le(data: &bytes::Bytes, offset: &mut usize) -> Result { - let bytes = read_bytes(data, offset, 8)?; - Ok(u64::from_le_bytes(bytes.as_ref().try_into().unwrap())) -} - /// Cache key for a [`BTreeIndexState`]. The cache it is used with is already /// namespaced per-index, so the key string is a constant. struct BTreeIndexStateKey; @@ -3286,7 +3241,23 @@ mod tests { }; use crate::scalar::registry::ScalarIndexPlugin; use arrow_array::RecordBatch; - use lance_core::cache::{CacheCodecImpl, CacheKey}; + use lance_core::cache::{CacheCodecImpl, CacheEntryReader, CacheEntryWriter, CacheKey}; + + /// Serialize a `BTreeIndexState` body (no envelope) for tests. + fn serialize_state(state: &BTreeIndexState) -> Vec { + let mut buf = Vec::new(); + state + .serialize(&mut CacheEntryWriter::new(&mut buf)) + .unwrap(); + buf + } + + /// Deserialize a `BTreeIndexState` body (no envelope) for tests. + fn deserialize_state(buf: Vec) -> lance_core::Result { + let data = bytes::Bytes::from(buf); + let mut reader = CacheEntryReader::new(&data, 0, BTreeIndexState::CURRENT_VERSION); + BTreeIndexState::deserialize(&mut reader) + } use rangemap::RangeInclusiveMap; lance_testing::define_stage_event_progress!( @@ -5888,9 +5859,7 @@ mod tests { } fn assert_state_roundtrips(state: &BTreeIndexState) { - let mut buf = Vec::new(); - state.serialize(&mut buf).unwrap(); - let restored = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + let restored = deserialize_state(serialize_state(state)).unwrap(); assert_eq!(restored.lookup_batch, state.lookup_batch); assert_eq!(restored.batch_size, state.batch_size); assert_eq!(restored.ranges_to_files, state.ranges_to_files); @@ -5959,9 +5928,7 @@ mod tests { batch_size: index.batch_size, ranges_to_files: index.ranges_to_files.clone(), }; - let mut buf = Vec::new(); - state.serialize(&mut buf).unwrap(); - let restored = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + let restored = deserialize_state(serialize_state(&state)).unwrap(); let reconstructed = restored .reconstruct(test_store.clone(), &LanceCache::no_cache(), None) .unwrap(); @@ -5997,18 +5964,57 @@ mod tests { assert_eq!(expected, actual); } + /// The lookup batch must decode zero-copy through the full envelope even + /// though the proto header pushes the IPC section to a non-aligned offset. + #[test] + fn test_btree_index_state_lookup_is_zero_copy() { + use lance_core::cache::CacheCodec; + const ALIGN: usize = 64; + + let ranges: RangeInclusiveMap = + [(0..=99, ("part_0_page_file.lance".to_string(), 0))] + .into_iter() + .collect(); + let state = BTreeIndexState { + lookup_batch: sample_lookup_batch(), + batch_size: 8192, + ranges_to_files: Some(Arc::new(ranges)), + }; + + let codec = CacheCodec::from_impl::(); + let any: Arc = Arc::new(state); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + + let mut v = vec![0u8; buf.len() + ALIGN]; + let pad = (ALIGN - (v.as_ptr() as usize % ALIGN)) % ALIGN; + v[pad..pad + buf.len()].copy_from_slice(&buf); + let data = bytes::Bytes::from(v).slice(pad..pad + buf.len()); + + let restored = codec.deserialize(&data).hit().unwrap(); + let restored = restored.downcast::().unwrap(); + + let base = data.as_ptr() as usize; + let end = base + data.len(); + for col in restored.lookup_batch.columns() { + for buffer in col.to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= base && ptr < end, + "lookup batch buffer was realigned out of the input — misaligned IPC section", + ); + } + } + } + #[test] - fn test_btree_index_state_rejects_invalid_has_ranges_tag() { - // u64 batch_size (any) then a bad has_ranges tag. + fn test_btree_index_state_rejects_truncated_header() { + // A header length prefix that overruns the buffer must error rather + // than panic or silently misread it. let mut buf = Vec::new(); - buf.extend_from_slice(&1000u64.to_le_bytes()); - buf.push(7u8); - let err = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap_err(); - let msg = err.to_string(); - assert!( - msg.contains("has_ranges") && msg.contains("7"), - "expected error to mention the bad has_ranges tag, got: {msg}" - ); + buf.extend_from_slice(&100u32.to_le_bytes()); // claims a 100-byte header + buf.extend_from_slice(&[0u8; 4]); // but only 4 bytes follow + assert!(deserialize_state(buf).is_err()); } #[tokio::test] diff --git a/rust/lance-index/src/scalar/btree/flat.rs b/rust/lance-index/src/scalar/btree/flat.rs index 212ef6490be..045b4c95c55 100644 --- a/rust/lance-index/src/scalar/btree/flat.rs +++ b/rust/lance-index/src/scalar/btree/flat.rs @@ -13,9 +13,8 @@ use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; use datafusion_physical_expr::create_physical_expr; use lance_arrow::RecordBatchExt; -use lance_arrow::ipc::{read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream}; use lance_core::Result; -use lance_core::cache::CacheCodecImpl; +use lance_core::cache::{CacheCodecImpl, CacheEntryReader, CacheEntryWriter}; use lance_core::deepsize::DeepSizeOf; use lance_core::utils::address::RowAddress; use lance_select::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps}; @@ -236,32 +235,38 @@ impl FlatIndex { } impl CacheCodecImpl for FlatIndex { - fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + const TYPE_ID: &'static str = "lance.scalar.FlatIndex"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { // Format: - // [len-prefixed all_addrs_map][len-prefixed null_addrs_map][batch IPC stream] - writer.write_all(&(self.all_addrs_map.serialized_size() as u64).to_le_bytes())?; - self.all_addrs_map.serialize_into(&mut *writer)?; + // RAW_BLOB : all_addrs_map (roaring tree map) + // RAW_BLOB : null_addrs_map (roaring tree map) + // ARROW_IPC : data batch + let mut all_addrs_bytes = Vec::with_capacity(self.all_addrs_map.serialized_size()); + self.all_addrs_map.serialize_into(&mut all_addrs_bytes)?; + w.write_raw(&all_addrs_bytes)?; - writer.write_all(&(self.null_addrs_map.serialized_size() as u64).to_le_bytes())?; - self.null_addrs_map.serialize_into(&mut *writer)?; + let mut null_addrs_bytes = Vec::with_capacity(self.null_addrs_map.serialized_size()); + self.null_addrs_map.serialize_into(&mut null_addrs_bytes)?; + w.write_raw(&null_addrs_bytes)?; - write_ipc_stream(self.data.as_ref(), writer)?; + w.write_ipc(self.data.as_ref())?; Ok(()) } - fn deserialize(data: &bytes::Bytes) -> Result + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result where Self: Sized, { - let mut offset = 0; - let all_addrs_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let all_addrs_bytes = r.read_raw()?; let all_addrs_map = RowAddrTreeMap::deserialize_from(all_addrs_bytes.as_ref())?; - let null_addrs_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let null_addrs_bytes = r.read_raw()?; let null_addrs_map = RowAddrTreeMap::deserialize_from(null_addrs_bytes.as_ref())?; - let batch = read_ipc_stream_single_at(data, &mut offset)?; + let batch = r.read_ipc()?; let df_schema = DFSchema::try_from(batch.schema())?; @@ -309,8 +314,12 @@ mod tests { fn assert_roundtrips(index: &FlatIndex) { let mut buf = Vec::new(); - index.serialize(&mut buf).unwrap(); - let restored = FlatIndex::deserialize(&bytes::Bytes::from(buf)).unwrap(); + index + .serialize(&mut CacheEntryWriter::new(&mut buf)) + .unwrap(); + let data = bytes::Bytes::from(buf); + let mut reader = CacheEntryReader::new(&data, 0, FlatIndex::CURRENT_VERSION); + let restored = FlatIndex::deserialize(&mut reader).unwrap(); assert_eq!(restored.data, index.data); assert_eq!(restored.all_addrs_map, index.all_addrs_map); @@ -335,6 +344,41 @@ mod tests { assert_roundtrips(&FlatIndex::try_new(empty).unwrap()); } + /// The data batch must decode zero-copy through the full envelope-bearing + /// [`CacheCodec`], even though the two roaring blobs and the envelope push + /// the IPC section to a non-aligned starting offset. + #[test] + fn test_flat_index_data_is_zero_copy() { + use lance_core::cache::CacheCodec; + const ALIGN: usize = 64; + + let index = example_index(); + let codec = CacheCodec::from_impl::(); + let any: Arc = Arc::new(index); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + + let mut v = vec![0u8; buf.len() + ALIGN]; + let pad = (ALIGN - (v.as_ptr() as usize % ALIGN)) % ALIGN; + v[pad..pad + buf.len()].copy_from_slice(&buf); + let data = bytes::Bytes::from(v).slice(pad..pad + buf.len()); + + let restored = codec.deserialize(&data).hit().unwrap(); + let restored = restored.downcast::().unwrap(); + + let base = data.as_ptr() as usize; + let end = base + data.len(); + for col in restored.data.columns() { + for buffer in col.to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= base && ptr < end, + "data batch buffer was realigned out of the input — misaligned IPC section", + ); + } + } + } + #[tokio::test] async fn test_equality() { check_index(&SargableQuery::Equals(ScalarValue::from(100)), &[0]).await; diff --git a/rust/lance-index/src/scalar/inverted/cache_codec.rs b/rust/lance-index/src/scalar/inverted/cache_codec.rs index 74cfc98ef7b..4ce75ec4406 100644 --- a/rust/lance-index/src/scalar/inverted/cache_codec.rs +++ b/rust/lance-index/src/scalar/inverted/cache_codec.rs @@ -4,16 +4,24 @@ //! Cache codec impls for FTS index entries. //! //! Serializes [`PostingList`] and [`Positions`] cache values for persistent -//! cache backends. The format is a small variant tag plus a JSON header for -//! scalar metadata, with Arrow-backed payload sections written as zero-copy -//! Arrow IPC streams via [`lance_arrow::ipc`]. The raw byte buffer inside -//! [`SharedPositionStream`] is written via [`write_len_prefixed_bytes`] and -//! read back via [`read_len_prefixed_bytes_at`] -- both zero-copy slices into -//! the input `Bytes` allocation. +//! cache backends, behind the stabilized envelope written by +//! [`CacheCodec`](lance_core::cache::CacheCodec). //! -//! This is the FTS counterpart of `partition_serde.rs` for vector indices. +//! Every variant uses a protobuf header (see `protos-cache/cache.proto`, with the +//! tail/position codecs and position-storage kind as proto enums) followed by +//! 64-byte-aligned Arrow IPC sections and, where applicable, raw blobs: +//! +//! - the compressed posting list: an IPC section for `blocks`, then the +//! position sections (legacy IPC, or shared block-offsets IPC + a raw blob of +//! the [`SharedPositionStream`] byte buffer, which has its own portable +//! encoding); +//! - the plain posting list: an IPC section of `(row_ids, frequencies)`, then +//! an optional legacy position IPC section; +//! - the standalone [`Positions`] codec: the position sections alone. +//! +//! All sections read back zero-copy via [`lance_arrow::ipc`]. This is the FTS +//! counterpart of `partition_serde.rs` for vector indices. -use std::io::Write; use std::sync::Arc; use arrow_array::cast::AsArray; @@ -22,14 +30,15 @@ use arrow_array::{ Array, Float32Array, LargeBinaryArray, ListArray, RecordBatch, UInt32Array, UInt64Array, }; use arrow_schema::{DataType, Field, Schema}; -use bytes::Bytes; -use lance_arrow::ipc::{ - read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream, - write_len_prefixed_bytes, -}; -use lance_core::cache::CacheCodecImpl; +use lance_arrow::ipc::{read_len_prefixed_bytes_at, write_len_prefixed_bytes}; +use lance_core::cache::{CacheCodecImpl, CacheEntryReader, CacheEntryWriter}; use lance_core::{Error, Result}; -use serde::{Deserialize, Serialize}; + +use crate::cache_pb::{ + CompressedPostingHeader, PlainPostingHeader, PositionStorage as PbPositionStorage, + PositionStreamCodec as PbPositionStreamCodec, PositionsHeader, + PostingTailCodec as PbPostingTailCodec, +}; use super::index::{ CompressedPositionStorage, CompressedPostingList, PlainPostingList, PositionStreamCodec, @@ -43,86 +52,43 @@ use super::index::{ const POSTING_VARIANT_PLAIN: u8 = 0; const POSTING_VARIANT_COMPRESSED: u8 = 1; -const POSITIONS_TAG_NONE: u8 = 0; -const POSITIONS_TAG_LEGACY: u8 = 1; -const POSITIONS_TAG_SHARED: u8 = 2; - -const POSTING_TAIL_CODEC_FIXED32: u8 = 0; -const POSTING_TAIL_CODEC_VARINT_DELTA: u8 = 1; - -const POSITION_STREAM_CODEC_VARINT_DOC_DELTA: u8 = 0; -const POSITION_STREAM_CODEC_PACKED_DELTA: u8 = 1; - // --------------------------------------------------------------------------- -// Codec enum byte mappings +// Codec enum mappings // --------------------------------------------------------------------------- -fn posting_tail_codec_to_u8(c: PostingTailCodec) -> u8 { - match c { - PostingTailCodec::Fixed32 => POSTING_TAIL_CODEC_FIXED32, - PostingTailCodec::VarintDelta => POSTING_TAIL_CODEC_VARINT_DELTA, - } -} +// Posting lists carry their discriminants as protobuf enums in the header; +// these map to/from the in-memory Rust enums. -fn u8_to_posting_tail_codec(v: u8) -> Result { - match v { - POSTING_TAIL_CODEC_FIXED32 => Ok(PostingTailCodec::Fixed32), - POSTING_TAIL_CODEC_VARINT_DELTA => Ok(PostingTailCodec::VarintDelta), - _ => Err(Error::io(format!("unknown posting tail codec: {v}"))), +fn posting_tail_codec_to_proto(c: PostingTailCodec) -> PbPostingTailCodec { + match c { + PostingTailCodec::Fixed32 => PbPostingTailCodec::Fixed32, + PostingTailCodec::VarintDelta => PbPostingTailCodec::VarintDelta, } } -fn position_stream_codec_to_u8(c: PositionStreamCodec) -> u8 { +fn proto_to_posting_tail_codec(c: PbPostingTailCodec) -> PostingTailCodec { match c { - PositionStreamCodec::VarintDocDelta => POSITION_STREAM_CODEC_VARINT_DOC_DELTA, - PositionStreamCodec::PackedDelta => POSITION_STREAM_CODEC_PACKED_DELTA, + PbPostingTailCodec::Fixed32 => PostingTailCodec::Fixed32, + PbPostingTailCodec::VarintDelta => PostingTailCodec::VarintDelta, } } -fn u8_to_position_stream_codec(v: u8) -> Result { - match v { - POSITION_STREAM_CODEC_VARINT_DOC_DELTA => Ok(PositionStreamCodec::VarintDocDelta), - POSITION_STREAM_CODEC_PACKED_DELTA => Ok(PositionStreamCodec::PackedDelta), - _ => Err(Error::io(format!("unknown position stream codec: {v}"))), +fn position_stream_codec_to_proto(c: PositionStreamCodec) -> PbPositionStreamCodec { + match c { + PositionStreamCodec::VarintDocDelta => PbPositionStreamCodec::VarintDocDelta, + PositionStreamCodec::PackedDelta => PbPositionStreamCodec::PackedDelta, } } -// --------------------------------------------------------------------------- -// Header / tag I/O helpers (mirrors partition_serde.rs) -// --------------------------------------------------------------------------- - -fn write_json_header(writer: &mut dyn Write, header: &impl Serialize) -> Result<()> { - let bytes = serde_json::to_vec(header)?; - write_len_prefixed_bytes(writer, &bytes)?; - Ok(()) -} - -fn read_json_header(data: &Bytes, offset: &mut usize) -> Result { - let bytes = read_len_prefixed_bytes_at(data, offset).map_err(|e| Error::io(e.to_string()))?; - serde_json::from_slice(&bytes) - .map_err(|e| Error::io(format!("failed to deserialize cache header: {e}"))) -} - -fn write_u8(writer: &mut dyn Write, value: u8) -> Result<()> { - writer - .write_all(&[value]) - .map_err(|e| Error::io(format!("failed to write tag byte: {e}"))) -} - -fn read_u8(data: &Bytes, offset: &mut usize) -> Result { - let bytes = data.as_ref(); - if *offset >= bytes.len() { - return Err(Error::io( - "truncated cache entry: missing tag byte".to_string(), - )); +fn proto_to_position_stream_codec(c: PbPositionStreamCodec) -> PositionStreamCodec { + match c { + PbPositionStreamCodec::VarintDocDelta => PositionStreamCodec::VarintDocDelta, + PbPositionStreamCodec::PackedDelta => PositionStreamCodec::PackedDelta, } - let v = bytes[*offset]; - *offset += 1; - Ok(v) } // --------------------------------------------------------------------------- -// Position storage serde (shared by PostingList variants and Positions) +// Position storage sections (shared by PostingList variants and Positions) // --------------------------------------------------------------------------- const POSITION_LIST_COLUMN: &str = "position_list"; @@ -131,33 +97,36 @@ const ROW_IDS_COLUMN: &str = "row_ids"; const FREQUENCIES_COLUMN: &str = "frequencies"; const BLOCKS_COLUMN: &str = "blocks"; -#[derive(Serialize, Deserialize)] -struct SharedPositionsHeader { - codec: u8, +fn legacy_positions_batch(list: &ListArray) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new( + POSITION_LIST_COLUMN, + list.data_type().clone(), + list.is_nullable(), + )])); + Ok(RecordBatch::try_new(schema, vec![Arc::new(list.clone())])?) +} + +fn read_legacy_positions(r: &mut CacheEntryReader<'_>) -> Result { + let batch = r.read_ipc()?; + Ok(batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::io("legacy position column is not a ListArray".to_string()))? + .clone()) } -fn write_position_storage( - writer: &mut dyn Write, +/// Write the position sections (the bytes after the header) for `storage`. The +/// caller's header proto carries the storage kind and shared-stream codec. +fn write_position_sections( + w: &mut CacheEntryWriter<'_>, storage: &CompressedPositionStorage, ) -> Result<()> { match storage { CompressedPositionStorage::LegacyPerDoc(list) => { - write_u8(writer, POSITIONS_TAG_LEGACY)?; - let schema = Arc::new(Schema::new(vec![Field::new( - POSITION_LIST_COLUMN, - list.data_type().clone(), - list.is_nullable(), - )])); - let batch = RecordBatch::try_new(schema, vec![Arc::new(list.clone())])?; - write_ipc_stream(&batch, writer)?; + w.write_ipc(&legacy_positions_batch(list)?)?; } CompressedPositionStorage::SharedStream(stream) => { - write_u8(writer, POSITIONS_TAG_SHARED)?; - let header = SharedPositionsHeader { - codec: position_stream_codec_to_u8(stream.codec()), - }; - write_json_header(writer, &header)?; - let offsets = UInt32Array::from(stream.block_offsets().to_vec()); let schema = Arc::new(Schema::new(vec![Field::new( BLOCK_OFFSETS_COLUMN, @@ -165,55 +134,42 @@ fn write_position_storage( false, )])); let batch = RecordBatch::try_new(schema, vec![Arc::new(offsets)])?; - write_ipc_stream(&batch, writer)?; - - write_len_prefixed_bytes(writer, stream.bytes())?; + w.write_ipc(&batch)?; + w.write_raw(stream.bytes())?; } } Ok(()) } -fn read_position_storage( - data: &Bytes, - offset: &mut usize, - tag: u8, -) -> Result { - match tag { - POSITIONS_TAG_LEGACY => { - let batch = - read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?; - let list = batch - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| Error::io("legacy position column is not a ListArray".to_string()))? - .clone(); - Ok(CompressedPositionStorage::LegacyPerDoc(list)) - } - POSITIONS_TAG_SHARED => { - let header: SharedPositionsHeader = read_json_header(data, offset)?; - let codec = u8_to_position_stream_codec(header.codec)?; - - let batch = - read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?; +/// Read the position sections for the given `storage` kind and (for shared +/// streams) `stream_codec`. Returns `None` only when `storage` is +/// [`PbPositionStorage::None`]. +fn read_position_sections( + r: &mut CacheEntryReader<'_>, + storage: PbPositionStorage, + stream_codec: PositionStreamCodec, +) -> Result> { + match storage { + PbPositionStorage::None => Ok(None), + PbPositionStorage::Legacy => Ok(Some(CompressedPositionStorage::LegacyPerDoc( + read_legacy_positions(r)?, + ))), + PbPositionStorage::Shared => { + let batch = r.read_ipc()?; let block_offsets = batch .column(0) .as_primitive_opt::() .ok_or_else(|| Error::io("block_offsets column is not UInt32".to_string()))? .values() .to_vec(); - - // Zero copy: read_len_prefixed_bytes_at returns a Bytes slice - // backed by the same allocation as `data`, and SharedPositionStream - // now stores its byte buffer as Bytes -- no copy on read. - let bytes = - read_len_prefixed_bytes_at(data, offset).map_err(|e| Error::io(e.to_string()))?; - - Ok(CompressedPositionStorage::SharedStream( - SharedPositionStream::new(codec, block_offsets, bytes), - )) + // Zero copy: read_raw returns a Bytes slice backed by the same + // allocation as the input, and SharedPositionStream stores its byte + // buffer as Bytes -- no copy on read. + let bytes = r.read_raw()?; + Ok(Some(CompressedPositionStorage::SharedStream( + SharedPositionStream::new(stream_codec, block_offsets, bytes), + ))) } - other => Err(Error::io(format!("unknown positions tag: {other}"))), } } @@ -221,50 +177,45 @@ fn read_position_storage( // PostingList codec // --------------------------------------------------------------------------- -#[derive(Serialize, Deserialize)] -struct PlainPostingHeader { - max_score: Option, -} - -#[derive(Serialize, Deserialize)] -struct CompressedPostingHeader { - max_score: f32, - length: u32, - posting_tail_codec: u8, -} - impl CacheCodecImpl for PostingList { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { + const TYPE_ID: &'static str = "lance.fts.PostingList"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { match self { Self::Plain(plain) => { - write_u8(writer, POSTING_VARIANT_PLAIN)?; - serialize_plain(writer, plain) + w.write_u8(POSTING_VARIANT_PLAIN)?; + serialize_plain(w, plain) } Self::Compressed(compressed) => { - write_u8(writer, POSTING_VARIANT_COMPRESSED)?; - serialize_compressed(writer, compressed) + w.write_u8(POSTING_VARIANT_COMPRESSED)?; + serialize_compressed(w, compressed) } } } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let variant = read_u8(data, &mut offset)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let variant = r.read_u8()?; match variant { - POSTING_VARIANT_PLAIN => Ok(Self::Plain(deserialize_plain(data, &mut offset)?)), - POSTING_VARIANT_COMPRESSED => { - Ok(Self::Compressed(deserialize_compressed(data, &mut offset)?)) - } + POSTING_VARIANT_PLAIN => Ok(Self::Plain(deserialize_plain(r)?)), + POSTING_VARIANT_COMPRESSED => Ok(Self::Compressed(deserialize_compressed(r)?)), other => Err(Error::io(format!("unknown PostingList variant: {other}"))), } } } -fn serialize_plain(writer: &mut dyn Write, plain: &PlainPostingList) -> Result<()> { +fn serialize_plain(w: &mut CacheEntryWriter<'_>, plain: &PlainPostingList) -> Result<()> { + // Plain postings carry only per-doc legacy positions (or none). + let position_storage = if plain.positions.is_some() { + PbPositionStorage::Legacy + } else { + PbPositionStorage::None + }; let header = PlainPostingHeader { max_score: plain.max_score, + position_storage: position_storage as i32, }; - write_json_header(writer, &header)?; + w.write_header(&header)?; let row_ids = UInt64Array::new(plain.row_ids.clone(), None); let frequencies = Float32Array::new(plain.frequencies.clone(), None); @@ -273,26 +224,18 @@ fn serialize_plain(writer: &mut dyn Write, plain: &PlainPostingList) -> Result<( Field::new(FREQUENCIES_COLUMN, DataType::Float32, false), ])); let batch = RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(frequencies)])?; - write_ipc_stream(&batch, writer)?; - - match &plain.positions { - Some(list) => { - // Plain postings can only carry per-doc legacy positions; reuse - // the shared encoder. - write_position_storage( - writer, - &CompressedPositionStorage::LegacyPerDoc(list.clone()), - )?; - } - None => write_u8(writer, POSITIONS_TAG_NONE)?, + w.write_ipc(&batch)?; + + if let Some(list) = &plain.positions { + w.write_ipc(&legacy_positions_batch(list)?)?; } Ok(()) } -fn deserialize_plain(data: &Bytes, offset: &mut usize) -> Result { - let header: PlainPostingHeader = read_json_header(data, offset)?; +fn deserialize_plain(r: &mut CacheEntryReader<'_>) -> Result { + let header: PlainPostingHeader = r.read_header()?; - let batch = read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?; + let batch = r.read_ipc()?; let row_ids = batch .column(0) .as_primitive_opt::() @@ -306,19 +249,13 @@ fn deserialize_plain(data: &Bytes, offset: &mut usize) -> Result None, - POSITIONS_TAG_LEGACY => match read_position_storage(data, offset, positions_tag)? { - CompressedPositionStorage::LegacyPerDoc(list) => Some(list), - CompressedPositionStorage::SharedStream(_) => { - unreachable!("shared stream tag was read as legacy variant (this is a bug)") - } - }, - other => { - return Err(Error::io(format!( - "Plain posting list cannot have positions tag {other}" - ))); + let positions = match header.position_storage() { + PbPositionStorage::None => None, + PbPositionStorage::Legacy => Some(read_legacy_positions(r)?), + PbPositionStorage::Shared => { + return Err(Error::io( + "Plain posting list cannot have a shared position stream".to_string(), + )); } }; @@ -330,13 +267,33 @@ fn deserialize_plain(data: &Bytes, offset: &mut usize) -> Result Result<()> { +/// The compressed posting list is serialized with a protobuf header followed +/// by 64-byte-aligned Arrow IPC sections (for the `blocks`, and for shared +/// position block-offsets) and a raw blob (for the shared position byte +/// stream, which already has its own portable encoding). +fn serialize_compressed( + w: &mut CacheEntryWriter<'_>, + posting: &CompressedPostingList, +) -> Result<()> { + let (position_storage, position_stream_codec) = match &posting.positions { + None => (PbPositionStorage::None, PbPositionStreamCodec::default()), + Some(CompressedPositionStorage::LegacyPerDoc(_)) => { + (PbPositionStorage::Legacy, PbPositionStreamCodec::default()) + } + Some(CompressedPositionStorage::SharedStream(stream)) => ( + PbPositionStorage::Shared, + position_stream_codec_to_proto(stream.codec()), + ), + }; + let header = CompressedPostingHeader { max_score: posting.max_score, length: posting.length, - posting_tail_codec: posting_tail_codec_to_u8(posting.posting_tail_codec), + posting_tail_codec: posting_tail_codec_to_proto(posting.posting_tail_codec) as i32, + position_storage: position_storage as i32, + position_stream_codec: position_stream_codec as i32, }; - write_json_header(writer, &header)?; + w.write_header(&header)?; let schema = Arc::new(Schema::new(vec![Field::new( BLOCKS_COLUMN, @@ -344,20 +301,19 @@ fn serialize_compressed(writer: &mut dyn Write, posting: &CompressedPostingList) false, )])); let batch = RecordBatch::try_new(schema, vec![Arc::new(posting.blocks.clone())])?; - write_ipc_stream(&batch, writer)?; + w.write_ipc(&batch)?; - match &posting.positions { - Some(storage) => write_position_storage(writer, storage)?, - None => write_u8(writer, POSITIONS_TAG_NONE)?, + if let Some(storage) = &posting.positions { + write_position_sections(w, storage)?; } Ok(()) } -fn deserialize_compressed(data: &Bytes, offset: &mut usize) -> Result { - let header: CompressedPostingHeader = read_json_header(data, offset)?; - let posting_tail_codec = u8_to_posting_tail_codec(header.posting_tail_codec)?; +fn deserialize_compressed(r: &mut CacheEntryReader<'_>) -> Result { + let header: CompressedPostingHeader = r.read_header()?; + let posting_tail_codec = proto_to_posting_tail_codec(header.posting_tail_codec()); - let batch = read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?; + let batch = r.read_ipc()?; let blocks = batch .column(0) .as_any() @@ -365,12 +321,8 @@ fn deserialize_compressed(data: &Bytes, offset: &mut usize) -> Result Result Result<()> { + const TYPE_ID: &'static str = "lance.fts.PostingListGroup"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let writer = w.raw_writer(); let count = u32::try_from(self.posting_lists.len()) .map_err(|_| Error::io("posting list group too large to serialize".to_string()))?; writer @@ -398,13 +354,16 @@ impl CacheCodecImpl for PostingListGroup { .map_err(|e| Error::io(format!("failed to write group count: {e}")))?; for posting in &self.posting_lists { let mut buf = Vec::new(); - posting.serialize(&mut buf)?; + let mut sub = CacheEntryWriter::new(&mut buf); + posting.serialize(&mut sub)?; write_len_prefixed_bytes(writer, &buf)?; } Ok(()) } - fn deserialize(data: &Bytes) -> Result { + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let body = r.body(); + let data = &body; let mut offset = 0; if data.len() < 4 { return Err(Error::io( @@ -413,11 +372,13 @@ impl CacheCodecImpl for PostingListGroup { } let count = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize; offset += 4; + let version = r.version(); let mut posting_lists = Vec::with_capacity(count); for _ in 0..count { let entry = read_len_prefixed_bytes_at(data, &mut offset) .map_err(|e| Error::io(e.to_string()))?; - posting_lists.push(PostingList::deserialize(&entry)?); + let mut sub = CacheEntryReader::new(&entry, 0, version); + posting_lists.push(PostingList::deserialize(&mut sub)?); } Ok(Self::new(posting_lists)) } @@ -428,20 +389,35 @@ impl CacheCodecImpl for PostingListGroup { // --------------------------------------------------------------------------- impl CacheCodecImpl for Positions { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - write_position_storage(writer, &self.0) + const TYPE_ID: &'static str = "lance.fts.Positions"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let (position_storage, position_stream_codec) = match &self.0 { + CompressedPositionStorage::LegacyPerDoc(_) => { + (PbPositionStorage::Legacy, PbPositionStreamCodec::default()) + } + CompressedPositionStorage::SharedStream(stream) => ( + PbPositionStorage::Shared, + position_stream_codec_to_proto(stream.codec()), + ), + }; + let header = PositionsHeader { + position_storage: position_storage as i32, + position_stream_codec: position_stream_codec as i32, + }; + w.write_header(&header)?; + write_position_sections(w, &self.0) } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let tag = read_u8(data, &mut offset)?; - if tag == POSITIONS_TAG_NONE { - return Err(Error::io( - "Positions cache entry cannot encode the None variant".to_string(), - )); - } - let storage = read_position_storage(data, &mut offset, tag)?; - Ok(Self(storage)) + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: PositionsHeader = r.read_header()?; + let stream_codec = proto_to_position_stream_codec(header.position_stream_codec()); + read_position_sections(r, header.position_storage(), stream_codec)? + .map(Self) + .ok_or_else(|| { + Error::io("Positions cache entry cannot encode the None variant".to_string()) + }) } } @@ -455,7 +431,8 @@ mod tests { use arrow_array::LargeBinaryArray; use arrow_array::builder::{Int32Builder, ListBuilder}; use bytes::Bytes; - use lance_core::cache::CacheCodecImpl; + use lance_core::Result; + use lance_core::cache::{CacheCodecImpl, CacheEntryReader, CacheEntryWriter}; use super::super::index::{ CompressedPositionStorage, CompressedPostingList, PlainPostingList, PositionStreamCodec, @@ -502,16 +479,26 @@ mod tests { } } - fn roundtrip_posting_list(entry: &PostingList) -> PostingList { + /// Serialize a codec body (no envelope) into a standalone buffer. + fn body_bytes(entry: &T) -> Bytes { let mut buf = Vec::new(); - entry.serialize(&mut buf).unwrap(); - PostingList::deserialize(&Bytes::from(buf)).unwrap() + let mut w = CacheEntryWriter::new(&mut buf); + entry.serialize(&mut w).unwrap(); + Bytes::from(buf) + } + + /// Deserialize a codec body (no envelope) at the current build's version. + fn from_body(data: &Bytes) -> Result { + let mut r = CacheEntryReader::new(data, 0, T::CURRENT_VERSION); + T::deserialize(&mut r) + } + + fn roundtrip_posting_list(entry: &PostingList) -> PostingList { + from_body::(&body_bytes(entry)).unwrap() } fn roundtrip_positions(entry: &Positions) -> Positions { - let mut buf = Vec::new(); - entry.serialize(&mut buf).unwrap(); - Positions::deserialize(&Bytes::from(buf)).unwrap() + from_body::(&body_bytes(entry)).unwrap() } fn assert_slice_points_into_bytes(slice: &[u8], bytes: &Bytes) { @@ -652,13 +639,9 @@ mod tests { expected_stream.clone(), )), ); - let mut buf = Vec::new(); - PostingList::Compressed(posting) - .serialize(&mut buf) - .unwrap(); - let serialized = Bytes::from(buf); + let serialized = body_bytes(&PostingList::Compressed(posting)); - let restored = PostingList::deserialize(&serialized).unwrap(); + let restored = from_body::(&serialized).unwrap(); let PostingList::Compressed(restored) = restored else { panic!("expected Compressed variant"); }; @@ -695,9 +678,7 @@ mod tests { vec![plain.clone(), compressed, plain], ] { let group = PostingListGroup::new(members.clone()); - let mut buf = Vec::new(); - group.serialize(&mut buf).unwrap(); - let restored = PostingListGroup::deserialize(&Bytes::from(buf)).unwrap(); + let restored = from_body::(&body_bytes(&group)).unwrap(); assert_eq!(restored.posting_lists.len(), members.len()); for (a, b) in members.iter().zip(restored.posting_lists.iter()) { match (a, b) { @@ -743,9 +724,193 @@ mod tests { None, ); let entry = PostingList::Plain(plain); - let mut buf = Vec::new(); - entry.serialize(&mut buf).unwrap(); + let mut buf = body_bytes(&entry).to_vec(); buf.truncate(buf.len() / 2); - assert!(PostingList::deserialize(&Bytes::from(buf)).is_err()); + assert!(from_body::(&Bytes::from(buf)).is_err()); + } + + /// Tests covering the stabilized envelope + compressed proto format, + /// exercised through the full type-erased [`CacheCodec`] (envelope + body). + mod stable_format { + use std::sync::Arc; + + use arrow_array::Array; + use lance_core::cache::CacheCodec; + use prost::Message; + + use super::*; + use crate::cache_pb::{CompressedPostingHeader, PostingTailCodec as PbPostingTailCodec}; + + type ArcAny = Arc; + + fn codec() -> CacheCodec { + CacheCodec::from_impl::() + } + + /// Serialize an entry through the full codec (envelope + body). + fn serialize_entry(entry: PostingList) -> Vec { + let any: ArcAny = Arc::new(entry); + let mut buf = Vec::new(); + codec().serialize(&any, &mut buf).unwrap(); + buf + } + + /// A `Bytes` whose base address is 64-byte aligned, modelling a backend + /// that reads cache entries into an aligned buffer. + fn aligned_bytes(payload: &[u8]) -> Bytes { + const ALIGN: usize = 64; + let mut v = vec![0u8; payload.len() + ALIGN]; + let pad = (ALIGN - (v.as_ptr() as usize % ALIGN)) % ALIGN; + v[pad..pad + payload.len()].copy_from_slice(payload); + Bytes::from(v).slice(pad..pad + payload.len()) + } + + fn compressed_with_shared_positions() -> PostingList { + let blocks = + LargeBinaryArray::from_opt_vec(vec![Some(&[9u8; 48][..]), Some(&[1u8; 48])]); + let stream = SharedPositionStream::new( + PositionStreamCodec::PackedDelta, + vec![0u32, 4, 11], + Bytes::from((0u8..64).collect::>()), + ); + PostingList::Compressed(CompressedPostingList::new( + blocks, + 7.0, + 3, + PostingTailCodec::VarintDelta, + Some(CompressedPositionStorage::SharedStream(stream)), + )) + } + + /// The compressed `blocks` (an aligned IPC section) and the shared + /// position blob (a raw section) must both be borrowed zero-copy from + /// the input even though the envelope pushes them to a non-zero, + /// non-aligned starting offset. + #[test] + fn compressed_sections_are_zero_copy_through_envelope() { + let serialized = aligned_bytes(&serialize_entry(compressed_with_shared_positions())); + let restored = codec().deserialize(&serialized).hit().unwrap(); + let restored = restored.downcast::().unwrap(); + let PostingList::Compressed(restored) = restored.as_ref() else { + panic!("expected Compressed"); + }; + + let base = serialized.as_ptr() as usize; + let end = base + serialized.len(); + let points_in = |ptr: usize| ptr >= base && ptr < end; + + // blocks IPC section decoded in place (no realigning memcpy). + for buf in restored.blocks.to_data().buffers() { + assert!( + points_in(buf.as_ptr() as usize), + "blocks buffer was realigned out of the input — misaligned IPC section", + ); + } + // shared position raw blob borrowed in place. + let Some(CompressedPositionStorage::SharedStream(stream)) = &restored.positions else { + panic!("expected shared stream"); + }; + assert!(points_in(stream.bytes().as_ptr() as usize)); + } + + /// The plain posting's row-id/frequency IPC section must also decode + /// zero-copy through the envelope + proto header. + #[test] + fn plain_sections_are_zero_copy_through_envelope() { + let plain = PostingList::Plain(PlainPostingList::new( + ScalarBuffer::from((0u64..64).collect::>()), + ScalarBuffer::from(vec![1.0f32; 64]), + Some(2.0), + None, + )); + let serialized = aligned_bytes(&serialize_entry(plain)); + let restored = codec().deserialize(&serialized).hit().unwrap(); + let restored = restored.downcast::().unwrap(); + let PostingList::Plain(restored) = restored.as_ref() else { + panic!("expected Plain"); + }; + + let base = serialized.as_ptr() as usize; + let end = base + serialized.len(); + // The row_ids ScalarBuffer must borrow from the input allocation. + let ptr = restored.row_ids.as_ptr() as usize; + assert!( + ptr >= base && ptr < end, + "row_ids buffer was realigned out of the input — misaligned IPC section", + ); + } + + /// Additive proto fields (lever #1) must not break decoding: an unknown + /// field number appended to the header is ignored. + #[test] + fn header_proto_ignores_unknown_fields() { + let header = CompressedPostingHeader { + max_score: 1.5, + length: 9, + posting_tail_codec: PbPostingTailCodec::VarintDelta as i32, + ..Default::default() + }; + let mut bytes = header.encode_to_vec(); + // Append an unknown field #15, varint wire type (0), value 7. + bytes.push(15 << 3); + bytes.push(7); + let decoded = CompressedPostingHeader::decode(bytes.as_slice()).unwrap(); + assert_eq!(decoded.length, 9); + assert_eq!(decoded.max_score, 1.5); + } + + /// An entry written by a different codec (foreign TYPE_ID) misses. + #[test] + fn foreign_type_id_is_miss() { + // A PostingListGroup entry carries a different TYPE_ID in its + // envelope; reading it as a PostingList must miss, not misread it. + let group = PostingListGroup::new(vec![]); + let any: ArcAny = Arc::new(group); + let mut buf = Vec::new(); + CacheCodec::from_impl::() + .serialize(&any, &mut buf) + .unwrap(); + assert!(codec().deserialize(&Bytes::from(buf)).hit().is_none()); + } + + /// An entry written by a newer build (higher type_version) misses. + #[test] + fn future_type_version_is_miss() { + let mut buf = serialize_entry(compressed_with_shared_positions()); + // Patch the envelope's type_version (magic[4] + ver[1] + len[2] + + // type_id[N]) to a value beyond what this build understands. + let type_id_len = u16::from_le_bytes([buf[5], buf[6]]) as usize; + let version_off = 4 + 1 + 2 + type_id_len; + buf[version_off..version_off + 4].copy_from_slice(&u32::MAX.to_le_bytes()); + assert!(codec().deserialize(&Bytes::from(buf)).hit().is_none()); + } + + /// A pre-stabilization blob (no magic) self-heals to a miss. + #[test] + fn pre_stabilization_blob_is_miss() { + // Old format led with a u64 LE length prefix, never our magic. + let mut blob = (30u64).to_le_bytes().to_vec(); + blob.extend_from_slice(&[0u8; 30]); + assert!(codec().deserialize(&Bytes::from(blob)).hit().is_none()); + } + + /// A structurally-valid envelope whose body leads with an out-of-range + /// variant tag self-heals to a `BodyError` miss rather than panicking or + /// misreading the remaining bytes. + #[test] + fn unknown_posting_variant_is_miss() { + use lance_core::cache::{CacheDecode, CacheMissReason}; + + let mut buf = serialize_entry(compressed_with_shared_positions()); + // The variant tag is the first body byte, right after the envelope + // (magic[4] + ver[1] + type_id_len[2] + type_id[N] + type_version[4]). + let type_id_len = u16::from_le_bytes([buf[5], buf[6]]) as usize; + let variant_off = 4 + 1 + 2 + type_id_len + 4; + buf[variant_off] = 2; // neither PLAIN (0) nor COMPRESSED (1) + match codec().deserialize(&Bytes::from(buf)) { + CacheDecode::Miss(reason) => assert_eq!(reason, CacheMissReason::BodyError), + CacheDecode::Hit(_) => panic!("expected a BodyError miss, got a hit"), + } + } } } diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index cf357d89585..8e07a607bff 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -18,8 +18,9 @@ use datafusion::execution::RecordBatchStream; use datafusion::physical_plan::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter}; use datafusion_common::ScalarValue; use futures::{StreamExt, TryStream, TryStreamExt, stream::BoxStream}; -use lance_arrow::ipc::{read_len_prefixed_bytes_at, write_len_prefixed_bytes}; -use lance_core::cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache}; +use lance_core::cache::{ + CacheCodec, CacheCodecImpl, CacheEntryReader, CacheEntryWriter, CacheKey, LanceCache, +}; use lance_core::deepsize::DeepSizeOf; use lance_core::error::LanceOptionExt; use lance_core::{Error, ROW_ID, Result}; @@ -532,27 +533,30 @@ impl LabelListIndexState { } impl CacheCodecImpl for LabelListIndexState { + const TYPE_ID: &'static str = "lance.scalar.LabelListIndexState"; + const CURRENT_VERSION: u32 = 1; + /// Wire format: /// ```text - /// [u64 list_nulls_len][list_nulls bytes] - /// [bitmap state bytes (self-delimiting)] + /// RAW_BLOB : list_nulls (roaring tree map, portable encoding) + /// /// ``` - fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { let mut nulls_bytes = Vec::with_capacity(self.list_nulls.serialized_size()); self.list_nulls.serialize_into(&mut nulls_bytes)?; - write_len_prefixed_bytes(writer, &nulls_bytes)?; - self.bitmap_state.serialize(writer)?; + w.write_raw(&nulls_bytes)?; + // The bitmap state writes its own self-delimiting body inline. + self.bitmap_state.serialize(w)?; Ok(()) } - fn deserialize(data: &bytes::Bytes) -> Result { - let mut offset = 0; - let nulls_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let nulls_bytes = r.read_raw()?; let list_nulls = Arc::new(RowAddrTreeMap::deserialize_from(nulls_bytes.as_ref())?); // The bitmap state is self-delimiting (length-prefixed null map + - // Arrow IPC stream with EOS marker), so we can hand the remaining - // tail to it directly. - let bitmap_state = BitmapIndexState::deserialize(&data.slice(offset..))?; + // Arrow IPC stream with EOS marker); it continues reading the body + // from where the null map left off. + let bitmap_state = BitmapIndexState::deserialize(r)?; Ok(Self { bitmap_state, list_nulls, @@ -728,3 +732,91 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { Ok(()) } } + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use datafusion_common::ScalarValue; + use lance_core::cache::CacheCodec; + use lance_core::utils::address::RowAddress; + + use super::super::bitmap::BitmapIndexState; + use super::super::btree::OrderableScalarValue; + use super::*; + + fn sample_state() -> LabelListIndexState { + let mut index_map = BTreeMap::new(); + for k in 0..32i32 { + index_map.insert( + OrderableScalarValue(ScalarValue::Int32(Some(k))), + k as usize, + ); + } + let mut bitmap_nulls = RowAddrTreeMap::new(); + bitmap_nulls.insert(RowAddress::new_from_parts(0, 3).into()); + let bitmap_state = + BitmapIndexState::new_for_test(index_map, bitmap_nulls, DataType::Int32).unwrap(); + + let mut list_nulls = RowAddrTreeMap::new(); + list_nulls.insert(RowAddress::new_from_parts(0, 9).into()); + LabelListIndexState { + bitmap_state, + list_nulls: Arc::new(list_nulls), + } + } + + #[test] + fn test_label_list_state_codec_roundtrip() { + let state = sample_state(); + let mut buf = Vec::new(); + state + .serialize(&mut CacheEntryWriter::new(&mut buf)) + .unwrap(); + let data = Bytes::from(buf); + let mut reader = CacheEntryReader::new(&data, 0, LabelListIndexState::CURRENT_VERSION); + let restored = LabelListIndexState::deserialize(&mut reader).unwrap(); + + assert_eq!(&*restored.list_nulls, &*state.list_nulls); + assert_eq!( + restored.bitmap_state.lookup_batch(), + state.bitmap_state.lookup_batch() + ); + assert_eq!( + restored.bitmap_state.null_map(), + state.bitmap_state.null_map() + ); + } + + /// The nested bitmap lookup batch must decode zero-copy through the full + /// envelope, proving the leading `list_nulls` RAW_BLOB does not knock the + /// nested IPC section off its 64-byte boundary. + #[test] + fn test_label_list_nested_lookup_is_zero_copy() { + const ALIGN: usize = 64; + let codec = CacheCodec::from_impl::(); + let any: Arc = Arc::new(sample_state()); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + + let mut v = vec![0u8; buf.len() + ALIGN]; + let pad = (ALIGN - (v.as_ptr() as usize % ALIGN)) % ALIGN; + v[pad..pad + buf.len()].copy_from_slice(&buf); + let data = Bytes::from(v).slice(pad..pad + buf.len()); + + let restored = codec.deserialize(&data).hit().unwrap(); + let restored = restored.downcast::().unwrap(); + + let base = data.as_ptr() as usize; + let end = base + data.len(); + for col in restored.bitmap_state.lookup_batch().columns() { + for buffer in col.to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= base && ptr < end, + "nested bitmap lookup buffer was realigned — misaligned IPC section", + ); + } + } + } +} diff --git a/rust/lance-select/src/mask.rs b/rust/lance-select/src/mask.rs index b76e0de9a2b..f9df7720441 100644 --- a/rust/lance-select/src/mask.rs +++ b/rust/lance-select/src/mask.rs @@ -13,7 +13,7 @@ use itertools::Itertools; use lance_core::deepsize::DeepSizeOf; use roaring::{MultiOps, RoaringBitmap, RoaringTreemap}; -use lance_core::cache::CacheCodecImpl; +use lance_core::cache::{CacheCodecImpl, CacheEntryReader, CacheEntryWriter}; use lance_core::utils::address::RowAddress; use lance_core::{Error, Result}; @@ -692,12 +692,17 @@ impl RowAddrTreeMap { } impl CacheCodecImpl for RowAddrTreeMap { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - self.serialize_into(writer) + const TYPE_ID: &'static str = "lance.RowAddrTreeMap"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + // A roaring bitmap has its own stable, portable serialization; it is + // the whole body, so write it raw rather than length-prefixed. + self.serialize_into(w.raw_writer()) } - fn deserialize(data: &bytes::Bytes) -> Result { - Self::deserialize_from(data.as_ref()) + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + Self::deserialize_from(r.body().as_ref()) } } diff --git a/rust/lance-table/src/format/index.rs b/rust/lance-table/src/format/index.rs index 33ee464fe76..f603536a3eb 100644 --- a/rust/lance-table/src/format/index.rs +++ b/rust/lance-table/src/format/index.rs @@ -15,6 +15,7 @@ use roaring::RoaringBitmap; use uuid::Uuid; use super::pb; +use lance_core::cache::{CacheEntryReader, CacheEntryWriter}; use lance_core::{Error, Result}; /// Metadata about a single file within an index segment. @@ -235,24 +236,26 @@ impl From<&IndexMetadata> for pb::IndexMetadata { /// orphan rule prevents `impl CacheCodecImpl for Vec`. type ArcAny = Arc; +/// Stable type identifier for the `Vec` cache entry. +const INDEX_METADATA_TYPE_ID: &str = "lance.table.IndexMetadataList"; +/// Body schema version written by this build. +const INDEX_METADATA_VERSION: u32 = 1; + fn serialize_index_metadata( any: &ArcAny, - writer: &mut dyn std::io::Write, + writer: &mut CacheEntryWriter<'_>, ) -> lance_core::Result<()> { - use prost::Message; let vec = any .downcast_ref::>() .expect("index_metadata_codec: wrong type (this is a bug in the cache layer)"); let section = pb::IndexSection { indices: vec.iter().map(pb::IndexMetadata::from).collect(), }; - writer.write_all(§ion.encode_to_vec())?; - Ok(()) + writer.write_header(§ion) } -fn deserialize_index_metadata(data: &bytes::Bytes) -> lance_core::Result { - use prost::Message; - let section = pb::IndexSection::decode(data.as_ref())?; +fn deserialize_index_metadata(reader: &mut CacheEntryReader<'_>) -> lance_core::Result { + let section: pb::IndexSection = reader.read_header()?; let indices: Vec = section .indices .into_iter() @@ -262,7 +265,12 @@ fn deserialize_index_metadata(data: &bytes::Bytes) -> lance_core::Result } pub fn index_metadata_codec() -> lance_core::cache::CacheCodec { - lance_core::cache::CacheCodec::new(serialize_index_metadata, deserialize_index_metadata) + lance_core::cache::CacheCodec::new( + INDEX_METADATA_TYPE_ID, + INDEX_METADATA_VERSION, + serialize_index_metadata, + deserialize_index_metadata, + ) } /// List all files in an index directory with their sizes. @@ -348,7 +356,8 @@ mod tests { let bytes = store.get(&key).unwrap(); let recovered = codec .deserialize(&bytes::Bytes::copy_from_slice(bytes)) - .unwrap(); + .hit() + .expect("entry should decode as a hit"); let recovered = recovered .downcast::>() .expect("downcast should succeed"); diff --git a/rust/lance/src/dataset/tests/dataset_index.rs b/rust/lance/src/dataset/tests/dataset_index.rs index beb6e2b99fd..d5c4493c8a8 100644 --- a/rust/lance/src/dataset/tests/dataset_index.rs +++ b/rust/lance/src/dataset/tests/dataset_index.rs @@ -2078,11 +2078,7 @@ mod fts_serializing_backend { ) -> Option { let guard = self.serialized.lock().await; if let Some((bytes, stored_codec, _)) = guard.get(key) { - return Some( - stored_codec - .deserialize(&bytes.clone()) - .expect("deserialization should succeed"), - ); + return stored_codec.deserialize(&bytes.clone()).hit(); } drop(guard); self.passthrough.get(key, codec).await diff --git a/rust/lance/src/index/vector/ivf/partition_serde.rs b/rust/lance/src/index/vector/ivf/partition_serde.rs index 83ced18c598..ad737620a94 100644 --- a/rust/lance/src/index/vector/ivf/partition_serde.rs +++ b/rust/lance/src/index/vector/ivf/partition_serde.rs @@ -3,32 +3,17 @@ //! Serialization and zero-copy deserialization for IVF partition cache entries. //! -//! The format is: -//! -//! ```text -//! [header_len: u64 LE] -//! [header: JSON bytes] -//! [sub_index Arrow IPC stream] -//! [... quantizer-specific IPC streams ...] -//! [storage Arrow IPC stream] -//! ``` -//! -//! Each IPC section is a self-delimiting Arrow IPC stream (schema + batches + EOS -//! marker), written directly to the underlying writer without buffering. On -//! deserialization, each message is read into a per-message buffer and zero-copy -//! decoded via [`lance_arrow::ipc`]. +//! Each entry is a protobuf header (see `lance-index/protos-cache/cache.proto`, with the +//! distance and rotation types as proto enums) followed by 64-byte-aligned +//! Arrow IPC sections in a fixed, version-keyed order: the sub-index, then any +//! quantizer-specific arrays (PQ codebook, RabitQ Matrix rotation), then the +//! quantizer storage batches. Sections decode zero-copy via [`lance_arrow::ipc`]. -use std::io::Write; use std::sync::Arc; use arrow_array::{FixedSizeListArray, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use bytes::Bytes; -use lance_arrow::ipc::{ - read_ipc_stream_at, read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream, - write_ipc_stream_batches, write_len_prefixed_bytes, -}; -use lance_core::cache::CacheCodecImpl; +use lance_core::cache::{CacheCodecImpl, CacheEntryReader, CacheEntryWriter}; use lance_core::{Error, Result}; use lance_index::vector::bq::RQRotationType; use lance_index::vector::bq::builder::RabitQuantizer; @@ -38,11 +23,15 @@ use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::pq::storage::ProductQuantizationMetadata; use lance_index::vector::quantizer::{Quantization, QuantizerStorage}; use lance_index::vector::sq::ScalarQuantizer; -use lance_index::vector::sq::storage::ScalarQuantizationMetadata; use lance_index::vector::storage::VectorStore; use lance_index::vector::v3::subindex::IvfSubIndex; use lance_linalg::distance::DistanceType; -use serde::{Deserialize, Serialize}; + +use lance_index::cache_pb::{ + DistanceType as PbDistanceType, FlatPartitionHeader, PqPartitionHeader, RabitPartitionHeader, + RabitQueryEstimator as PbRabitQueryEstimator, RotationType as PbRotationType, + SqPartitionHeader, +}; use super::v2::PartitionEntry; @@ -68,7 +57,7 @@ type ArcAny = Arc; fn serialize_partition_entry( any: &ArcAny, - writer: &mut dyn Write, + writer: &mut CacheEntryWriter<'_>, ) -> lance_core::Result<()> where S: IvfSubIndex + 'static, @@ -81,14 +70,16 @@ where concrete.serialize(writer) } -fn deserialize_partition_entry(data: &Bytes) -> lance_core::Result +fn deserialize_partition_entry( + reader: &mut CacheEntryReader<'_>, +) -> lance_core::Result where S: IvfSubIndex + 'static, Q: Quantization + 'static, Concrete: Quantization + 'static, PartitionEntry: CacheCodecImpl, { - let concrete = PartitionEntry::::deserialize(data)?; + let concrete = PartitionEntry::::deserialize(reader)?; let any: ArcAny = Arc::new(concrete); Ok(any .downcast::>() @@ -109,6 +100,8 @@ where PartitionEntry: CacheCodecImpl, { lance_core::cache::CacheCodec::new( + as CacheCodecImpl>::TYPE_ID, + as CacheCodecImpl>::CURRENT_VERSION, serialize_partition_entry::, deserialize_partition_entry::, ) @@ -118,51 +111,64 @@ where // Common helpers // --------------------------------------------------------------------------- -fn distance_type_to_u8(dt: DistanceType) -> u8 { +// Distance and rotation discriminants travel as proto enums in the header; +// these map to/from the in-memory Rust enums. + +fn distance_type_to_proto(dt: DistanceType) -> PbDistanceType { + match dt { + DistanceType::L2 => PbDistanceType::L2, + DistanceType::Cosine => PbDistanceType::Cosine, + DistanceType::Dot => PbDistanceType::Dot, + DistanceType::Hamming => PbDistanceType::Hamming, + } +} + +fn proto_to_distance_type(dt: PbDistanceType) -> DistanceType { match dt { - DistanceType::L2 => 0, - DistanceType::Cosine => 1, - DistanceType::Dot => 2, - DistanceType::Hamming => 3, + PbDistanceType::L2 => DistanceType::L2, + PbDistanceType::Cosine => DistanceType::Cosine, + PbDistanceType::Dot => DistanceType::Dot, + PbDistanceType::Hamming => DistanceType::Hamming, } } -fn u8_to_distance_type(v: u8) -> Result { - match v { - 0 => Ok(DistanceType::L2), - 1 => Ok(DistanceType::Cosine), - 2 => Ok(DistanceType::Dot), - 3 => Ok(DistanceType::Hamming), - _ => Err(Error::io(format!("unknown distance type: {v}"))), +fn rotation_type_to_proto(rt: RQRotationType) -> PbRotationType { + match rt { + RQRotationType::Matrix => PbRotationType::Matrix, + RQRotationType::Fast => PbRotationType::Fast, } } -fn rotation_type_to_u8(rt: RQRotationType) -> u8 { +fn proto_to_rotation_type(rt: PbRotationType) -> RQRotationType { match rt { - RQRotationType::Matrix => 0, - RQRotationType::Fast => 1, + PbRotationType::Matrix => RQRotationType::Matrix, + PbRotationType::Fast => RQRotationType::Fast, } } -fn u8_to_rotation_type(v: u8) -> Result { - match v { - 0 => Ok(RQRotationType::Matrix), - 1 => Ok(RQRotationType::Fast), - _ => Err(Error::io(format!("unknown rotation type: {v}"))), +fn query_estimator_to_proto(qe: RabitQueryEstimator) -> PbRabitQueryEstimator { + match qe { + RabitQueryEstimator::ResidualQuery => PbRabitQueryEstimator::ResidualQuery, + RabitQueryEstimator::RawQuery => PbRabitQueryEstimator::RawQuery, } } -/// Write a JSON-serializable header using [`write_len_prefixed_bytes`]. -fn write_json_header(writer: &mut dyn Write, header: &impl Serialize) -> Result<()> { - let header_json = serde_json::to_vec(header)?; - write_len_prefixed_bytes(writer, &header_json)?; - Ok(()) +fn proto_to_query_estimator(qe: PbRabitQueryEstimator) -> RabitQueryEstimator { + match qe { + PbRabitQueryEstimator::ResidualQuery => RabitQueryEstimator::ResidualQuery, + PbRabitQueryEstimator::RawQuery => RabitQueryEstimator::RawQuery, + } } -/// Read a JSON header written by [`write_json_header`]. -fn read_json_header(data: &Bytes, offset: &mut usize) -> Result { - let bytes = read_len_prefixed_bytes_at(data, offset).map_err(|e| Error::io(e.to_string()))?; - serde_json::from_slice(&bytes).map_err(|e| Error::io(e.to_string())) +/// Read a storage section expected to hold exactly one batch. +fn read_single_storage_batch(r: &mut CacheEntryReader<'_>) -> Result { + let mut batches = r.read_ipc_batches()?; + match batches.len() { + 1 => Ok(batches.remove(0)), + n => Err(Error::io(format!( + "expected exactly 1 storage batch, got {n}" + ))), + } } /// Wrap a `FixedSizeListArray` in a single-column `RecordBatch` with the given @@ -202,17 +208,11 @@ fn batch_to_codebook(batch: &RecordBatch) -> Result { // PQ // --------------------------------------------------------------------------- -#[derive(Serialize, Deserialize)] -struct PqPartitionHeader { - distance_type: u8, - nbits: u32, - num_sub_vectors: usize, - dimension: usize, - transposed: bool, -} - impl CacheCodecImpl for PartitionEntry { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { + const TYPE_ID: &'static str = "lance.vector.ivf.PartitionEntry.PQ"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { let metadata = self.storage.metadata(); let distance_type = self.storage.distance_type(); @@ -221,32 +221,28 @@ impl CacheCodecImpl for PartitionEntry { })?; let header = PqPartitionHeader { - distance_type: distance_type_to_u8(distance_type), + distance_type: distance_type_to_proto(distance_type) as i32, nbits: metadata.nbits, - num_sub_vectors: metadata.num_sub_vectors, - dimension: metadata.dimension, + num_sub_vectors: metadata.num_sub_vectors as u64, + dimension: metadata.dimension as u64, transposed: metadata.transposed, }; - write_json_header(writer, &header)?; - write_ipc_stream(&self.index.to_batch()?, writer)?; - write_ipc_stream(&codebook_to_batch(codebook)?, writer)?; - write_ipc_stream_batches(self.storage.to_batches()?, writer)?; + w.write_header(&header)?; + w.write_ipc(&self.index.to_batch()?)?; + w.write_ipc(&codebook_to_batch(codebook)?)?; + w.write_ipc_batches(self.storage.to_batches()?)?; Ok(()) } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let header: PqPartitionHeader = read_json_header(data, &mut offset)?; - let distance_type = u8_to_distance_type(header.distance_type)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: PqPartitionHeader = r.read_header()?; + let distance_type = proto_to_distance_type(header.distance_type()); - let sub_index_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; - let codebook_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; - let storage_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; + let sub_index_batch = r.read_ipc()?; + let codebook_batch = r.read_ipc()?; + let storage_batch = read_single_storage_batch(r)?; let index = S::load(sub_index_batch)?; let codebook = batch_to_codebook(&codebook_batch)?; @@ -254,8 +250,8 @@ impl CacheCodecImpl for PartitionEntry { let metadata = ProductQuantizationMetadata { codebook_position: 0, nbits: header.nbits, - num_sub_vectors: header.num_sub_vectors, - dimension: header.dimension, + num_sub_vectors: header.num_sub_vectors as usize, + dimension: header.dimension as usize, codebook: Some(codebook), codebook_tensor: Vec::new(), transposed: header.transposed, @@ -276,41 +272,35 @@ impl CacheCodecImpl for PartitionEntry { // Flat (Float32) // --------------------------------------------------------------------------- -#[derive(Serialize, Deserialize)] -struct FlatPartitionHeader { - distance_type: u8, - dim: usize, -} - impl CacheCodecImpl for PartitionEntry { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - let metadata = self.storage.metadata(); - let distance_type = self.storage.distance_type(); + const TYPE_ID: &'static str = "lance.vector.ivf.PartitionEntry.Flat"; + const CURRENT_VERSION: u32 = 1; + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let metadata = self.storage.metadata(); let header = FlatPartitionHeader { - distance_type: distance_type_to_u8(distance_type), - dim: metadata.dim, + distance_type: distance_type_to_proto(self.storage.distance_type()) as i32, + dim: metadata.dim as u64, }; - write_json_header(writer, &header)?; - write_ipc_stream(&self.index.to_batch()?, writer)?; - write_ipc_stream_batches(self.storage.to_batches()?, writer)?; + w.write_header(&header)?; + w.write_ipc(&self.index.to_batch()?)?; + w.write_ipc_batches(self.storage.to_batches()?)?; Ok(()) } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let header: FlatPartitionHeader = read_json_header(data, &mut offset)?; - let distance_type = u8_to_distance_type(header.distance_type)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: FlatPartitionHeader = r.read_header()?; + let distance_type = proto_to_distance_type(header.distance_type()); - let sub_index_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; - let storage_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; + let sub_index_batch = r.read_ipc()?; + let storage_batch = read_single_storage_batch(r)?; let index = S::load(sub_index_batch)?; - let metadata = FlatMetadata { dim: header.dim }; + let metadata = FlatMetadata { + dim: header.dim as usize, + }; let storage = ::Storage::try_from_batch( storage_batch, &metadata, @@ -327,34 +317,34 @@ impl CacheCodecImpl for PartitionEntry { // --------------------------------------------------------------------------- impl CacheCodecImpl for PartitionEntry { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - let metadata = self.storage.metadata(); - let distance_type = self.storage.distance_type(); + const TYPE_ID: &'static str = "lance.vector.ivf.PartitionEntry.FlatBin"; + const CURRENT_VERSION: u32 = 1; + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let metadata = self.storage.metadata(); let header = FlatPartitionHeader { - distance_type: distance_type_to_u8(distance_type), - dim: metadata.dim, + distance_type: distance_type_to_proto(self.storage.distance_type()) as i32, + dim: metadata.dim as u64, }; - write_json_header(writer, &header)?; - write_ipc_stream(&self.index.to_batch()?, writer)?; - write_ipc_stream_batches(self.storage.to_batches()?, writer)?; + w.write_header(&header)?; + w.write_ipc(&self.index.to_batch()?)?; + w.write_ipc_batches(self.storage.to_batches()?)?; Ok(()) } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let header: FlatPartitionHeader = read_json_header(data, &mut offset)?; - let distance_type = u8_to_distance_type(header.distance_type)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: FlatPartitionHeader = r.read_header()?; + let distance_type = proto_to_distance_type(header.distance_type()); - let sub_index_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; - let storage_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; + let sub_index_batch = r.read_ipc()?; + let storage_batch = read_single_storage_batch(r)?; let index = S::load(sub_index_batch)?; - let metadata = FlatMetadata { dim: header.dim }; + let metadata = FlatMetadata { + dim: header.dim as usize, + }; let storage = ::Storage::try_from_batch( storage_batch, &metadata, @@ -370,56 +360,41 @@ impl CacheCodecImpl for PartitionEntry { // SQ // --------------------------------------------------------------------------- -#[derive(Serialize, Deserialize)] -struct SqPartitionHeader { - distance_type: u8, - num_bits: u16, - dim: usize, - bounds_start: f64, - bounds_end: f64, -} - impl CacheCodecImpl for PartitionEntry { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - let metadata = self.storage.metadata(); - let distance_type = self.storage.distance_type(); + const TYPE_ID: &'static str = "lance.vector.ivf.PartitionEntry.SQ"; + const CURRENT_VERSION: u32 = 1; + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let metadata = self.storage.metadata(); let header = SqPartitionHeader { - distance_type: distance_type_to_u8(distance_type), - num_bits: metadata.num_bits, - dim: metadata.dim, + distance_type: distance_type_to_proto(self.storage.distance_type()) as i32, + num_bits: metadata.num_bits as u32, + dim: metadata.dim as u64, bounds_start: metadata.bounds.start, bounds_end: metadata.bounds.end, }; - write_json_header(writer, &header)?; - write_ipc_stream(&self.index.to_batch()?, writer)?; - // SQ storage may contain multiple batches; stream them all in one IPC stream. - write_ipc_stream_batches(self.storage.to_batches()?, writer)?; + w.write_header(&header)?; + w.write_ipc(&self.index.to_batch()?)?; + // SQ storage may contain multiple batches; write them all in one section. + w.write_ipc_batches(self.storage.to_batches()?)?; Ok(()) } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let header: SqPartitionHeader = read_json_header(data, &mut offset)?; - let distance_type = u8_to_distance_type(header.distance_type)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: SqPartitionHeader = r.read_header()?; + let distance_type = proto_to_distance_type(header.distance_type()); - let sub_index_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; - let storage_batches = - read_ipc_stream_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; + let sub_index_batch = r.read_ipc()?; + let storage_batches = r.read_ipc_batches()?; let index = S::load(sub_index_batch)?; - let metadata = ScalarQuantizationMetadata { - dim: header.dim, - num_bits: header.num_bits, - bounds: header.bounds_start..header.bounds_end, - }; + let num_bits = header.num_bits as u16; let storage = ::Storage::try_new( - metadata.num_bits, + num_bits, distance_type, - metadata.bounds, + header.bounds_start..header.bounds_end, storage_batches, None, )?; @@ -432,88 +407,69 @@ impl CacheCodecImpl for PartitionEntry { // RabitQ // --------------------------------------------------------------------------- -#[derive(Serialize, Deserialize)] -struct RabitPartitionHeader { - distance_type: u8, - num_bits: u8, - code_dim: u32, - #[serde(default = "default_rabit_query_estimator")] - query_estimator: RabitQueryEstimator, - /// 0 = Matrix, 1 = Fast - rotation_type: u8, - /// Fast rotation signs (only set when rotation_type == Fast). - fast_rotation_signs: Option>, -} - -fn default_rabit_query_estimator() -> RabitQueryEstimator { - RabitQueryEstimator::ResidualQuery -} - impl CacheCodecImpl for PartitionEntry { - fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - let metadata = self.storage.metadata(); - let distance_type = self.storage.distance_type(); + const TYPE_ID: &'static str = "lance.vector.ivf.PartitionEntry.Rabit"; + const CURRENT_VERSION: u32 = 1; + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let metadata = self.storage.metadata(); let header = RabitPartitionHeader { - distance_type: distance_type_to_u8(distance_type), - num_bits: metadata.num_bits, + distance_type: distance_type_to_proto(self.storage.distance_type()) as i32, + num_bits: metadata.num_bits as u32, code_dim: metadata.code_dim, - query_estimator: metadata.query_estimator, - rotation_type: rotation_type_to_u8(metadata.rotation_type), + rotation_type: rotation_type_to_proto(metadata.rotation_type) as i32, + query_estimator: query_estimator_to_proto(metadata.query_estimator) as i32, fast_rotation_signs: metadata.fast_rotation_signs.clone(), }; - write_json_header(writer, &header)?; + w.write_header(&header)?; + w.write_ipc(&self.index.to_batch()?)?; - write_ipc_stream(&self.index.to_batch()?, writer)?; - - // Write the rotation matrix IPC stream only for Matrix rotation; the - // Fast rotation case stores its signs compactly in the JSON header. + // Write the rotation matrix IPC section only for Matrix rotation; the + // Fast rotation case stores its signs compactly in the proto header. if metadata.rotation_type == RQRotationType::Matrix { let mat = metadata.rotate_mat.as_ref().ok_or_else(|| { Error::io( "RabitQ Matrix metadata missing rotate_mat during serialization".to_string(), ) })?; - write_ipc_stream(&fsl_to_batch(mat, "rotate_mat")?, writer)?; + w.write_ipc(&fsl_to_batch(mat, "rotate_mat")?)?; } - write_ipc_stream_batches(self.storage.to_batches()?, writer)?; + w.write_ipc_batches(self.storage.to_batches()?)?; Ok(()) } - fn deserialize(data: &Bytes) -> Result { - let mut offset = 0; - let header: RabitPartitionHeader = read_json_header(data, &mut offset)?; - let distance_type = u8_to_distance_type(header.distance_type)?; - let rotation_type = u8_to_rotation_type(header.rotation_type)?; + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: RabitPartitionHeader = r.read_header()?; + let distance_type = proto_to_distance_type(header.distance_type()); + let rotation_type = proto_to_rotation_type(header.rotation_type()); - let sub_index_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; + let sub_index_batch = r.read_ipc()?; let rotate_mat = if rotation_type == RQRotationType::Matrix { - let mat_batch = read_ipc_stream_single_at(data, &mut offset) - .map_err(|e| Error::io(e.to_string()))?; + let mat_batch = r.read_ipc()?; Some(batch_to_fsl(&mat_batch)?) } else { None }; - let storage_batch = - read_ipc_stream_single_at(data, &mut offset).map_err(|e| Error::io(e.to_string()))?; + let storage_batch = read_single_storage_batch(r)?; let index = S::load(sub_index_batch)?; + // Read the proto enum accessor before moving fields out of `header`. + let query_estimator = proto_to_query_estimator(header.query_estimator()); let metadata = RabitQuantizationMetadata { rotate_mat, rotate_mat_position: None, fast_rotation_signs: header.fast_rotation_signs, rotation_type, code_dim: header.code_dim, - num_bits: header.num_bits, + num_bits: header.num_bits as u8, // The storage batch already has packed codes; skip re-packing. packed: true, - query_estimator: header.query_estimator, + query_estimator, }; let storage = ::Storage::try_from_batch( storage_batch, @@ -551,6 +507,21 @@ mod tests { use lance_index::vector::flat::storage::FlatFloatStorage; use lance_index::vector::sq::storage::ScalarQuantizationStorage; + /// Serialize a codec body (no envelope) for tests. + fn ser_body(entry: &T) -> Vec { + let mut buf = Vec::new(); + entry + .serialize(&mut CacheEntryWriter::new(&mut buf)) + .unwrap(); + buf + } + + /// Deserialize a codec body (no envelope) at the current build's version. + fn de_body(bytes: Vec) -> Result { + let data = bytes::Bytes::from(bytes); + T::deserialize(&mut CacheEntryReader::new(&data, 0, T::CURRENT_VERSION)) + } + // ----- PQ helpers ------------------------------------------------------- fn make_test_codebook(dim: usize, num_sub_vectors: usize) -> FixedSizeListArray { @@ -618,12 +589,9 @@ mod tests { storage, }; - let mut serialized = Vec::new(); - entry.serialize(&mut serialized).unwrap(); - let deserialized = PartitionEntry::::deserialize( - &bytes::Bytes::from(serialized), - ) - .unwrap(); + let serialized = ser_body(&entry); + let deserialized = + de_body::>(serialized).unwrap(); assert_eq!(entry.storage, deserialized.storage); } @@ -671,12 +639,8 @@ mod tests { storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = PartitionEntry::::deserialize( - &bytes::Bytes::from(bytes), - ) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); assert_eq!( restored.storage.distance_type(), entry.storage.distance_type() @@ -694,12 +658,9 @@ mod tests { storage, }; - let mut serialized = Vec::new(); - entry.serialize(&mut serialized).unwrap(); - let deserialized = PartitionEntry::::deserialize( - &bytes::Bytes::from(serialized), - ) - .unwrap(); + let serialized = ser_body(&entry); + let deserialized = + de_body::>(serialized).unwrap(); assert_eq!(entry.storage, deserialized.storage); } @@ -712,13 +673,9 @@ mod tests { index: FlatIndex::default(), storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); + let mut bytes = ser_body(&entry); bytes.truncate(3); - assert!( - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .is_err() - ); + assert!(de_body::>(bytes).is_err()); } // ----- Flat helpers ----------------------------------------------------- @@ -756,11 +713,8 @@ mod tests { storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); assert_eq!( restored.storage.metadata().dim, @@ -786,11 +740,8 @@ mod tests { index: FlatIndex::default(), storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); assert_eq!(restored.storage.distance_type(), dt); } } @@ -803,11 +754,8 @@ mod tests { storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); let restored_batch = restored.storage.to_batches().unwrap().next().unwrap(); let schema = restored_batch.schema(); @@ -828,11 +776,8 @@ mod tests { storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); let restored_batch = restored.storage.to_batches().unwrap().next().unwrap(); let schema = restored_batch.schema(); @@ -884,11 +829,8 @@ mod tests { storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); let m = entry.storage.metadata(); let rm = restored.storage.metadata(); @@ -914,12 +856,8 @@ mod tests { index: FlatIndex::default(), storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = PartitionEntry::::deserialize( - &bytes::Bytes::from(bytes), - ) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); assert_eq!(restored.storage.distance_type(), dt); } } @@ -960,11 +898,8 @@ mod tests { index: FlatIndex::default(), storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); assert_eq!(restored.storage.len(), 30); let orig_ids: Vec = entry.storage.row_ids().copied().collect(); @@ -978,14 +913,27 @@ mod tests { num_rows: usize, code_dim: usize, distance_type: DistanceType, + ) -> ::Storage { + make_rabit_storage( + num_rows, + code_dim, + distance_type, + RQRotationType::Fast, + RabitQueryEstimator::ResidualQuery, + ) + } + + fn make_rabit_storage( + num_rows: usize, + code_dim: usize, + distance_type: DistanceType, + rotation_type: RQRotationType, + query_estimator: RabitQueryEstimator, ) -> ::Storage { use lance_arrow::FixedSizeListArrayExt; - let quantizer = RabitQuantizer::new_with_rotation::( - 1, - code_dim as i32, - RQRotationType::Fast, - ); + let quantizer = + RabitQuantizer::new_with_rotation::(1, code_dim as i32, rotation_type); let values: Vec = (0..num_rows * code_dim) .map(|i| (i % 100) as f32 / 100.0 - 0.5) .collect(); @@ -997,7 +945,8 @@ mod tests { .as_fixed_size_list() .clone(); - let metadata = quantizer.metadata(None); + let mut metadata = quantizer.metadata(None); + metadata.query_estimator = query_estimator; let batch = RecordBatch::try_from_iter(vec![ ( lance_core::ROW_ID, @@ -1044,11 +993,8 @@ mod tests { storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = - PartitionEntry::::deserialize(&bytes::Bytes::from(bytes)) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); let m = entry.storage.metadata(); let rm = restored.storage.metadata(); @@ -1082,22 +1028,125 @@ mod tests { fn test_rabitq_distance_types() { for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot] { let storage = make_rabit_storage_fast(10, 32, dt); - let expected_distance_type = if dt == DistanceType::Cosine { - DistanceType::L2 - } else { - dt - }; let entry = PartitionEntry:: { index: FlatIndex::default(), storage, }; - let mut bytes = Vec::new(); - entry.serialize(&mut bytes).unwrap(); - let restored = PartitionEntry::::deserialize( - &bytes::Bytes::from(bytes), - ) + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); + // The codec round-trips the distance type faithfully. + assert_eq!( + restored.storage.distance_type(), + entry.storage.distance_type() + ); + } + } + + #[test] + fn test_roundtrip_rabitq_raw_query_estimator() { + // The query estimator is a non-default value here; it must survive the + // round trip so raw-query search keeps working after a cache reload. + let storage = make_rabit_storage( + 40, + 32, + DistanceType::L2, + RQRotationType::Fast, + RabitQueryEstimator::RawQuery, + ); + assert_eq!( + storage.metadata().query_estimator, + RabitQueryEstimator::RawQuery + ); + let entry = PartitionEntry:: { + index: FlatIndex::default(), + storage, + }; + + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); + assert_eq!( + restored.storage.metadata().query_estimator, + RabitQueryEstimator::RawQuery + ); + } + + /// Matrix rotation writes an extra `rotate_mat` IPC section between the + /// sub-index and storage sections; exercise that the codec preserves it. + #[test] + fn test_roundtrip_flat_rabitq_matrix() { + let storage = make_rabit_storage( + 40, + 32, + DistanceType::L2, + RQRotationType::Matrix, + RabitQueryEstimator::ResidualQuery, + ); + let entry = PartitionEntry:: { + index: FlatIndex::default(), + storage, + }; + + let bytes = ser_body(&entry); + let restored = de_body::>(bytes).unwrap(); + + let m = entry.storage.metadata(); + let rm = restored.storage.metadata(); + assert_eq!(rm.rotation_type, RQRotationType::Matrix); + assert_eq!(rm.code_dim, m.code_dim); + assert_eq!(rm.num_bits, m.num_bits); + // The rotation matrix itself must survive the round trip. + let orig_mat = m + .rotate_mat + .as_ref() + .expect("matrix rotation has rotate_mat"); + let rest_mat = rm + .rotate_mat + .as_ref() + .expect("restored matrix rotation has rotate_mat"); + assert_eq!( + orig_mat.values().as_primitive::().values(), + rest_mat.values().as_primitive::().values(), + ); + } + + /// SQ storage (a multi-batch IPC section) must decode zero-copy through the + /// full envelope even though the proto header and sub-index section push it + /// to a non-aligned starting offset. + #[test] + fn test_partition_storage_is_zero_copy_through_envelope() { + use lance_core::cache::CacheCodec; + const ALIGN: usize = 64; + + let entry = PartitionEntry:: { + index: FlatIndex::default(), + storage: make_sq_storage(64, 32, DistanceType::L2), + }; + let codec = CacheCodec::from_impl::>(); + let any: Arc = Arc::new(entry); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + + let mut v = vec![0u8; buf.len() + ALIGN]; + let pad = (ALIGN - (v.as_ptr() as usize % ALIGN)) % ALIGN; + v[pad..pad + buf.len()].copy_from_slice(&buf); + let data = bytes::Bytes::from(v).slice(pad..pad + buf.len()); + + let restored = codec.deserialize(&data).hit().unwrap(); + let restored = restored + .downcast::>() .unwrap(); - assert_eq!(restored.storage.distance_type(), expected_distance_type); + + let base = data.as_ptr() as usize; + let end = base + data.len(); + let first = restored.storage.to_batches().unwrap().next().unwrap(); + for col in first.columns() { + for buffer in col.to_data().buffers() { + let ptr = buffer.as_ptr() as usize; + assert!( + ptr >= base && ptr < end, + "storage buffer was realigned out of the input — misaligned IPC section", + ); + } } } @@ -1135,17 +1184,12 @@ mod tests { let entry = IvfStateEntryBox(Arc::new(state)); - let mut bytes = Vec::new(); - CacheCodecImpl::serialize(&entry, &mut bytes).unwrap(); - - let restored = - ::deserialize(&bytes::Bytes::from(bytes.clone())) - .unwrap(); + let bytes = ser_body(&entry); + let restored = de_body::(bytes.clone()).unwrap(); // Re-serialize the restored entry and compare bytes — a stronger check // than field-by-field comparison and avoids needing to downcast. - let mut restored_bytes = Vec::new(); - CacheCodecImpl::serialize(&restored, &mut restored_bytes).unwrap(); + let restored_bytes = ser_body(&restored); assert_eq!(bytes, restored_bytes); } } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 29d9e224970..45002e93c15 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -3,7 +3,6 @@ //! IVF - Inverted File index. -use std::io::Write as IoWrite; use std::marker::PhantomData; use std::{ any::Any, @@ -26,8 +25,10 @@ use futures::future::BoxFuture; use futures::prelude::stream::{self, TryStreamExt}; use futures::{StreamExt, TryFutureExt}; use lance_arrow::RecordBatchExt; -use lance_arrow::ipc::write_len_prefixed_bytes; -use lance_core::cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}; +use lance_core::cache::{ + CacheCodec, CacheCodecImpl, CacheEntryReader, CacheEntryWriter, CacheKey, LanceCache, + WeakLanceCache, +}; use lance_core::deepsize::DeepSizeOf; use lance_core::utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu}; use lance_core::utils::tracing::{IO_TYPE_LOAD_VECTOR_PART, TRACE_IO_EVENTS}; @@ -35,6 +36,7 @@ use lance_core::{Error, ROW_ID, Result}; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::LanceEncodingsIo; use lance_file::reader::{CachedFileMetadata, FileReader, FileReaderOptions}; +use lance_index::cache_pb::IvfStateHeader; use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; @@ -214,28 +216,6 @@ impl DeepSizeOf for IvfIndexState { } } -/// Serialization header for the `IvfIndexState` wire format. -/// -/// Kept as a flat, non-generic struct so the JSON header format is stable -/// regardless of `Q`. `quantizer_metadata_json` holds the serialized -/// `Q::Metadata`; large blobs (PQ codebook, RQ matrix) follow as raw bytes. -#[derive(serde::Serialize, serde::Deserialize)] -struct IvfIndexStateHeader { - index_file_path: String, - uuid: String, - distance_type: String, - sub_index_metadata: Vec, - sub_index_type: String, - quantization_type: String, - quantizer_metadata_json: String, - #[serde(default)] - cache_key_prefix: String, - #[serde(default)] - index_file_size: u64, - #[serde(default)] - aux_file_size: u64, -} - /// Object-safe interface for a type-erased `IvfIndexState`. /// /// Stored as `Arc` inside [`IvfStateEntryBox`], which is @@ -243,7 +223,7 @@ struct IvfIndexStateHeader { /// wrapper lets the cache infrastructure work with a sized type while the /// hot paths call `reconstruct` without knowing `Q`. pub(crate) trait IvfStateEntry: DeepSizeOf + Send + Sync + 'static { - fn serialize_state(&self, writer: &mut dyn IoWrite) -> Result<()>; + fn serialize_state(&self, w: &mut CacheEntryWriter<'_>) -> Result<()>; fn reconstruct<'a>( &'a self, @@ -267,42 +247,39 @@ impl DeepSizeOf for IvfStateEntryBox { } } -/// Wire format (unchanged from the non-generic `IvfIndexState`): -/// `[header_json_len: u64 LE][header JSON][ivf_pb_len: u64 LE][ivf protobuf] -/// [extra_len: u64 LE][extra bytes][aux_ivf_pb_len: u64 LE][aux_ivf protobuf]` +/// Wire format: +/// ```text +/// HEADER : IvfStateHeader proto (paths, types, quantizer metadata JSON) +/// RAW_BLOB : IVF model protobuf +/// RAW_BLOB : quantizer extra-metadata buffer (may be empty) +/// RAW_BLOB : auxiliary IVF model protobuf +/// ``` impl CacheCodecImpl for IvfStateEntryBox { - fn serialize(&self, writer: &mut dyn IoWrite) -> Result<()> { - self.0.serialize_state(writer) - } + const TYPE_ID: &'static str = "lance.vector.ivf.IvfState"; + const CURRENT_VERSION: u32 = 1; - fn deserialize(data: &bytes::Bytes) -> Result { - use lance_arrow::ipc::read_len_prefixed_bytes_at; + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + self.0.serialize_state(w) + } - // Parse the common wire format, then dispatch on quantization_type to + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + // Parse the common header, then dispatch on quantization_type to // construct the right IvfIndexState. - let mut offset = 0; - let header_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; - let header: IvfIndexStateHeader = serde_json::from_slice(&header_bytes) - .map_err(|e| lance_core::Error::io(format!("IvfIndexState header: {e}")))?; + let header: IvfStateHeader = r.read_header()?; - let ivf_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let ivf_bytes = r.read_raw()?; let ivf = IvfModel::try_from( pb::Ivf::decode(ivf_bytes.as_ref()) .map_err(|e| lance_core::Error::io(format!("IvfIndexState IVF decode: {e}")))?, )?; - let extra_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let extra_bytes = r.read_raw()?; - // aux_ivf was added after initial deployment; fall back to ivf on - // clean EOF (legacy format without the field). - let aux_ivf = if offset + 8 <= data.len() { - let aux_ivf_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let aux_ivf_bytes = r.read_raw()?; + let aux_ivf = IvfModel::try_from(pb::Ivf::decode(aux_ivf_bytes.as_ref()).map_err(|e| { lance_core::Error::io(format!("IvfIndexState aux IVF decode: {e}")) - })?)? - } else { - ivf.clone() - }; + })?)?; let distance_type = DistanceType::try_from(header.distance_type.as_str())?; let sub_index_type = SubIndexType::try_from(header.sub_index_type.as_str())?; @@ -311,7 +288,7 @@ impl CacheCodecImpl for IvfStateEntryBox { // Helper: parse Q::Metadata from the JSON+extra_bytes in the header, // then build an IvfStateEntryBox wrapping IvfIndexState. fn make_entry( - header: IvfIndexStateHeader, + header: IvfStateHeader, ivf: IvfModel, aux_ivf: IvfModel, extra_bytes: bytes::Bytes, @@ -397,13 +374,13 @@ impl CacheCodecImpl for IvfStateEntryBox { } impl IvfStateEntry for IvfIndexState { - fn serialize_state(&self, writer: &mut dyn IoWrite) -> Result<()> { + fn serialize_state(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { let quantizer_metadata_json = serde_json::to_string(&self.metadata) .map_err(|e| lance_core::Error::io(format!("IvfIndexState metadata: {e}")))?; let extra = self.metadata.extra_metadata()?; let extra = extra.as_deref().unwrap_or(&[]); - let header = IvfIndexStateHeader { + let header = IvfStateHeader { index_file_path: self.index_file_path.clone(), uuid: self.uuid.to_string(), distance_type: self.distance_type.to_string(), @@ -415,15 +392,13 @@ impl IvfStateEntry for IvfIndexState { index_file_size: self.index_file_size, aux_file_size: self.aux_file_size, }; - let header_json = serde_json::to_vec(&header) - .map_err(|e| lance_core::Error::io(format!("IvfIndexState header: {e}")))?; let ivf_bytes = pb::Ivf::try_from(&self.ivf)?.encode_to_vec(); let aux_ivf_bytes = pb::Ivf::try_from(&self.aux_ivf)?.encode_to_vec(); - write_len_prefixed_bytes(writer, &header_json)?; - write_len_prefixed_bytes(writer, &ivf_bytes)?; - write_len_prefixed_bytes(writer, extra)?; - write_len_prefixed_bytes(writer, &aux_ivf_bytes)?; + w.write_header(&header)?; + w.write_raw(&ivf_bytes)?; + w.write_raw(extra)?; + w.write_raw(&aux_ivf_bytes)?; Ok(()) } @@ -6220,11 +6195,9 @@ mod tests { // Try serialized store first let guard = self.serialized.lock().await; if let Some((bytes, stored_codec, _)) = guard.get(key) { - return Some( - stored_codec - .deserialize(&bytes::Bytes::copy_from_slice(bytes)) - .expect("deserialization should succeed"), - ); + return stored_codec + .deserialize(&bytes::Bytes::copy_from_slice(bytes)) + .hit(); } drop(guard); // Fall through to passthrough