diff --git a/CHANGELOG.md b/CHANGELOG.md index 13a5fbc..55beffd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- *(zstd)* decode and validate frames carrying a `Content_Checksum` (the `zstd` + CLI writes one by default). Previously any such frame was refused with + `Unsupported`, so default `zstd` output only decoded with `--no-check`. Adds a + streaming XXH64 implementation; the decompressed output is hashed and checked + against the 4-byte frame trailer, reporting `ChecksumMismatch` on corruption. + ### Fixed - *(decoder bridge)* a decoder that buffers a whole block internally (notably diff --git a/src/zstd/decoder.rs b/src/zstd/decoder.rs index 896528b..e76967d 100644 --- a/src/zstd/decoder.rs +++ b/src/zstd/decoder.rs @@ -4,9 +4,9 @@ //! `Compressed_Block` (Block_Type=2). See the module-level `mod.rs` docs for //! a full list of supported literal / sequence sub-modes. //! -//! The decoder also refuses frames whose Frame_Header sets the -//! `Content_Checksum_Flag` — we do not implement XXH64 in this crate, so we -//! cannot validate the trailing 4-byte checksum. +//! Frames whose Frame_Header sets the `Content_Checksum_Flag` are decoded and +//! the trailing 4-byte XXH64 checksum is validated against the decompressed +//! output (see [`crate::zstd`]). use alloc::vec::Vec; @@ -14,6 +14,7 @@ use crate::error::Error; use crate::traits::{RawDecoder, RawProgress}; use crate::zstd::literals::{LiteralsState, decode_literals}; use crate::zstd::sequences::{SequencesState, decode_sequences, execute_sequences}; +use crate::zstd::xxhash::Xxh64; const MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD]; @@ -52,9 +53,8 @@ enum DecPhase { /// Emitting the bytes decoded out of a Compressed_Block (held in /// `emit_buf`). CompressedEmit, - /// Reading 4-byte Content_Checksum trailer (only entered if we somehow - /// allowed a checksummed frame — currently we refuse such frames in - /// `Fhd`). + /// Reading and validating the 4-byte Content_Checksum trailer (entered + /// after the last block when the Frame_Header set `Content_Checksum_Flag`). ContentChecksum, /// Frame fully consumed; subsequent input is ignored (we do not handle /// concatenated frames). @@ -116,6 +116,11 @@ pub struct Decoder { /// Carry-over state for sequence FSE tables (Repeat_Mode) and the /// previous-offsets stack. seq_state: SequencesState, + + /// Running XXH64 over the decompressed output, fed at every emit site. + /// Only consulted (and the per-byte update only performed) when + /// `has_content_checksum` is set; finalized against the frame trailer. + content_hash: Xxh64, } impl Decoder { @@ -142,6 +147,7 @@ impl Decoder { history_emitted: 0, lit_state: LiteralsState::default(), seq_state: SequencesState::new(), + content_hash: Xxh64::new(), } } @@ -182,12 +188,6 @@ impl Decoder { self.single_segment = ss_flag != 0; self.has_content_checksum = cchk_flag != 0; - // We don't implement XXH64 in this build, so checksummed frames are - // unsupported (per task spec). - if self.has_content_checksum { - return Err(self.poison(Error::Unsupported)); - } - self.dict_id_field_size = match dict_id_flag { 0 => 0, 1 => 1, @@ -490,6 +490,9 @@ impl RawDecoder for Decoder { }); } output[written..written + n].copy_from_slice(&input[consumed..consumed + n]); + if self.has_content_checksum { + self.content_hash.update(&output[written..written + n]); + } // Mirror into history so subsequent Compressed_Blocks can // back-reference these bytes. self.history @@ -526,6 +529,9 @@ impl RawDecoder for Decoder { for slot in &mut output[written..written + n] { *slot = self.rle_byte; } + if self.has_content_checksum { + self.content_hash.update(&output[written..written + n]); + } // Mirror into history. for _ in 0..n { self.history.push(self.rle_byte); @@ -582,6 +588,9 @@ impl RawDecoder for Decoder { output[written..written + n].copy_from_slice( &self.history[self.history_emitted..self.history_emitted + n], ); + if self.has_content_checksum { + self.content_hash.update(&output[written..written + n]); + } self.history_emitted += n; written += n; if self.history_emitted == self.history.len() { @@ -589,8 +598,9 @@ impl RawDecoder for Decoder { } } DecPhase::ContentChecksum => { - // Currently unreachable — we reject checksummed frames - // in `parse_fhd`. Kept as a state for future XXH64 work. + // The 4-byte trailer is the low 32 bits of XXH64 over the + // decompressed content (little-endian). Validate it against + // the running hash we fed at every emit site. if !self.fill_scratch(input, &mut consumed) { return Ok(RawProgress { consumed, @@ -598,6 +608,11 @@ impl RawDecoder for Decoder { done: false, }); } + let expected = u32::from_le_bytes(self.scratch[..4].try_into().unwrap()); + let actual = self.content_hash.digest() as u32; + if expected != actual { + return Err(self.poison(Error::ChecksumMismatch)); + } self.phase = DecPhase::Done; } DecPhase::Done => { diff --git a/src/zstd/mod.rs b/src/zstd/mod.rs index 4cb9e98..9612b25 100644 --- a/src/zstd/mod.rs +++ b/src/zstd/mod.rs @@ -50,12 +50,13 @@ //! RLE_Mode for sequence FSE tables, multi-frame output, content checksum, //! or dictionaries. //! -//! # What does NOT work +//! # Content checksum //! -//! - **Content_Checksum_Flag** in the Frame_Header. The 4-byte trailer is the -//! low 32 bits of XXH64 over the decompressed data; we do not ship an -//! XXH64 implementation, so any frame that advertises a content checksum -//! is refused with [`crate::Error::Unsupported`]. +//! Frames with `Content_Checksum_Flag` set (the `zstd` CLI writes one by +//! default) are decoded and the 4-byte trailer — the low 32 bits of XXH64 over +//! the decompressed data — is validated; a mismatch is reported as +//! [`crate::Error::ChecksumMismatch`]. Our encoder does not yet emit a content +//! checksum. //! //! - **Skippable_Frame** magic numbers (`0x184D2A50..=0x184D2A5F`) are //! detected and rejected as unsupported rather than silently skipped. @@ -80,6 +81,7 @@ mod huffman; mod literals; mod matcher; mod sequences; +mod xxhash; pub use decoder::Decoder; pub use encoder::{Encoder, EncoderConfig}; diff --git a/src/zstd/xxhash.rs b/src/zstd/xxhash.rs new file mode 100644 index 0000000..9ba5c57 --- /dev/null +++ b/src/zstd/xxhash.rs @@ -0,0 +1,197 @@ +//! Streaming XXH64, the hash zstd uses for the optional frame +//! `Content_Checksum`. +//! +//! A zstd frame whose `Content_Checksum_Flag` is set (the `zstd` CLI writes one +//! by default) appends the low 32 bits of `XXH64(decompressed_content, seed=0)`, +//! little-endian, after the last block. The decoder feeds every decompressed +//! byte through [`Xxh64::update`] and compares [`Xxh64::digest`] against that +//! trailer. +//! +//! This is the canonical XXH64 (Yann Collet) with seed 0; verified against the +//! reference test vectors below and, end-to-end, against checksums produced by +//! the `zstd` CLI. + +const PRIME64_1: u64 = 0x9E37_79B1_85EB_CA87; +const PRIME64_2: u64 = 0xC2B2_AE3D_27D4_EB4F; +const PRIME64_3: u64 = 0x1656_67B1_9E37_79F9; +const PRIME64_4: u64 = 0x85EB_CA77_C2B2_AE63; +const PRIME64_5: u64 = 0x27D4_EB2F_1656_67C5; + +/// Running XXH64 state (seed fixed at 0, which is all zstd needs). +#[derive(Clone)] +pub(crate) struct Xxh64 { + /// Four parallel accumulators, used once `total_len >= 32`. + acc: [u64; 4], + /// Total bytes consumed across all `update` calls. + total_len: u64, + /// Partial stripe carried between `update` calls (`0..32` valid bytes). + buf: [u8; 32], + buf_len: usize, +} + +impl Xxh64 { + pub(crate) fn new() -> Self { + Self { + acc: [ + PRIME64_1.wrapping_add(PRIME64_2), + PRIME64_2, + 0, + 0u64.wrapping_sub(PRIME64_1), + ], + total_len: 0, + buf: [0u8; 32], + buf_len: 0, + } + } + + #[inline] + fn round(acc: u64, lane: u64) -> u64 { + acc.wrapping_add(lane.wrapping_mul(PRIME64_2)) + .rotate_left(31) + .wrapping_mul(PRIME64_1) + } + + #[inline] + fn merge_round(acc: u64, lane: u64) -> u64 { + let acc = acc ^ Self::round(0, lane); + acc.wrapping_mul(PRIME64_1).wrapping_add(PRIME64_4) + } + + #[inline] + fn read_u64(b: &[u8]) -> u64 { + u64::from_le_bytes(b[..8].try_into().unwrap()) + } + + /// Consume one full 32-byte stripe into the four accumulators. + #[inline] + fn process_stripe(acc: &mut [u64; 4], stripe: &[u8]) { + acc[0] = Self::round(acc[0], Self::read_u64(&stripe[0..8])); + acc[1] = Self::round(acc[1], Self::read_u64(&stripe[8..16])); + acc[2] = Self::round(acc[2], Self::read_u64(&stripe[16..24])); + acc[3] = Self::round(acc[3], Self::read_u64(&stripe[24..32])); + } + + /// Feed `data` into the running hash. + pub(crate) fn update(&mut self, mut data: &[u8]) { + self.total_len = self.total_len.wrapping_add(data.len() as u64); + + // Top off a partially filled stripe first. + if self.buf_len > 0 { + let need = 32 - self.buf_len; + if data.len() < need { + self.buf[self.buf_len..self.buf_len + data.len()].copy_from_slice(data); + self.buf_len += data.len(); + return; + } + let (head, rest) = data.split_at(need); + self.buf[self.buf_len..].copy_from_slice(head); + let buf = self.buf; + Self::process_stripe(&mut self.acc, &buf); + self.buf_len = 0; + data = rest; + } + + // Bulk stripes straight from the input. + let mut chunks = data.chunks_exact(32); + for stripe in &mut chunks { + Self::process_stripe(&mut self.acc, stripe); + } + + // Carry the trailing partial stripe. + let rem = chunks.remainder(); + if !rem.is_empty() { + self.buf[..rem.len()].copy_from_slice(rem); + self.buf_len = rem.len(); + } + } + + /// Finalize without disturbing the running state, returning the full 64-bit + /// digest. zstd compares the low 32 bits. + pub(crate) fn digest(&self) -> u64 { + let mut h = if self.total_len >= 32 { + let mut h = self.acc[0] + .rotate_left(1) + .wrapping_add(self.acc[1].rotate_left(7)) + .wrapping_add(self.acc[2].rotate_left(12)) + .wrapping_add(self.acc[3].rotate_left(18)); + h = Self::merge_round(h, self.acc[0]); + h = Self::merge_round(h, self.acc[1]); + h = Self::merge_round(h, self.acc[2]); + h = Self::merge_round(h, self.acc[3]); + h + } else { + // Short input: only the seed-derived constant participates. + PRIME64_5 + }; + + h = h.wrapping_add(self.total_len); + + // Consume the leftover (< 32) bytes: 8 at a time, then 4, then 1. + let mut p = &self.buf[..self.buf_len]; + while p.len() >= 8 { + let k1 = Self::round(0, Self::read_u64(p)); + h = (h ^ k1) + .rotate_left(27) + .wrapping_mul(PRIME64_1) + .wrapping_add(PRIME64_4); + p = &p[8..]; + } + if p.len() >= 4 { + let k = u32::from_le_bytes(p[..4].try_into().unwrap()) as u64; + h = (h ^ k.wrapping_mul(PRIME64_1)) + .rotate_left(23) + .wrapping_mul(PRIME64_2) + .wrapping_add(PRIME64_3); + p = &p[4..]; + } + for &b in p { + h = (h ^ (b as u64).wrapping_mul(PRIME64_5)) + .rotate_left(11) + .wrapping_mul(PRIME64_1); + } + + // Final avalanche. + h ^= h >> 33; + h = h.wrapping_mul(PRIME64_2); + h ^= h >> 29; + h = h.wrapping_mul(PRIME64_3); + h ^= h >> 32; + h + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn xxh64(data: &[u8]) -> u64 { + let mut h = Xxh64::new(); + h.update(data); + h.digest() + } + + #[test] + fn reference_vectors() { + // Canonical XXH64 vectors (seed 0) from the reference implementation. + assert_eq!(xxh64(b""), 0xEF46_DB37_51D8_E999); + assert_eq!(xxh64(b"a"), 0xD24E_C4F1_A98C_6E5B); + assert_eq!(xxh64(b"abc"), 0x44BC_2CF5_AD77_0999); + // 64 bytes ⇒ exercises the multi-stripe accumulator path. + let long: alloc::vec::Vec = (0..64u8).collect(); + assert_eq!(xxh64(&long), 0xF7C6_7301_DB67_13F0); + } + + #[test] + fn streaming_matches_one_shot() { + let data: alloc::vec::Vec = (0..250u32).map(|i| (i.wrapping_mul(37)) as u8).collect(); + let one = xxh64(&data); + // Feed in awkward chunk sizes that straddle stripe boundaries. + for chunk in [1usize, 3, 7, 8, 16, 31, 32, 33] { + let mut h = Xxh64::new(); + for part in data.chunks(chunk) { + h.update(part); + } + assert_eq!(h.digest(), one, "chunk size {chunk}"); + } + } +} diff --git a/tests/zstd.rs b/tests/zstd.rs index 0a81da2..873fdbe 100644 --- a/tests/zstd.rs +++ b/tests/zstd.rs @@ -443,18 +443,47 @@ fn decode_rejects_bad_magic() { assert_eq!(err, Error::BadHeader); } -#[test] -fn decode_rejects_checksum_flag() { - // Build a frame with Content_Checksum_Flag set (bit 2). We can't actually - // verify the checksum (no XXH64), so the decoder must refuse. +/// A minimal valid frame carrying `Content_Checksum_Flag`: magic, FHD with the +/// checksum bit set, a Window_Descriptor, one Last Raw_Block of `payload`, then +/// the 4-byte trailer (caller supplies it so tests can corrupt it). +fn checksummed_raw_frame(payload: &[u8], trailer: [u8; 4]) -> Vec { let mut f = Vec::new(); f.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]); - f.push(0x04); // FHD: Content_Checksum_Flag = 1 - f.push(0x50); + f.push(0x04); // FHD: Content_Checksum_Flag = 1, SS = 0, no dict, no FCS + f.push(0x00); // Window_Descriptor (minimal) + // Block_Header: Last_Block = 1, Block_Type = 0 (Raw), Block_Size = len. + let bh: u32 = 1 | ((payload.len() as u32) << 3); + f.push((bh & 0xFF) as u8); + f.push(((bh >> 8) & 0xFF) as u8); + f.push(((bh >> 16) & 0xFF) as u8); + f.extend_from_slice(payload); + f.extend_from_slice(&trailer); + f +} + +#[test] +fn decode_validates_correct_checksum() { + // Trailer = low 32 bits of XXH64("hello", seed 0), little-endian. + let frame = checksummed_raw_frame(b"hello", [0xA3, 0x6D, 0x9F, 0x88]); let mut dec = Decoder::new(); let mut out = [0u8; 16]; - let err = dec.decode(&f, &mut out).unwrap_err(); - assert_eq!(err, Error::Unsupported); + let (p, _st) = dec.decode(&frame, &mut out).unwrap(); + let (pf, _stf) = dec.finish(&mut out[p.written..]).unwrap(); + assert_eq!(&out[..p.written + pf.written], b"hello"); +} + +#[test] +fn decode_rejects_bad_checksum() { + // Same frame with the trailer corrupted: must be rejected, not ignored. + let frame = checksummed_raw_frame(b"hello", [0xA3, 0x6D, 0x9F, 0x00]); + let mut dec = Decoder::new(); + let mut out = [0u8; 16]; + // The mismatch surfaces once the trailer is read (here, in one shot). + let err = dec + .decode(&frame, &mut out) + .and_then(|p| dec.finish(&mut out[p.0.written..])) + .unwrap_err(); + assert_eq!(err, Error::ChecksumMismatch); } #[test]