diff --git a/CHANGELOG.md b/CHANGELOG.md index 94b4956..542d016 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Performance + +- **Encoder compression-ratio improvements** across the high-effort formats + (encoder-only; decoders unchanged, and every format's output still decodes + byte-for-byte with its reference tool — `xz`/`lzma`/`zstd`/`brotli`/`bzip2`/ + `lz4 -d`). Measured on a 2.9 MB real-source corpus, our max level vs the + reference's max level (`ours/ref`, lower is better): + - **bzip2**: 1.07 → **1.00** — the encoder was building a single Huffman + table and pinning all selectors to 0; now does the reference's up-to-6 + tables with 4 refinement passes (`sendMTFValues`) + depth-aware code + lengths + post-RLE1 block sizing. Output is byte-identical to `bzip2 -9`. + - **lzma**: 1.57 → **1.07** — cost-based optimal parse (LZMA-SDK-style + price model + DP over literals/matches/rep-matches) replacing the greedy + parse. `.lzma` is now near parity with `xz -9`. + - **lz4**: 1.53 → **1.18** — new HC (hash-chain + lazy) and price-based + optimal parse tiers wired to the level knob (`-l 9` does HC, `-l 12` + optimal); the fast low levels are unchanged. Also fixed a latent + conformance bug where a match could start in the final 12 bytes of a block + (rejected by strict `lz4 -d`). + - **zstd**: 1.49 → **1.40** — literals were always falling back to a raw + (un-entropy-coded) block because the Huffman-weight writer capped at 128 + symbols; added FSE-compressed weights, plus a price-based optimal parse and + repeat-offset preference at high levels. + - **xz / lzma2**: 1.60 → **1.51** — benefits from the shared LZMA optimal + parse; the remaining gap is the 64 KiB per-chunk dictionary/model reset + framing, not the parse. + - **brotli**: 1.50 → **1.48** — literal context modeling (multi-tree context + map), cost-aware match selection, and repeat-distance preference. + - **deflate/zlib/gzip** (≈1.01 vs `gzip -9`) and **lzw** were already at + parity and are unchanged. + ### Added - **Raw LZMA2 encoder** (`lzma2`): `compcol::lzma2::Lzma2` now encodes as well diff --git a/src/brotli/encoder_ctx.rs b/src/brotli/encoder_ctx.rs new file mode 100644 index 0000000..3cc806f --- /dev/null +++ b/src/brotli/encoder_ctx.rs @@ -0,0 +1,239 @@ +//! Literal context modeling for the brotli encoder (RFC 7932 §7.1). +//! +//! The decoder selects a literal Huffman tree per byte using a context +//! id derived from the two previous output bytes (`literal_context`), +//! then maps `context_id` → tree index through the literal context map +//! `cmapl`. The base encoder declared a single literal tree (NTREESL=1), +//! leaving the whole context-modeling lever on the table. +//! +//! This module builds, for one meta-block: +//! 1. a per-context literal histogram (64 contexts × 256 symbols), +//! 2. a clustering of those 64 contexts into a small number of trees +//! (agglomerative, merging contexts whose distributions are close), +//! 3. the resulting context map `cmapl[0..64]` (tree index per context). +//! +//! The encoder picks the UTF8 context mode — the same default the +//! reference uses for text — and emits the map plus one literal tree per +//! cluster. Everything stays spec-compliant; only encoder choices change. + +use alloc::vec::Vec; + +use super::context::{self, ContextMode}; + +/// Number of literal contexts (context id ∈ 0..=63). +pub(crate) const NUM_CONTEXTS: usize = 64; + +/// Upper bound on the number of literal trees we will emit. More trees +/// model the input more tightly but cost a full prefix-code header each; +/// 16 is a good balance and keeps the context-map alphabet small. +pub(crate) const MAX_LITERAL_TREES: usize = 16; + +/// Context modes the encoder evaluates per meta-block, picking the one +/// with the lowest estimated total cost. UTF8 distinguishes UTF8 byte +/// classes (good for mixed/multibyte text); MSB6/LSB6 split on the high +/// or low six bits of the previous byte and give near-order-1 separation +/// on ASCII text and source code — which UTF8 collapses into a couple of +/// buckets. Signed helps numeric/binary-ish data. +pub(crate) const CANDIDATE_MODES: [ContextMode; 4] = [ + ContextMode::Utf8, + ContextMode::Msb6, + ContextMode::Lsb6, + ContextMode::Signed, +]; + +/// Per-context literal histograms plus the cluster assignment. +pub(crate) struct LiteralContextModel { + /// The context mode this model was built for. + pub mode: ContextMode, + /// `histograms[c][b]` = count of literal byte `b` under context `c`, + /// folded across clusters after merging (so a cluster's representative + /// context carries the merged histogram). Only used to derive per-tree + /// frequencies, which are reconstructed by the caller from `cmap`, so + /// the post-merge layout does not matter to correctness. + pub histograms: Vec<[u32; 256]>, + /// `cmap[c]` = tree index assigned to context `c` (0..num_trees). + pub cmap: Vec, + /// Number of distinct trees actually used. + pub num_trees: u32, + /// Estimated encoded cost of the literals under this model, in bits + /// (data + a rough per-tree header allowance). Used to compare modes. + pub est_cost_bits: u64, +} + +/// Shannon-style bit cost of a histogram: `Σ count·log2(total/count)`. +/// Returned in fixed-point (bits × 256) to stay in integer arithmetic +/// (this is a no_std crate; `f64::log2` is unavailable without `std`). +fn histogram_bits(hist: &[u32; 256], total: u32) -> u64 { + if total == 0 { + return 0; + } + let log_total = log2_fixed(total as u64); + let mut bits: u64 = 0; + for &c in hist.iter() { + if c != 0 { + // count * (log2(total) - log2(count)) + bits += (c as u64) * (log_total - log2_fixed(c as u64)); + } + } + bits +} + +/// `log2(x) * 256` for `x ≥ 1`, integer math. Combines an integer +/// floor-log2 with a small fractional interpolation table. +fn log2_fixed(x: u64) -> u64 { + debug_assert!(x >= 1); + if x == 1 { + return 0; + } + let floor = 63 - x.leading_zeros() as u64; // floor(log2(x)) + // Fractional part via linear interpolation between 2^floor and + // 2^(floor+1). frac = (x - 2^floor) / 2^floor, scaled to 0..256. + let base = 1u64 << floor; + let frac = ((x - base) << 8) / base; // 0..256 + floor * 256 + frac +} + +/// Combined bit cost of two histograms merged into one. +fn merged_bits(a: &[u32; 256], at: u32, b: &[u32; 256], bt: u32) -> u64 { + let total = at + bt; + if total == 0 { + return 0; + } + let log_total = log2_fixed(total as u64); + let mut bits: u64 = 0; + for i in 0..256 { + let c = a[i] + b[i]; + if c != 0 { + bits += (c as u64) * (log_total - log2_fixed(c as u64)); + } + } + bits +} + +/// Rough fixed-point (bits×256) allowance for one literal prefix-code +/// header (256-symbol complex code) plus its share of the context map. +/// Used both as the merge "bonus" and in the cross-mode cost estimate so +/// the two stay consistent. +const HEADER_COST_BITS: u64 = 140 * 256; + +/// Cluster the per-context histograms (already tallied for `mode`) into +/// at most `max_trees` literal trees, then estimate the model's total +/// encoded cost so the caller can compare context modes. +/// +/// The histograms are tallied over exactly the literal bytes the encoder +/// will emit (see `build_literal_context_model` in `mod.rs`). The merge +/// is agglomerative: repeatedly fuse the pair of clusters whose union +/// costs the fewest extra data bits, charging each surviving cluster a +/// fixed header allowance so similar contexts coalesce. +pub(crate) fn cluster( + mode: ContextMode, + mut histograms: Vec<[u32; 256]>, + max_trees: usize, +) -> LiteralContextModel { + debug_assert_eq!(histograms.len(), NUM_CONTEXTS); + + // Per-context totals. + let mut totals: Vec = histograms.iter().map(|h| h.iter().sum::()).collect(); + + // Cluster id per context. + let mut cluster_of: Vec = (0..NUM_CONTEXTS as i32).collect(); + + // Active cluster set: start with one cluster per non-empty context. + let mut active: Vec = (0..NUM_CONTEXTS).filter(|&c| totals[c] > 0).collect(); + + if active.is_empty() { + return LiteralContextModel { + mode, + histograms, + cmap: alloc::vec![0u8; NUM_CONTEXTS], + num_trees: 1, + est_cost_bits: 0, + }; + } + + // Park empty contexts onto the first active cluster. + let first_active = active[0]; + for c in 0..NUM_CONTEXTS { + if totals[c] == 0 { + cluster_of[c] = first_active as i32; + } + } + + while active.len() > 1 { + let force = active.len() > max_trees; + let mut best_i = 0usize; + let mut best_j = 0usize; + let mut best_delta: i64 = i64::MAX; + for ai in 0..active.len() { + for aj in (ai + 1)..active.len() { + let ci = active[ai]; + let cj = active[aj]; + let bi = histogram_bits(&histograms[ci], totals[ci]); + let bj = histogram_bits(&histograms[cj], totals[cj]); + let bm = merged_bits(&histograms[ci], totals[ci], &histograms[cj], totals[cj]); + // Merging trades a header allowance against extra data bits. + let delta = bm as i64 - bi as i64 - bj as i64 - HEADER_COST_BITS as i64; + if delta < best_delta { + best_delta = delta; + best_i = ai; + best_j = aj; + } + } + } + // Stop when not forced and the cheapest merge is a net loss. + if !force && best_delta > 0 { + break; + } + let ci = active[best_i]; + let cj = active[best_j]; + let src = histograms[cj]; + for (dst, s) in histograms[ci].iter_mut().zip(src.iter()) { + *dst += *s; + } + totals[ci] += totals[cj]; + for slot in cluster_of.iter_mut() { + if *slot == cj as i32 { + *slot = ci as i32; + } + } + active.swap_remove(best_j); + } + + // Compress cluster ids to a dense 0..num_trees range. + let mut remap = alloc::vec![-1i32; NUM_CONTEXTS]; + let mut next = 0u8; + let mut cmap = alloc::vec![0u8; NUM_CONTEXTS]; + for c in 0..NUM_CONTEXTS { + let cl = cluster_of[c] as usize; + if remap[cl] < 0 { + remap[cl] = next as i32; + next += 1; + } + cmap[c] = remap[cl] as u8; + } + let num_trees = next.max(1) as u32; + + // Estimate total cost: data bits across surviving clusters + a header + // allowance per tree. `active` now holds the surviving cluster reps. + let mut data_bits: u64 = 0; + for &ci in &active { + data_bits += histogram_bits(&histograms[ci], totals[ci]); + } + let est_cost_bits = data_bits / 256 + num_trees as u64 * (HEADER_COST_BITS / 256); + + LiteralContextModel { + mode, + histograms, + cmap, + num_trees, + est_cost_bits, + } +} + +/// Compute the literal context id from the two preceding output bytes +/// under the given mode. `prev1`/`prev2` are the bytes at `g-1`/`g-2` in +/// the full output stream. +#[inline] +pub(crate) fn context_id(mode: ContextMode, prev1: u8, prev2: u8) -> u8 { + context::literal_context(mode, prev1, prev2) +} diff --git a/src/brotli/encoder_lz77.rs b/src/brotli/encoder_lz77.rs index 6e344ac..71939fe 100644 --- a/src/brotli/encoder_lz77.rs +++ b/src/brotli/encoder_lz77.rs @@ -50,6 +50,12 @@ pub(crate) struct FinderParams { pub max_chain: usize, /// Length at which the finder stops looking for a longer candidate. pub nice_match: usize, + /// When `true`, the encoder selects matches by [`find_match_cost`] + /// (maximise bit gain — prefer closer/cheaper distances) instead of + /// [`find_match`] (longest match). Distance coding dominates our + /// output, so cost-aware selection meaningfully cuts distance extra + /// bits at higher qualities. + pub cost_match: bool, } pub(crate) struct MatchFinder { @@ -218,6 +224,131 @@ impl MatchFinder { None } } + + /// Cost-aware variant of [`find_match`]. Instead of returning the + /// strictly longest match, it returns the match that maximises an + /// approximate bit *gain*: + /// + /// ```text + /// gain(len, dist) = len * VALUE_PER_BYTE - distance_cost(dist) + /// ``` + /// + /// where `distance_cost ≈ floor(log2(dist))` charges far distances for + /// their extra bits. This makes the finder prefer a slightly shorter + /// but much closer match when the closer distance is cheaper to code — + /// the dominant cost in our output is distance extra bits, so trading a + /// byte of length for a far smaller distance is frequently a net win. + /// + /// `VALUE_PER_BYTE` is tuned so a one-byte length sacrifice is accepted + /// only when it saves at least that many distance bits. + pub(crate) fn find_match_cost( + &self, + buffer: &[u8], + pos: usize, + params: FinderParams, + ) -> Option<(usize, usize)> { + const VALUE_PER_BYTE: i64 = 3; + let buf_len = buffer.len(); + if pos + MIN_MATCH > buf_len { + return None; + } + let h = hash4_at(buffer, pos); + let idx = (h as usize) & (HASH_SIZE - 1); + + let max_dist = WINDOW_SIZE.min(pos); + let max_len = MAX_MATCH.min(buf_len - pos); + if max_len < MIN_MATCH { + return None; + } + let nice = params.nice_match.min(max_len); + let chain_cap = params.max_chain; + let target = &buffer[pos..pos + max_len]; + + let mut best_len: usize = 0; + let mut best_dist: usize = 0; + let mut best_gain: i64 = i64::MIN; + + let prev = &self.prev[..]; + let head = &self.head[..]; + let mut cur = head[idx]; + let mut steps = 0usize; + while cur != NIL && steps < chain_cap { + let cur_pos = cur as usize; + if cur_pos >= pos { + cur = prev[cur_pos]; + steps += 1; + continue; + } + let dist = pos - cur_pos; + if dist > max_dist { + break; + } + // Cheap reject: this candidate can only beat the incumbent if + // it matches at least one byte past the current best length + // (a shorter match would need an implausibly tiny distance to + // win on gain; the explicit gain check below is the arbiter, + // but extending the compare to `best_len` first is wasted work + // when even a full-length match here would lose). + let cand = &buffer[cur_pos..cur_pos + max_len]; + let mut len = 0usize; + while len + 8 <= max_len { + let a = u64::from_le_bytes([ + cand[len], + cand[len + 1], + cand[len + 2], + cand[len + 3], + cand[len + 4], + cand[len + 5], + cand[len + 6], + cand[len + 7], + ]); + let b = u64::from_le_bytes([ + target[len], + target[len + 1], + target[len + 2], + target[len + 3], + target[len + 4], + target[len + 5], + target[len + 6], + target[len + 7], + ]); + let diff = a ^ b; + if diff != 0 { + len += (diff.trailing_zeros() / 8) as usize; + break; + } + len += 8; + } + while len < max_len && cand[len] == target[len] { + len += 1; + } + + if len >= MIN_MATCH { + let dist_cost = 32 - (dist as u32).leading_zeros(); // ≈ log2(dist)+1 + let gain = len as i64 * VALUE_PER_BYTE - dist_cost as i64; + if gain > best_gain { + best_gain = gain; + best_len = len; + best_dist = dist; + // Stop once we have a long-enough match at this (closest + // so far) distance; deeper chain entries are strictly + // farther, so they can only win with extra length, which + // is rare past `nice`. + if len >= nice { + break; + } + } + } + cur = prev[cur_pos]; + steps += 1; + } + + if best_len >= MIN_MATCH { + Some((best_len, best_dist)) + } else { + None + } + } } impl Default for MatchFinder { @@ -226,6 +357,61 @@ impl Default for MatchFinder { } } +/// Compute the match length at a *specific* back-distance `dist` starting +/// at `pos`. Returns the number of bytes that match (capped at +/// [`MAX_MATCH`] and the slice tail), or 0 when `dist` is out of range. +/// +/// Used by the encoder's repeat-distance preference: a match reachable at +/// a recently-used distance costs only a 2–6 bit short code instead of a +/// full distance symbol plus up to 24 extra bits, so even a somewhat +/// shorter repeat-distance match can win on total bits. +pub(crate) fn match_len_at(buffer: &[u8], pos: usize, dist: usize) -> usize { + if dist == 0 || dist > pos { + return 0; + } + let buf_len = buffer.len(); + let max_len = MAX_MATCH.min(buf_len - pos); + if max_len == 0 { + return 0; + } + let src = pos - dist; + let cand = &buffer[src..src + max_len]; + let target = &buffer[pos..pos + max_len]; + let mut len = 0usize; + while len + 8 <= max_len { + let a = u64::from_le_bytes([ + cand[len], + cand[len + 1], + cand[len + 2], + cand[len + 3], + cand[len + 4], + cand[len + 5], + cand[len + 6], + cand[len + 7], + ]); + let b = u64::from_le_bytes([ + target[len], + target[len + 1], + target[len + 2], + target[len + 3], + target[len + 4], + target[len + 5], + target[len + 6], + target[len + 7], + ]); + let diff = a ^ b; + if diff != 0 { + len += (diff.trailing_zeros() / 8) as usize; + return len; + } + len += 8; + } + while len < max_len && cand[len] == target[len] { + len += 1; + } + len +} + /// Hash four bytes into a 15-bit bucket. #[inline] fn hash4_at(buffer: &[u8], pos: usize) -> u32 { diff --git a/src/brotli/mod.rs b/src/brotli/mod.rs index 436364a..ad31a74 100644 --- a/src/brotli/mod.rs +++ b/src/brotli/mod.rs @@ -58,6 +58,7 @@ use crate::traits::{Algorithm, RawDecoder, RawEncoder, RawProgress}; mod context; mod dictionary; +mod encoder_ctx; mod encoder_dict; mod encoder_huffman; mod encoder_iac; @@ -118,26 +119,27 @@ impl LevelParams { // Clamp instead of returning Err — keeping the public surface // infallible matches the reference brotli CLI's behaviour. let q = if quality > 11 { 11 } else { quality }; - // (max_chain, nice_match, use_dict) - let (max_chain, nice_match, use_dict) = match q { - 0 => (2, 8, false), - 1 => (4, 16, false), - 2 => (8, 24, false), - 3 => (16, 32, false), - 4 => (24, 48, true), - 5 => (48, 96, true), - 6 => (64, 128, true), - 7 => (96, 192, true), - 8 => (160, 256, true), - 9 => (256, 384, true), - 10 => (512, 768, true), + // (max_chain, nice_match, use_dict, cost_match) + let (max_chain, nice_match, use_dict, cost_match) = match q { + 0 => (2, 8, false, false), + 1 => (4, 16, false, false), + 2 => (8, 24, false, false), + 3 => (16, 32, false, false), + 4 => (24, 48, true, true), + 5 => (48, 96, true, true), + 6 => (64, 128, true, true), + 7 => (96, 192, true, true), + 8 => (160, 256, true, true), + 9 => (256, 384, true, true), + 10 => (512, 768, true, true), // 11 (and clamp-from-above) - _ => (1024, 1024, true), + _ => (1024, 1024, true, true), }; Self { finder: encoder_lz77::FinderParams { max_chain, nice_match, + cost_match, }, use_dict, } @@ -261,6 +263,12 @@ pub struct Encoder { /// this stream. Mirrors the decoder's `total_out` and is used to /// compute `max_dist` for static-dictionary references. prev_total_out: u64, + /// The two output bytes immediately preceding the next meta-block — + /// the decoder's persistent `p1`/`p2` at block start. Used by the + /// encoder's literal-context model to compute each literal's context + /// id. Both 0 at stream start (matching the decoder). + prev_out1: u8, + prev_out2: u8, /// Lazily built static-dictionary index for encoder-side dictionary /// references. The index is ~80 KiB and is reused across meta-blocks. dict_index: Option>, @@ -376,6 +384,8 @@ impl Encoder { seen_any_input: false, ring: DistRing::new(), prev_total_out: 0, + prev_out1: 0, + prev_out2: 0, dict_index: None, id_transforms: None, params: LevelParams::from_quality(config.quality), @@ -445,6 +455,8 @@ impl Encoder { // to allocate a separate Vec and drain. The drain happens after // we've finished encoding so the borrow doesn't conflict. let pending_view = &self.pending[..mlen]; + let prev_out1 = self.prev_out1; + let prev_out2 = self.prev_out2; encode_meta_block( &mut self.bw, &mut self.out, @@ -452,11 +464,22 @@ impl Encoder { is_last, &mut self.ring, self.prev_total_out, + prev_out1, + prev_out2, dict_index.as_deref(), id_transforms.as_deref().map(|v| v.as_slice()), self.params, scratch, ); + // Carry the last two output bytes of this block into the next + // block's literal-context state (mirrors the decoder's p1/p2). + if mlen >= 2 { + self.prev_out2 = self.pending[mlen - 2]; + self.prev_out1 = self.pending[mlen - 1]; + } else if mlen == 1 { + self.prev_out2 = self.prev_out1; + self.prev_out1 = self.pending[0]; + } self.pending.drain(..mlen); self.prev_total_out += mlen as u64; } @@ -578,6 +601,8 @@ impl RawEncoder for Encoder { self.seen_any_input = false; self.ring = DistRing::new(); self.prev_total_out = 0; + self.prev_out1 = 0; + self.prev_out2 = 0; // Keep `dict_index`, `id_transforms`, `params`, and `scratch` — // they're immutable tables / configuration / capacity we'd // rebuild identically. `scratch` is `prepare()`d before the next @@ -627,6 +652,7 @@ fn lz77_to_commands( id_transforms: Option<&[encoder_dict::IdTransform]>, prev_total_out: u64, finder_params: encoder_lz77::FinderParams, + ring_start: DistRing, scratch: &mut EncScratch, ) { use encoder_lz77::{MAX_MATCH, MIN_MATCH}; @@ -670,6 +696,12 @@ fn lz77_to_commands( } let payload_len = payload.len(); let mut pos = 0usize; + // Local distance ring, mirroring the one `plan_commands` will rebuild, + // so the repeat-distance preference sees the same recent distances the + // decoder will. Backref distances push (unless they equal the current + // last distance — short code 0 does not push); dictionary refs never + // push. This must stay in lockstep with `plan_commands`. + let mut ring = ring_start; // We mirror the decoder's `total_out`: it's `prev_total_out` plus // the number of input bytes encoded so far in this meta-block. For @@ -687,10 +719,19 @@ fn lz77_to_commands( let mut best_dict_tr_id: u8 = 0; let mut best_dict_emit_len: u32 = 0; - // 1) In-window LZ77 match. - if pos + MIN_MATCH <= payload_len - && let Some((len, dist)) = mf.find_match(payload, pos, finder_params) - { + // 1) In-window LZ77 match. At higher qualities we use the + // cost-aware finder, which prefers closer (cheaper-distance) + // matches over marginally longer far ones. + let found = if pos + MIN_MATCH <= payload_len { + if finder_params.cost_match { + mf.find_match_cost(payload, pos, finder_params) + } else { + mf.find_match(payload, pos, finder_params) + } + } else { + None + }; + if let Some((len, dist)) = found { let len = len.min(MAX_MATCH).min(payload_len - pos); if len >= MIN_MATCH { best_len = len; @@ -699,6 +740,66 @@ fn lz77_to_commands( } } + // 1b) Repeat-distance preference. A match reachable at one of the + // four most-recent distances encodes its distance as a cheap + // short code (≈4 bits, no extra) instead of a full symbol plus + // up to 24 extra bits, so even a shorter repeat-distance match + // usually wins on total bits. Distance coding is ~58% of our + // output, so this is the dominant ratio lever. + // + // We compare candidates by an approximate *gain* model: + // gain(len, dist) = len * VALUE_PER_BYTE - distance_cost(dist) + // where a covered byte is worth ~`VALUE_PER_BYTE` bits and a + // far distance costs ~log2(dist) extra bits plus its symbol. + // The candidate with the highest gain is taken; ties favour the + // longer match. + if pos + MIN_MATCH <= payload_len && best_kind != 2 { + // Bit cost of a distance: ring distances are short codes; + // everything else pays its symbol + extra bits (~log2(d)). + let last1 = ring.nth_last(1); + let dist_cost = |d: u32, is_repeat: bool| -> i64 { + if is_repeat { + // Short code: ~2 bits when it is the last distance + // (code 0), ~5 bits for the other ring slots. + if d as i32 == last1 { 2 } else { 5 } + } else { + // Symbol (~6 bits) + extra (~floor(log2(d))). + let lg = 31 - d.max(1).leading_zeros(); + (6 + lg) as i64 + } + }; + const VALUE_PER_BYTE: i64 = 6; + + // Baseline gain from the longest match (if any). + let mut best_gain: i64 = if best_kind == 1 { + best_len as i64 * VALUE_PER_BYTE - dist_cost(best_match_dist, false) + } else { + i64::MIN + }; + + for n in 1u32..=4 { + let rd = ring.nth_last(n); + if rd <= 0 { + continue; + } + let rd = rd as usize; + if rd > pos { + continue; + } + let rl = encoder_lz77::match_len_at(payload, pos, rd); + if rl < MIN_MATCH { + continue; + } + let gain = rl as i64 * VALUE_PER_BYTE - dist_cost(rd as u32, true); + if gain > best_gain { + best_gain = gain; + best_len = rl; + best_kind = 1; + best_match_dist = rd as u32; + } + } + } + // 2) Static-dictionary reference. Heuristic: only consider dict // refs when emitted length is long enough to amortise the // distance-code cost (≥ 6 bytes with no LZ77 alternative, @@ -762,6 +863,16 @@ fn lz77_to_commands( } else { best_dict_word_len as u32 }; + // Mirror the ring update `plan_commands` will perform: a + // back-reference pushes its distance unless it equals the + // current last distance (short code 0, which does not push); + // dictionary references never push. + if best_kind == 1 { + let d = best_match_dist as i32; + if d != ring.nth_last(1) { + ring.push(d); + } + } cmds.push(Command { insert: core::mem::replace(&mut pending, next_pending), copy_len, @@ -1078,8 +1189,12 @@ fn plan_commands( } /// Build the meta-block header bits *up to but not including* the -/// prefix codes. `is_last` controls whether ISLAST/ISLASTEMPTY are -/// emitted; on the last meta-block ISUNCOMPRESSED is omitted. +/// literal context mode. `is_last` controls whether ISLAST/ISLASTEMPTY +/// are emitted; on the last meta-block ISUNCOMPRESSED is omitted. +/// +/// The CMODE / NTREESL / literal-context-map / NTREESD fields are emitted +/// separately by the caller (see [`write_literal_context_header`]) since +/// they depend on whether the encoder chose to model literal contexts. fn write_meta_block_header(bw: &mut BitWriter, out: &mut Vec, mlen: u32, is_last: bool) { debug_assert!(mlen >= 1 && mlen <= MAX_BLOCK as u32); // ISLAST @@ -1106,12 +1221,85 @@ fn write_meta_block_header(bw: &mut BitWriter, out: &mut Vec, mlen: u32, is_ bw.write(0, 2, out); // NDIRECT = 0 (4 bits) bw.write(0, 4, out); - // CMODE[0] = 0 (LSB6). 2 bits per CMODE entry; NBLTYPESL = 1 so one entry. - bw.write(0, 2, out); - // NTREESL = 1 → "0" +} + +/// Emit CMODE[0], NTREESL + (optional) literal context map, and NTREESD. +/// +/// When `num_lit_trees == 1` this reproduces the legacy single-tree +/// header: CMODE is irrelevant (one tree, all-zero map), NTREESL=1, +/// NTREESD=1. When `num_lit_trees >= 2` it emits the chosen context mode, +/// NTREESL, and the literal context map `cmap` (one tree index per +/// context, 0..63). +fn write_literal_context_header( + bw: &mut BitWriter, + out: &mut Vec, + cmode: u32, + num_lit_trees: u32, + cmap: &[u8], +) { + // CMODE[0] (2 bits). With a single tree the value is decode-irrelevant + // (the context map is all zero), so 0 (LSB6) keeps the legacy bytes. + let cmode_bits = if num_lit_trees >= 2 { cmode } else { 0 }; + bw.write(cmode_bits, 2, out); + if num_lit_trees >= 2 { + // NTREESL = num_lit_trees, encoded with the nbltypes scheme. + write_nbltypes(bw, out, num_lit_trees); + // Literal context map of size 64 * NBLTYPESL = 64. + write_context_map(bw, out, cmap, num_lit_trees); + } else { + // NTREESL = 1 → "0". + bw.write(0, 1, out); + } + // NTREESD = 1 → "0" (we never split distance trees). bw.write(0, 1, out); - // (No literal context map since NTREESL = 1.) - // NTREESD = 1 → "0" +} + +/// Encode a count using brotli's NBLTYPES / NTREES variable-length code +/// (§9.2 "1 + ..."). Inverse of [`Decoder::read_nbltypes`]. +fn write_nbltypes(bw: &mut BitWriter, out: &mut Vec, value: u32) { + debug_assert!(value >= 1); + if value == 1 { + bw.write(0, 1, out); + return; + } + // First bit 1, then 3-bit selector N, then N extra bits. + bw.write(1, 1, out); + if value == 2 { + // N = 0 → value 2. + bw.write(0, 3, out); + return; + } + // value = (1 << n) + 1 + extra, with extra < (1 << n). + let v = value - 1; // value - 1 = (1<= 2 + let extra = v - (1u32 << n); + debug_assert!(extra < (1u32 << n)); + bw.write(n, 3, out); + bw.write(extra, n, out); +} + +/// Emit a context map (literal or distance) using the simplest valid +/// encoding: RLEMAX=0 (no zero-run codes), a prefix code over `ntrees` +/// symbols built from the map's own value frequencies, the map values +/// verbatim, then IMTF=0 (no move-to-front). +/// +/// Inverse of [`read_context_map`]. +fn write_context_map(bw: &mut BitWriter, out: &mut Vec, map: &[u8], ntrees: u32) { + debug_assert!(ntrees >= 2); + // RLEMAX = 0 → single "0" bit, no extra. + bw.write(0, 1, out); + // Prefix code over the `ntrees` map symbols. Build from frequencies. + let mut freq = alloc::vec![0u32; ntrees as usize]; + for &m in map { + freq[m as usize] += 1; + } + let strategy = pick_huffman_strategy(&freq, ntrees as usize); + let codes = emit_prefix_code(bw, out, &strategy, ntrees); + // Emit each map entry as a symbol. + for &m in map { + write_symbol(bw, out, &strategy, &codes, m as u32); + } + // IMTF = 0. bw.write(0, 1, out); } @@ -1171,6 +1359,8 @@ fn encode_meta_block( is_last: bool, ring: &mut DistRing, prev_total_out: u64, + prev1: u8, + prev2: u8, dict_index: Option<&DictIndex>, id_transforms: Option<&[IdTransform]>, level: LevelParams, @@ -1181,73 +1371,99 @@ fn encode_meta_block( // Window size = 1 << WBITS = 1 << 16 (the encoder always picks WBITS=16). const WINDOW_SIZE: u32 = 1 << 16; - // 1. Run LZ77 + command construction in a single fused pass. + // 1. Run LZ77 + command construction in a single fused pass. The + // match finder is given a copy of the block-start distance ring so + // its repeat-distance preference matches what `plan_commands`/the + // decoder will see; `plan_commands` then advances the real ring. lz77_to_commands( payload, dict_index, id_transforms, prev_total_out, level.finder, + *ring, scratch, ); // 2. Plan + tally frequencies in a single pass. plan_commands(mlen, ring, prev_total_out, WINDOW_SIZE, scratch); - // 3. Pick Huffman strategies (operates on the scratch frequency tables). - let lit_strategy = pick_huffman_strategy(&scratch.lit_freq, 256); + // 3. Decide whether to model literal contexts. We do so when the + // encoder is on a dictionary-enabled tier (quality ≥ 4) and the + // payload is large enough to amortise multiple prefix-code + // headers. Below that threshold the single-tree path wins on + // overhead. + let lit_model = if level.use_dict && payload.len() >= 1024 { + build_literal_context_model(payload, prev1, prev2, scratch) + } else { + None + }; + + // 4. Pick Huffman strategies. For literals, either a single tree + // (legacy) or one tree per cluster. let ic_strategy = pick_huffman_strategy(&scratch.ic_freq, 704); let dist_strategy = pick_huffman_strategy(&scratch.dist_freq, 64); - // 4. Write the meta-block header. - write_meta_block_header(bw, out, mlen, is_last); - - // 5. Emit prefix codes. - let lit_codes = emit_prefix_code(bw, out, &lit_strategy, 256); - let ic_codes = emit_prefix_code(bw, out, &ic_strategy, 704); - let dist_codes = emit_prefix_code(bw, out, &dist_strategy, 64); - - // 6. Emit the command stream. - // Specialise on the common all-complex case so the hot inner loop - // skips the per-symbol match dispatch in `write_symbol`. For most - // inputs (Lorem, mixed text) the literal alphabet is dense (≥ 16 - // distinct bytes) so this picks up the complex branch. - let scratch_view: &EncScratch = scratch; - let cmds_len = scratch_view.cmds.len(); - for i in 0..cmds_len { - let sym = scratch_view.ic_sym[i]; - write_symbol(bw, out, &ic_strategy, &ic_codes, sym); - let (ieb, iev) = scratch_view.ins_extra[i]; - if ieb > 0 { - bw.write(iev, ieb, out); - } - let (ceb, cev) = scratch_view.copy_extra[i]; - if ceb > 0 { - bw.write(cev, ceb, out); + match lit_model { + Some(model) if model.num_trees >= 2 => { + encode_meta_block_with_contexts( + bw, + out, + payload, + mlen, + is_last, + prev1, + prev2, + &ic_strategy, + &dist_strategy, + &model, + scratch, + ); } - // Inline the literal-emission fast path for the common complex - // strategy: a single bounds-check + length lookup + reverse + - // write per byte, with no per-byte enum dispatch. - let insert = &scratch_view.cmds[i].insert; - match &lit_strategy { - HuffStrategy::Complex(lengths) => { - for &b in insert { - let len = lengths[b as usize] as u32; - debug_assert!(len > 0); - let code = lit_codes[b as usize]; - let rev = reverse_bits(code as u32, len); - bw.write(rev, len, out); + _ => { + // Legacy single-literal-tree path. + let lit_strategy = pick_huffman_strategy(&scratch.lit_freq, 256); + write_meta_block_header(bw, out, mlen, is_last); + write_literal_context_header(bw, out, 0, 1, &[]); + let lit_codes = emit_prefix_code(bw, out, &lit_strategy, 256); + let ic_codes = emit_prefix_code(bw, out, &ic_strategy, 704); + let dist_codes = emit_prefix_code(bw, out, &dist_strategy, 64); + + let scratch_view: &EncScratch = scratch; + let cmds_len = scratch_view.cmds.len(); + for i in 0..cmds_len { + let sym = scratch_view.ic_sym[i]; + write_symbol(bw, out, &ic_strategy, &ic_codes, sym); + let (ieb, iev) = scratch_view.ins_extra[i]; + if ieb > 0 { + bw.write(iev, ieb, out); } - } - _ => { - for &b in insert { - write_symbol(bw, out, &lit_strategy, &lit_codes, b as u32); + let (ceb, cev) = scratch_view.copy_extra[i]; + if ceb > 0 { + bw.write(cev, ceb, out); + } + let insert = &scratch_view.cmds[i].insert; + match &lit_strategy { + HuffStrategy::Complex(lengths) => { + for &b in insert { + let len = lengths[b as usize] as u32; + debug_assert!(len > 0); + let code = lit_codes[b as usize]; + let rev = reverse_bits(code as u32, len); + bw.write(rev, len, out); + } + } + _ => { + for &b in insert { + write_symbol(bw, out, &lit_strategy, &lit_codes, b as u32); + } + } + } + if let Some((dcode, ndb, dextra)) = scratch_view.dist_enc[i] { + write_symbol(bw, out, &dist_strategy, &dist_codes, dcode); + if ndb > 0 { + bw.write(dextra, ndb, out); + } } - } - } - if let Some((dcode, ndb, dextra)) = scratch_view.dist_enc[i] { - write_symbol(bw, out, &dist_strategy, &dist_codes, dcode); - if ndb > 0 { - bw.write(dextra, ndb, out); } } } @@ -1262,6 +1478,165 @@ fn encode_meta_block( } } +/// Build the per-context literal histograms for this meta-block and run +/// the clustering. Returns `None` when the model collapses to a single +/// tree (caller falls back to the legacy single-tree path). +/// +/// `prev1`/`prev2` are the two output bytes preceding the block. The +/// literal bytes are exactly the `insert` runs of the planned commands; +/// their output positions follow the command cursor, and `p1`/`p2` for +/// each literal are the two immediately-preceding output bytes (which — +/// since the decoded output equals `payload` — we read straight from +/// `payload`). +fn build_literal_context_model( + payload: &[u8], + prev1: u8, + prev2: u8, + scratch: &EncScratch, +) -> Option { + use encoder_ctx::NUM_CONTEXTS; + + // First, count total literals — bail cheaply when there are too few to + // benefit from per-context trees. + let total_lits: u64 = scratch.cmds.iter().map(|c| c.insert.len() as u64).sum(); + if total_lits < 256 { + return None; + } + + // Evaluate each candidate context mode: tally per-context histograms, + // cluster, and keep the model with the lowest estimated cost. The + // histogram pass is O(literals) per mode and cheap next to LZ77. + let mut best: Option = None; + for &mode in &encoder_ctx::CANDIDATE_MODES { + let mut histograms: Vec<[u32; 256]> = alloc::vec![[0u32; 256]; NUM_CONTEXTS]; + let mut g: usize = 0; + for c in &scratch.cmds { + for &b in &c.insert { + let p1 = if g >= 1 { payload[g - 1] } else { prev1 }; + let p2 = if g >= 2 { + payload[g - 2] + } else if g == 1 { + prev1 + } else { + prev2 + }; + let cid = encoder_ctx::context_id(mode, p1, p2) as usize; + histograms[cid][b as usize] += 1; + g += 1; + } + match c.kind { + CopyKind::Backref { .. } => g += c.copy_len as usize, + CopyKind::Dict { emit_len, .. } => g += emit_len as usize, + CopyKind::None => {} + } + } + let model = encoder_ctx::cluster(mode, histograms, encoder_ctx::MAX_LITERAL_TREES); + match &best { + Some(b) if b.est_cost_bits <= model.est_cost_bits => {} + _ => best = Some(model), + } + } + + match best { + Some(model) if model.num_trees >= 2 => Some(model), + _ => None, + } +} + +/// Emit a meta-block using literal context modeling: one literal Huffman +/// tree per cluster, selected per byte through the context map. +#[allow(clippy::too_many_arguments)] +fn encode_meta_block_with_contexts( + bw: &mut BitWriter, + out: &mut Vec, + payload: &[u8], + mlen: u32, + is_last: bool, + prev1: u8, + prev2: u8, + ic_strategy: &HuffStrategy, + dist_strategy: &HuffStrategy, + model: &encoder_ctx::LiteralContextModel, + scratch: &EncScratch, +) { + let _ = mlen; + let num_trees = model.num_trees as usize; + + // Per-tree literal frequency tables: fold each context's histogram + // into its assigned tree. + let mut tree_freqs: Vec<[u32; 256]> = alloc::vec![[0u32; 256]; num_trees]; + for (cid, hist) in model.histograms.iter().enumerate() { + let t = model.cmap[cid] as usize; + let dst = &mut tree_freqs[t]; + for (d, h) in dst.iter_mut().zip(hist.iter()) { + *d += *h; + } + } + + // Build a strategy + code table per tree. + let lit_strategies: Vec = tree_freqs + .iter() + .map(|f| pick_huffman_strategy(f, 256)) + .collect(); + + // Header. + write_meta_block_header(bw, out, mlen, is_last); + write_literal_context_header(bw, out, model.mode as u32, model.num_trees, &model.cmap); + + // Literal prefix codes (one per tree), in tree-index order. + let mut lit_codes: Vec> = Vec::with_capacity(num_trees); + for strat in &lit_strategies { + lit_codes.push(emit_prefix_code(bw, out, strat, 256)); + } + // IC + distance prefix codes. + let ic_codes = emit_prefix_code(bw, out, ic_strategy, 704); + let dist_codes = emit_prefix_code(bw, out, dist_strategy, 64); + + // Emit the command stream, selecting a literal tree per byte from its + // context. `g` tracks the output position so we can read p1/p2 from + // `payload` (output == payload for this block). + let cmds_len = scratch.cmds.len(); + let mut g: usize = 0; + for i in 0..cmds_len { + let sym = scratch.ic_sym[i]; + write_symbol(bw, out, ic_strategy, &ic_codes, sym); + let (ieb, iev) = scratch.ins_extra[i]; + if ieb > 0 { + bw.write(iev, ieb, out); + } + let (ceb, cev) = scratch.copy_extra[i]; + if ceb > 0 { + bw.write(cev, ceb, out); + } + for &b in &scratch.cmds[i].insert { + let p1 = if g >= 1 { payload[g - 1] } else { prev1 }; + let p2 = if g >= 2 { + payload[g - 2] + } else if g == 1 { + prev1 + } else { + prev2 + }; + let cid = encoder_ctx::context_id(model.mode, p1, p2) as usize; + let t = model.cmap[cid] as usize; + write_symbol(bw, out, &lit_strategies[t], &lit_codes[t], b as u32); + g += 1; + } + if let Some((dcode, ndb, dextra)) = scratch.dist_enc[i] { + write_symbol(bw, out, dist_strategy, &dist_codes, dcode); + if ndb > 0 { + bw.write(dextra, ndb, out); + } + } + // Advance output cursor past the copy. + match scratch.cmds[i].kind { + CopyKind::Backref { .. } => g += scratch.cmds[i].copy_len as usize, + CopyKind::Dict { emit_len, .. } => g += emit_len as usize, + CopyKind::None => {} + } + } +} + /// Emit the prefix-code header bits for one alphabet. Returns the /// per-symbol code values needed when later emitting data symbols. /// Caller uses these together with the original `HuffStrategy` to diff --git a/src/bzip2/encoder.rs b/src/bzip2/encoder.rs index 8105010..b71d4fd 100644 --- a/src/bzip2/encoder.rs +++ b/src/bzip2/encoder.rs @@ -28,7 +28,7 @@ use super::bwt::bwt_forward; use super::crc::Crc32; use super::huffman::{MAX_CODE_LEN, build_canonical_codes, build_canonical_lengths}; use super::mtf::mtf_forward_reduced; -use super::rle::{rle1_forward, rle2_forward}; +use super::rle::{Rle1Encoder, rle1_forward, rle2_forward}; /// Tunables for the bzip2 encoder. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -71,8 +71,15 @@ enum Phase { pub struct Encoder { config: EncoderConfig, - /// Per-block input buffer (post-RLE-1 will be derived from this). + /// Per-block raw input buffer. Retained verbatim because the block + /// CRC is computed over the raw bytes; `encode_block` re-runs RLE-1 + /// over it. Block boundaries, however, are governed by the *post*- + /// RLE-1 size tracked in `rle1` (matching reference bzip2's + /// `nblock`-based blocking), not by `pending.len()`. pending: Vec, + /// Streaming RLE-1 size tracker mirroring `pending`. Used only to + /// decide when the current block has reached the post-RLE-1 cap. + rle1: Rle1Encoder, /// Encoded bytes waiting to be returned to the caller. out: Vec, /// Index into `out` of the next byte to deliver. @@ -99,6 +106,7 @@ impl Encoder { Self { config, pending: Vec::new(), + rle1: Rle1Encoder::new(), out: Vec::new(), out_idx: 0, header_written: false, @@ -108,8 +116,8 @@ impl Encoder { } } - /// Maximum number of raw-input bytes per block, before RLE-1. We - /// follow the reference upper bound `level * 100_000 - 19`; the + /// Maximum number of post-RLE-1 bytes per block. We follow the + /// reference upper bound `level * 100_000 - 19`; the /// `-19` cushions the worst-case expansion of pathological inputs /// passing through the post-MTF / Huffman layers. fn block_input_cap(&self) -> usize { @@ -135,6 +143,8 @@ impl Encoder { /// `pending` to empty afterwards. fn encode_block(&mut self) { let block: Vec = core::mem::take(&mut self.pending); + // Reset the RLE-1 size tracker for the next block. + self.rle1 = Rle1Encoder::new(); // Sanity: a "no input" block is not produced — we only call // encode_block when pending is non-empty. debug_assert!(!block.is_empty()); @@ -181,45 +191,23 @@ impl Encoder { let alpha_size = num_used + 2; // includes EOB // Step 7: choose number of Huffman tables. bzip2 chooses - // 2..=6 based on the per-block symbol-count buckets. For a - // simple encoder we use a fixed mapping that always produces - // a valid stream. + // 2..=6 based on the per-block symbol-count buckets. let num_tables = pick_num_tables(symbols.len()); - // Step 8: assign each 50-symbol group a Huffman table id, then - // build per-table code-length tables from per-table frequency - // counts. We use the simplest possible split: all groups use - // the same single table, replicated `num_tables` times. This - // is valid bzip2 — the spec only requires 2..=6 distinct - // tables to be present in the header, and selectors to be in - // 0..num_tables. Reusing one table is wasteful (gives up the - // compression edge that having multiple specialised tables - // would provide) but always correct. - // - // However, the spec demands 2..=6 tables — exactly one is - // **not** allowed. So we ship two identical-length tables and - // assign half the groups to table 0 and half to table 1; this - // satisfies the structural requirements without changing the - // bitstream costs vs a single-table encoder. + // Step 8: build the multi-table Huffman assignment exactly the + // way reference bzip2's `sendMTFValues` does: initialise the + // tables by partitioning the cumulative-frequency space, then + // run a fixed number of refinement passes that (a) assign each + // 50-symbol group to whichever table codes it cheapest and + // (b) rebuild each table from the symbols of the groups + // assigned to it. This is where the compression edge over a + // single shared table comes from. let num_selectors_total = symbols.len().div_ceil(50); debug_assert!(num_selectors_total >= 1); debug_assert!(num_selectors_total <= MAX_SELECTORS); - let mut freqs = vec![0u32; alpha_size]; - for &s in &symbols { - freqs[s as usize] += 1; - } - let table_lengths = build_canonical_lengths(&freqs, MAX_CODE_LEN); - - // Build per-table copies (all identical). num_tables ≥ 2. - let tables: Vec> = (0..num_tables).map(|_| table_lengths.clone()).collect(); - - // Each group's selector is just the group index modulo - // num_tables — but we keep them all 0 so frequency-weighted - // length design (which we don't do here) is trivially the - // same. Reference bzip2's selector design picks the cheapest - // table per group; we just pick table 0 everywhere. - let selectors: Vec = vec![0u8; num_selectors_total]; + let (tables, selectors) = + optimize_tables(&symbols, alpha_size, num_tables, num_selectors_total); // Build canonical codes for each table. let codes: Vec> = tables @@ -437,6 +425,145 @@ fn pick_num_tables(n_symbols: usize) -> usize { } } +/// Number of refinement passes over the table/selector assignment. +/// Reference bzip2 uses `BZ_N_ITERS = 4`. +const HUFF_N_ITERS: usize = 4; + +/// Symbol-group size used when assigning selectors. Reference bzip2 +/// uses `BZ_G_SIZE = 50`. +const HUFF_GROUP_SIZE: usize = 50; + +/// A code length placeholder used during table initialisation/cost +/// scoring. Mirrors reference bzip2's `BZ_LESSER_ICOST` (0) and +/// `BZ_GREATER_ICOST` (15). +const ICOST_LESSER: u8 = 0; +const ICOST_GREATER: u8 = 15; + +/// Build `num_tables` Huffman code-length tables and a per-group +/// selector list using reference bzip2's `sendMTFValues` strategy. +/// +/// Returns `(tables, selectors)` where `tables[t]` is the per-symbol +/// code-length array (length `alpha_size`) for table `t`, and +/// `selectors[g]` is the table id (0..num_tables) chosen for the g-th +/// 50-symbol group. +fn optimize_tables( + symbols: &[u16], + alpha_size: usize, + num_tables: usize, + num_groups: usize, +) -> (Vec>, Vec) { + // Global symbol frequencies across the whole block. + let mut global_freq = vec![0u32; alpha_size]; + for &s in symbols { + global_freq[s as usize] += 1; + } + + // ── Initial table construction ──────────────────────────────── + // + // Faithful port of reference bzip2's `sendMTFValues` initialiser: + // partition the alphabet into `num_tables` contiguous bands of + // (roughly) equal total frequency, walking symbols low→high. Band + // `nPart-1` (i.e. tables fill from the highest id down to 0) covers + // the next slice of low symbols; in-band symbols get the cheap + // placeholder length, the rest the expensive one. There is an + // odd-iteration back-off that nudges the band boundary, matching + // the reference exactly so our initial assignment — and therefore + // the refinement that follows — tracks bzip2's. + let mut tables: Vec> = (0..num_tables) + .map(|_| vec![ICOST_GREATER; alpha_size]) + .collect(); + { + let n_mtf = symbols.len() as i64; + let mut n_part = num_tables as i64; + let mut rem_f = n_mtf; + let mut gs = 0i64; + while n_part > 0 { + let t_freq = rem_f / n_part; + let mut ge = gs - 1; + let mut a_freq = 0i64; + while a_freq < t_freq && ge < alpha_size as i64 - 1 { + ge += 1; + a_freq += global_freq[ge as usize] as i64; + } + // Odd-iteration back-off: if this isn't the first part, the + // boundary lands above `gs`, the part index parity is odd, + // and backing off keeps `a_freq` closer to the target. + if ge > gs + && n_part != num_tables as i64 + && n_part != 1 + && ((num_tables as i64 - n_part) % 2 == 1) + { + a_freq -= global_freq[ge as usize] as i64; + ge -= 1; + } + + let lens = &mut tables[(n_part - 1) as usize]; + for (v, slot) in lens.iter_mut().enumerate() { + let vi = v as i64; + if vi >= gs && vi <= ge { + *slot = ICOST_LESSER; + } else { + *slot = ICOST_GREATER; + } + } + + n_part -= 1; + gs = ge + 1; + rem_f -= a_freq; + } + } + + let mut selectors = vec![0u8; num_groups]; + + // ── Refinement passes ───────────────────────────────────────── + for _iter in 0..HUFF_N_ITERS { + // Per-table accumulated frequencies for this pass. + let mut table_freq: Vec> = vec![vec![0u32; alpha_size]; num_tables]; + + // For each group: score it under every table, pick the cheapest, + // record the selector, and fold the group's symbols into the + // winning table's frequency accumulator. + let mut g = 0usize; + let mut group_idx = 0usize; + while g < symbols.len() { + let end = (g + HUFF_GROUP_SIZE).min(symbols.len()); + let group = &symbols[g..end]; + + // Cost of coding this group under each table. + let mut best_table = 0usize; + let mut best_cost = u64::MAX; + for (t, lens) in tables.iter().enumerate() { + let mut cost = 0u64; + for &s in group { + cost += lens[s as usize] as u64; + } + if cost < best_cost { + best_cost = cost; + best_table = t; + } + } + + selectors[group_idx] = best_table as u8; + let acc = &mut table_freq[best_table]; + for &s in group { + acc[s as usize] += 1; + } + + group_idx += 1; + g = end; + } + + // Rebuild each table from the frequencies of the groups assigned + // to it. A table with no assigned groups keeps coverage via the + // `+1` floor inside `build_canonical_lengths`. + for (t, freq) in table_freq.iter().enumerate() { + tables[t] = build_canonical_lengths(freq, MAX_CODE_LEN); + } + } + + (tables, selectors) +} + impl Default for Encoder { fn default() -> Self { Self::new() @@ -469,16 +596,20 @@ impl RawEncoder for Encoder { }); } - // Now accept input, filling up to the per-block cap. If we - // fill, encode that block, drain to output, repeat. + // Now accept input, filling each block up to the per-block + // *post-RLE-1* cap (reference bzip2 sizes blocks by `nblock`, + // the RLE-1 output length). We feed bytes through both the raw + // buffer (kept for the CRC and the in-block RLE-1 re-run) and + // the streaming size tracker, cutting a block the moment the + // tracked RLE-1 length reaches the cap. When a block fills we + // encode it, drain to output, and continue. while consumed < input.len() { - let space = cap - self.pending.len(); - let take = space.min(input.len() - consumed); - self.pending - .extend_from_slice(&input[consumed..consumed + take]); - consumed += take; + let b = input[consumed]; + self.pending.push(b); + self.rle1.push(b); + consumed += 1; - if self.pending.len() == cap { + if self.rle1.encoded_len() >= cap { self.encode_block(); self.flush_full_bytes(); self.drain_out(output, &mut written); @@ -549,6 +680,7 @@ impl RawEncoder for Encoder { fn raw_reset(&mut self) { self.pending.clear(); + self.rle1 = Rle1Encoder::new(); self.out.clear(); self.out_idx = 0; self.header_written = false; diff --git a/src/bzip2/huffman.rs b/src/bzip2/huffman.rs index 34a200b..1a559bd 100644 --- a/src/bzip2/huffman.rs +++ b/src/bzip2/huffman.rs @@ -20,15 +20,17 @@ //! //! ## Encoding (length design + emit) //! -//! For the encoder we just need a length-limited prefix code over the -//! observed symbol frequencies of one Huffman group; bzip2's reference -//! design uses an iterative "moffat" package-merge fallback, but for -//! correctness alone a textbook Huffman tree with depth clamping to -//! the 20-bit ceiling is sufficient. We implement Huffman by repeated -//! merging of the two smallest-weight nodes; if any code length -//! exceeds the ceiling, we scale weights up and retry, which is the -//! simple fixpoint mentioned in *Managing Gigabytes* (Witten, Moffat, -//! Bell) §2.4 and in the bzip2 source's `sendMTFValues` epilogue. +//! For the encoder we need a length-limited prefix code over the +//! observed symbol frequencies of one Huffman group. We port reference +//! bzip2's `BZ2_hbMakeCodeLengths` directly: a min-heap tree build whose +//! key packs the cumulative frequency in the high bits and the subtree +//! depth in the low 8 bits, so equal-frequency merges prefer the +//! shallower subtree. That depth-aware tiebreak reproduces bzip2's exact +//! per-table bit costs. Code lengths are capped at 17 bits (bzip2's +//! design limit since 1.0.3); if any code exceeds the cap the +//! frequencies are halved and the build is retried, exactly as in the +//! reference. The decode side still accepts up to 20 bits for +//! compatibility with streams from pre-1.0.3 encoders. extern crate alloc; use alloc::vec; @@ -215,100 +217,173 @@ impl DecodeTable { /// Compute per-symbol Huffman code lengths from frequency counts. /// /// `freqs[i] > 0` is treated as "symbol i is used"; symbols with -/// `freqs[i] == 0` are assigned the smallest possible nonzero length -/// (so the table still covers them even if they don't appear). The -/// returned lengths are clamped to `max_len` by iteratively scaling -/// down the weights when the natural Huffman depth exceeds the cap. +/// `freqs[i] == 0` are assigned a small nonzero weight (so the table +/// still covers them even if they don't appear). The returned lengths +/// are clamped to `max_len` (reference bzip2 designs with `maxLen = 17`; +/// callers pass `MAX_CODE_LEN = 20`, but bzip2's own length builder is +/// run with 17 so the encoder never emits codes longer than that). /// -/// This is a textbook two-pass Huffman: build a tree by repeatedly -/// merging the two minimum-weight items; if the resulting longest path -/// exceeds the cap, halve all weights and try again. Halving converges -/// because the alphabet is at most 258 symbols (256 bytes + RUNA/RUNB + -/// EOB) so the natural Huffman depth is bounded by O(log φ(n)) ≈ 14 at -/// reasonable distributions; the cap of 20 bits is loose, so we rarely -/// need more than one or two retries even on degenerate inputs. +/// This is a faithful port of reference bzip2's `BZ2_hbMakeCodeLengths` +/// (`huffman.c`). It builds the Huffman tree with a min-heap whose key +/// packs the cumulative frequency in the high bits and the subtree +/// depth in the low 8 bits, so that among equal-frequency merge +/// candidates the **shallower** subtree is preferred. That depth-aware +/// tiebreak yields more balanced trees (shorter maximum code length and +/// marginally better total cost on large blocks) than a frequency-only +/// textbook Huffman build, and is what lets our output match the +/// reference's per-table bit costs. If any code still exceeds `max_len` +/// the frequencies are halved and the build is retried, exactly as in +/// the reference. pub(crate) fn build_canonical_lengths(freqs: &[u32], max_len: usize) -> Vec { - let n = freqs.len(); - let mut weights: Vec = freqs.iter().map(|&f| if f == 0 { 1 } else { f }).collect(); - - loop { - let lengths = compute_lengths(&weights); - let mx = lengths.iter().copied().max().unwrap_or(0) as usize; - if mx <= max_len { - // Symbols that weren't actually used still get a non-zero - // length (we initialised their weights to 1); the table - // serialiser may treat them however it wants. bzip2 just - // emits the canonical code anyway. - return lengths; - } - // Scale weights down by halving (rounding up to keep all - // values > 0) and retry. - for w in weights.iter_mut() { - *w = (*w).div_ceil(2).max(1); - } - // After scaling everything to 1 the natural Huffman depth is - // ⌈log₂ n⌉ which for n ≤ 258 is at most 9 — well under 20 — - // so the loop always terminates within a few iterations. - if n <= 1 { - // Degenerate alphabet; just return the singletons at len 1. - return vec![1u8; n.max(1)]; - } - } + // bzip2 caps the design length at 17; honour whatever the caller + // passes but never exceed 17 internally so we stay byte-for-byte + // compatible with reference output where it matters. + let design_max = max_len.min(17); + hb_make_code_lengths(freqs, design_max) } -/// Compute Huffman lengths from a weight vector using the textbook -/// two-pass tree-build. -/// -/// We represent the partial tree as an array of length 2N parents: -/// internal nodes occupy indices ≥ N, leaves occupy 0..N. Each merge -/// step links two minimum-weight active nodes under a fresh internal -/// node; once the tree is built we walk parent links to compute each -/// leaf's depth. -fn compute_lengths(weights: &[u32]) -> Vec { - let n = weights.len(); - if n == 0 { +/// Direct port of `BZ2_hbMakeCodeLengths`. Weights pack +/// `frequency << 8 | depth`; merges add the frequencies and set the +/// depth to `1 + max(depth_a, depth_b)`. +fn hb_make_code_lengths(freqs: &[u32], max_len: usize) -> Vec { + let alpha_size = freqs.len(); + if alpha_size == 0 { return Vec::new(); } - if n == 1 { + if alpha_size == 1 { return vec![1]; } - // Heap of (weight, node_id). Implemented as a sorted vector since - // n ≤ 258; the constant factor on the binary-heap path is not - // worth the complexity for this size. - let mut alive: Vec<(u64, usize)> = weights + // Nodes and heap entries are 1-based; index 0 is a sentinel, exactly + // as in the C source. `weight`/`parent` need room for up to + // `2*alpha_size` nodes (leaves + internal), `heap` for `alpha_size+2`. + let cap_nodes = alpha_size * 2 + 2; + let mut weight = vec![0i64; cap_nodes]; + let mut parent = vec![0i32; cap_nodes]; + let mut heap = vec![0i32; alpha_size + 2]; + + // Initial leaf weights: (freq or 1) << 8, depth 0 in the low byte. + let mut cur_freq: Vec = freqs .iter() - .enumerate() - .map(|(i, &w)| (w as u64, i)) + .map(|&f| if f == 0 { 1i64 } else { f as i64 }) .collect(); - let mut parent: Vec = vec![usize::MAX; 2 * n]; - let mut next_node = n; - - while alive.len() > 1 { - // Sort descending so we pop the two smallest off the back. - alive.sort_by_key(|b| core::cmp::Reverse(b.0)); - let (w1, n1) = alive.pop().unwrap(); - let (w2, n2) = alive.pop().unwrap(); - parent[n1] = next_node; - parent[n2] = next_node; - alive.push((w1 + w2, next_node)); - next_node += 1; + + const DEPTH_MASK: i64 = 0x0000_00ff; + fn weight_of(w: i64) -> i64 { + w & !DEPTH_MASK } + fn depth_of(w: i64) -> i64 { + w & DEPTH_MASK + } + fn add_weights(a: i64, b: i64) -> i64 { + (weight_of(a) + weight_of(b)) | (1 + core::cmp::max(depth_of(a), depth_of(b))) + } + + loop { + for i in 0..alpha_size { + weight[i + 1] = cur_freq[i] << 8; + } - // Walk parent links from each leaf to the root counting depth. - let mut lengths = vec![0u8; n]; - for leaf in 0..n { - let mut depth = 0u32; - let mut node = parent[leaf]; - while node != usize::MAX { - depth += 1; - node = parent[node]; + let mut n_nodes = alpha_size as i32; + let mut n_heap = 0i32; + + heap[0] = 0; + weight[0] = 0; + parent[0] = -2; + + // UPHEAP / DOWNHEAP operate on `heap`, keyed by `weight`. + for i in 1..=alpha_size as i32 { + parent[i as usize] = -1; + n_heap += 1; + heap[n_heap as usize] = i; + // UPHEAP(n_heap) + let mut zz = n_heap; + let tmp = heap[zz as usize]; + while weight[tmp as usize] < weight[heap[(zz >> 1) as usize] as usize] { + heap[zz as usize] = heap[(zz >> 1) as usize]; + zz >>= 1; + } + heap[zz as usize] = tmp; + } + + while n_heap > 1 { + let n1 = heap[1]; + heap[1] = heap[n_heap as usize]; + n_heap -= 1; + downheap(&mut heap, &weight, n_heap, 1); + + let n2 = heap[1]; + heap[1] = heap[n_heap as usize]; + n_heap -= 1; + downheap(&mut heap, &weight, n_heap, 1); + + n_nodes += 1; + parent[n1 as usize] = n_nodes; + parent[n2 as usize] = n_nodes; + weight[n_nodes as usize] = add_weights(weight[n1 as usize], weight[n2 as usize]); + parent[n_nodes as usize] = -1; + n_heap += 1; + heap[n_heap as usize] = n_nodes; + // UPHEAP(n_heap) + let mut zz = n_heap; + let tmp = heap[zz as usize]; + while weight[tmp as usize] < weight[heap[(zz >> 1) as usize] as usize] { + heap[zz as usize] = heap[(zz >> 1) as usize]; + zz >>= 1; + } + heap[zz as usize] = tmp; + } + + // Compute lengths by walking parent links; detect over-long codes. + let mut lengths = vec![0u8; alpha_size]; + let mut too_long = false; + for i in 1..=alpha_size { + let mut j = 0i32; + let mut k = i as i32; + while parent[k as usize] >= 0 { + k = parent[k as usize]; + j += 1; + } + lengths[i - 1] = j as u8; + if j as usize > max_len { + too_long = true; + } + } + + if !too_long { + return lengths; + } + + // Scale frequencies: j = weight>>8; j = 1 + j/2. + for f in cur_freq.iter_mut() { + let j = *f; + *f = 1 + (j / 2); } - // depth 0 only happens when n == 1 (root = leaf); that case - // was returned above. - lengths[leaf] = depth.max(1) as u8; } - lengths +} + +/// DOWNHEAP(z) from the bzip2 source, operating on the 1-based `heap` +/// array of length `n_heap`, keyed by `weight`. +fn downheap(heap: &mut [i32], weight: &[i64], n_heap: i32, z: i32) { + let mut zz = z; + let tmp = heap[zz as usize]; + loop { + let mut yy = zz << 1; + if yy > n_heap { + break; + } + if yy < n_heap + && weight[heap[(yy + 1) as usize] as usize] < weight[heap[yy as usize] as usize] + { + yy += 1; + } + if weight[tmp as usize] < weight[heap[yy as usize] as usize] { + break; + } + heap[zz as usize] = heap[yy as usize]; + zz = yy; + } + heap[zz as usize] = tmp; } /// Build the canonical (code, length) table from a per-symbol length @@ -391,4 +466,29 @@ mod tests { assert!(lens.iter().all(|&l| (1..=20).contains(&l))); assert_eq!(lens.len(), freqs.len()); } + + #[test] + fn build_lengths_caps_at_17_and_is_kraft_valid() { + // The reference-faithful builder must (a) never emit a code + // longer than 17 bits, and (b) always produce a Kraft-valid + // canonical prefix code that `DecodeTable::from_lengths` + // accepts — even for skewed and degenerate distributions. + let cases: alloc::vec::Vec> = alloc::vec![ + alloc::vec![1, 1], + alloc::vec![0, 0, 0, 5], + alloc::vec![1000000, 1, 1, 1, 1, 1, 1, 1], + (0..50u32).map(|i| 1 << (i % 24)).collect(), + alloc::vec![1u32; 258], + ]; + for freqs in &cases { + let lens = build_canonical_lengths(freqs, MAX_CODE_LEN); + assert_eq!(lens.len(), freqs.len()); + assert!( + lens.iter().all(|&l| (1..=17).contains(&l)), + "length out of 1..=17: {lens:?}" + ); + // Must round-trip through the decode-table builder. + DecodeTable::from_lengths(&lens).expect("builder produced a non-Kraft-valid table"); + } + } } diff --git a/src/bzip2/rle.rs b/src/bzip2/rle.rs index 1c36641..4fa296a 100644 --- a/src/bzip2/rle.rs +++ b/src/bzip2/rle.rs @@ -79,6 +79,93 @@ pub(crate) fn rle1_forward(input: &[u8]) -> Vec { out } +/// Streaming RLE-1 encoder used by the block-builder. +/// +/// Reference bzip2 sizes each block by its **post-RLE-1** length +/// (`nblock`), not by the count of raw input bytes — so on compressible +/// data it packs more raw bytes into a 900 KB block than a raw-input cap +/// would. To match that (and the reference's block count and ratio) we +/// feed raw bytes through this incremental encoder and cut a block once +/// the emitted length reaches the per-block cap. +/// +/// The encoder tracks the in-progress run across `push` calls so a run +/// straddling a call boundary is encoded identically to the one-shot +/// [`rle1_forward`]. +pub(crate) struct Rle1Encoder { + out: Vec, + /// Byte value of the current run (valid iff `run > 0`). + run_byte: u8, + /// Length of the current pending run (1..=255 while building). + run: usize, +} + +impl Rle1Encoder { + pub(crate) fn new() -> Self { + Self { + out: Vec::new(), + run_byte: 0, + run: 0, + } + } + + /// Current emitted (post-RLE-1) length, **including** the bytes the + /// pending run will contribute once flushed. Used by the caller to + /// decide when a block is full. + pub(crate) fn encoded_len(&self) -> usize { + self.out.len() + Self::run_cost(self.run) + } + + /// How many output bytes a finished run of `run` identical bytes + /// costs: runs <4 are verbatim, runs >=4 cost 4 literals + 1 count. + fn run_cost(run: usize) -> usize { + if run == 0 { + 0 + } else if run < 4 { + run + } else { + 5 + } + } + + /// Flush the pending run into `out`. + fn flush_run(&mut self) { + let b = self.run_byte; + if self.run == 0 { + return; + } + if self.run < 4 { + for _ in 0..self.run { + self.out.push(b); + } + } else { + self.out.push(b); + self.out.push(b); + self.out.push(b); + self.out.push(b); + self.out.push((self.run - 4) as u8); + } + self.run = 0; + } + + /// Feed one raw byte. + pub(crate) fn push(&mut self, b: u8) { + if self.run > 0 && b == self.run_byte && self.run < 255 { + self.run += 1; + return; + } + // Different byte, or the 255-cap reached: flush and start anew. + self.flush_run(); + self.run_byte = b; + self.run = 1; + } + + /// Finish: flush the pending run and return the encoded block. + pub(crate) fn finish(mut self) -> Vec { + self.flush_run(); + self.out + } +} + /// Invert bzip2's RLE-1 pre-pass. Streaming-friendly: consumes the full /// `input` and returns the reconstituted raw bytes. pub(crate) fn rle1_inverse(input: &[u8]) -> Vec { @@ -264,6 +351,60 @@ mod tests { assert_eq!(rle1_inverse(&r), v); } + /// Feed `input` through the streaming encoder one byte at a time + /// and return the finished output. + fn rle1_stream(input: &[u8]) -> Vec { + let mut e = Rle1Encoder::new(); + for &b in input { + e.push(b); + } + e.finish() + } + + #[test] + fn rle1_stream_matches_oneshot() { + // The streaming encoder must produce byte-for-byte the same + // output as the one-shot `rle1_forward` on a range of inputs, + // including run-straddling and the 255-cap boundary. + let cases: Vec> = vec![ + b"".to_vec(), + b"a".to_vec(), + b"abcabc".to_vec(), + b"aaaa".to_vec(), + b"aaaaaaa".to_vec(), + vec![b'a'; 255], + vec![b'a'; 256], + vec![b'a'; 600], + { + let mut v = vec![b'x'; 300]; + v.extend(vec![b'y'; 5]); + v.extend(b"zzz"); + v.extend(vec![b'q'; 1000]); + v + }, + ]; + for c in &cases { + assert_eq!(rle1_stream(c), rle1_forward(c), "mismatch len {}", c.len()); + } + } + + #[test] + fn rle1_stream_encoded_len_tracks_output() { + // `encoded_len()` must equal the final output length at finish, + // including the pending run's contribution. + for n in [0usize, 1, 3, 4, 5, 254, 255, 256, 700] { + let input = vec![b'a'; n]; + let mut e = Rle1Encoder::new(); + for &b in &input { + e.push(b); + } + let predicted = e.encoded_len(); + let out = e.finish(); + assert_eq!(predicted, out.len(), "encoded_len mismatch for n={n}"); + assert_eq!(out, rle1_forward(&input)); + } + } + #[test] fn rle2_round_trip() { // Use a synthetic MTF stream with zeros mixed in. diff --git a/src/factory.rs b/src/factory.rs index 4985f5f..dcbe901 100644 --- a/src/factory.rs +++ b/src/factory.rs @@ -255,6 +255,19 @@ pub fn encoder_by_name_with_level(name: &str, level: u8) -> Option Some(Box::new( ::encoder_with(crate::bzip2::EncoderConfig { level }), )), + #[cfg(feature = "lz4")] + crate::lz4::Lz4::NAME => Some(Box::new(::encoder_with( + crate::lz4::EncoderConfig { level }, + ))), + #[cfg(feature = "lz4")] + crate::lz4::frame::LZ4Frame::NAME => Some(Box::new( + ::encoder_with( + crate::lz4::frame::EncoderConfig { + level, + ..crate::lz4::frame::EncoderConfig::default() + }, + ), + )), // Non-leveled algorithms: ignore `level`, return default encoder. _ => encoder_by_name(name), } diff --git a/src/lz4/block.rs b/src/lz4/block.rs index 69f8970..7125eb2 100644 --- a/src/lz4/block.rs +++ b/src/lz4/block.rs @@ -6,6 +6,18 @@ //! buffer and produce a complete output buffer. The streaming wrapper in //! [`super`] is responsible for chunking arbitrarily large inputs into blocks //! of bounded size and re-assembling them on decode. +//! +//! Two parses share the same bitstream emitter, so every block — fast or +//! high-compression — decodes with the exact same decoder (ours and the +//! reference `lz4` tool): +//! +//! * The **fast** parse ([`encode_block`]) is a single-hash greedy matcher +//! with LZ4's skip-step acceleration. It is the speed-crown path used for +//! low levels. +//! * The **HC** parse ([`encode_block_level`] at higher levels) is an +//! LZ4-HC-style match finder: a hash-chain (head + prev) walk that finds the +//! *longest* match within the 64 KiB window, plus one-step lazy matching. +//! Search depth scales with the level. use alloc::vec::Vec; @@ -20,7 +32,7 @@ const LAST_LITERALS: usize = 5; /// Last match must start at least 12 bytes before the end of the block. const MFLIMIT: usize = 12; -/// Size of the encoder's hash table (entries are `u32` block offsets). +/// Size of the fast encoder's hash table (entries are `u32` block offsets). /// /// 12 bits = 4096 entries × 4 bytes = 16 KiB scratch — small enough to fit /// comfortably in cache, large enough to find most useful matches in a @@ -28,10 +40,26 @@ const MFLIMIT: usize = 12; const HASH_LOG: u32 = 12; const HASH_TABLE_SIZE: usize = 1 << HASH_LOG; +/// Hash-table size for the HC (hash-chain) match finder. A wider table than +/// the fast path reduces collisions so chains stay short and on-topic, which +/// improves both match quality and the cost of the bounded chain walk. +const HC_HASH_LOG: u32 = 15; +const HC_HASH_TABLE_SIZE: usize = 1 << HC_HASH_LOG; + /// Sentinel for an empty hash slot. `u32::MAX` is safe because block sizes /// are bounded by the streaming wrapper to fit in a `u32`. const HASH_EMPTY: u32 = u32::MAX; +/// Lowest level that engages the HC (hash-chain + lazy) parse. Levels below +/// this use the fast greedy parse (preserving LZ4's speed crown). +const HC_LEVEL_THRESHOLD: u8 = 3; + +/// Lowest level that engages the price-based optimal parse. Levels in +/// `HC_LEVEL_THRESHOLD..OPT_LEVEL_THRESHOLD` use the lazy HC parse; this level +/// and above run a forward dynamic-programming parse that minimises the +/// encoded byte cost. +const OPT_LEVEL_THRESHOLD: u8 = 10; + /// Hash 4 bytes down to `HASH_LOG` bits. /// /// Uses the classic LZ4 multiply-and-shift hash. `2654435761` is Knuth's @@ -42,6 +70,13 @@ fn hash4(bytes: [u8; 4]) -> usize { ((v.wrapping_mul(2_654_435_761)) >> (32 - HASH_LOG)) as usize } +/// Hash 4 bytes down to `HC_HASH_LOG` bits (HC parse). +#[inline] +fn hc_hash4(bytes: [u8; 4]) -> usize { + let v = u32::from_le_bytes(bytes); + ((v.wrapping_mul(2_654_435_761)) >> (32 - HC_HASH_LOG)) as usize +} + /// Worst-case encoded-length bound for `input_len` bytes of input. /// /// Matches the canonical `LZ4_compressBound` formula. The encoder uses this @@ -52,9 +87,9 @@ pub fn compress_bound(input_len: usize) -> usize { /// Encode `input` as a single LZ4 block into `out` (which is cleared first). /// -/// Returns the number of bytes written. Inputs of any length are accepted; -/// inputs shorter than `MFLIMIT + 1` are emitted as a literal-only sequence, -/// as required by the spec. +/// This is the fast greedy parse (low-level / default speed path). Inputs of +/// any length are accepted; inputs shorter than `MFLIMIT + 1` are emitted as a +/// literal-only sequence, as required by the spec. pub fn encode_block(input: &[u8], out: &mut Vec) { out.clear(); if input.is_empty() { @@ -74,10 +109,11 @@ pub fn encode_block(input: &[u8], out: &mut Vec) { let mut anchor: usize = 0; // start of the current pending literal run // Position of the last byte we are allowed to start a match at. Anything - // past `match_limit` must be emitted as trailing literals. + // past `match_limit` must be emitted as trailing literals. (Note this is + // the *match-start* bound, len - MFLIMIT, which is stricter than the + // hashable bound len - MIN_MATCH - LAST_LITERALS — the spec forbids a + // match starting in the final MFLIMIT bytes.) let match_limit = input.len() - MFLIMIT; - // Position of the last byte we are allowed to *read* a 4-byte hash from. - let hash_limit = input.len() - MIN_MATCH - LAST_LITERALS; // The first byte is never the start of a match in our matcher; insert it // into the table so subsequent positions can refer to it. @@ -94,7 +130,14 @@ pub fn encode_block(input: &[u8], out: &mut Vec) { // skip faster over incompressible data instead of probing every byte. let mut match_pos; loop { - if ip > hash_limit { + // A match may only *start* at or before `match_limit` (the spec + // requires the last match to begin at least MFLIMIT bytes before + // the block end). `hash_limit` (len - 4 - 5) is larger than + // `match_limit` (len - 12), so bounding the probe at `hash_limit` + // could find a match starting in the forbidden tail region — a + // block the strict reference decoder rejects. Stop at + // `match_limit`; the rest becomes trailing literals. + if ip > match_limit { emit_last_literals(&input[anchor..], out); return; } @@ -173,6 +216,494 @@ pub fn encode_block(input: &[u8], out: &mut Vec) { emit_last_literals(&input[anchor..], out); } +/// Encode `input` as a single LZ4 block at compression `level`. +/// +/// `level` selects the parse strategy and search effort: +/// +/// * `level <` `HC_LEVEL_THRESHOLD` — delegate to the fast greedy +/// [`encode_block`] (LZ4's speed path). +/// * `level >=` `HC_LEVEL_THRESHOLD` — the HC parse: a hash-chain match +/// finder that searches up to `nb_attempts` candidates per position for the +/// *longest* match, plus one-step lazy matching. `nb_attempts` grows with +/// the level, so higher levels trade speed for ratio. +/// +/// The emitted bitstream is byte-for-byte a valid LZ4 block in every case — +/// only the parse changes, so the reference `lz4` decoder reads it unchanged. +pub fn encode_block_level(input: &[u8], out: &mut Vec, level: u8) { + if level < HC_LEVEL_THRESHOLD { + encode_block(input, out); + return; + } + if level < OPT_LEVEL_THRESHOLD { + encode_block_hc(input, out, level); + return; + } + encode_block_optimal(input, out, level); +} + +/// Map a compression level to a hash-chain search depth (`nb_attempts`). +/// +/// Depth roughly doubles every level, mirroring the spirit of reference +/// LZ4-HC: higher levels search deeper for the longest match. The window is +/// only 64 KiB so even the deepest setting stays bounded. +fn nb_attempts_for_level(level: u8) -> u32 { + match level { + 0..=3 => 8, + 4 => 16, + 5 => 32, + 6 => 64, + 7 => 128, + 8 => 256, + 9 => 512, + 10 => 1024, + 11 => 2048, + _ => 4096, + } +} + +/// Insert position `p` into the hash chain. The 4-byte read requires +/// `p + 4 <= input.len()`, guaranteed by the caller (`p <= hash_limit`). +#[inline] +fn hc_insert(input: &[u8], p: usize, head: &mut [u32], chain: &mut [u32]) { + let h = hc_hash4([input[p], input[p + 1], input[p + 2], input[p + 3]]); + chain[p] = head[h]; + head[h] = p as u32; +} + +/// Find the longest match for the 4 bytes at `pos` by walking the hash chain. +/// +/// Returns `(match_pos, match_len)` for the best forward match whose length is +/// at least `MIN_MATCH`, or `None`. Forward extension only — the caller applies +/// backward extension so it can clamp the start at the current anchor. +/// Candidates are strictly older positions on the chain, so self-matches are +/// impossible regardless of whether `pos` has been inserted yet. +fn hc_longest_match( + input: &[u8], + pos: usize, + head: &[u32], + chain: &[u32], + nb_attempts: u32, + forward_limit: usize, +) -> Option<(usize, usize)> { + let h = hc_hash4([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]); + let mut cand = head[h]; + let min_pos = pos.saturating_sub(MAX_DISTANCE); + + let mut best_len = MIN_MATCH - 1; + let mut best_pos = 0usize; + let mut attempts = nb_attempts; + + while cand != HASH_EMPTY && attempts > 0 { + let c = cand as usize; + if c >= pos { + // Only older positions are valid back-references. (Can only happen + // for a stale/self entry; skip defensively without trusting it.) + cand = chain[c]; + attempts -= 1; + continue; + } + if c < min_pos { + break; // chain is ordered newest->oldest; we've left the window. + } + // Cheap reject: a longer match requires the byte at `best_len` to + // agree (and the first byte, as a quick filter). + if pos + best_len < forward_limit + && input[c + best_len] == input[pos + best_len] + && input[c] == input[pos] + { + let mut l = 0usize; + while pos + l < forward_limit && input[c + l] == input[pos + l] { + l += 1; + } + if l > best_len { + best_len = l; + best_pos = c; + if pos + best_len >= forward_limit { + break; // cannot grow further + } + } + } + cand = chain[c]; + attempts -= 1; + } + + if best_len < MIN_MATCH { + None + } else { + Some((best_pos, best_len)) + } +} + +/// Apply backward extension to a forward match `(match_pos, len)` found at +/// `pos`, sliding the start earlier while bytes agree, clamped so the start +/// never crosses `anchor`. Returns `(start, match_pos, len)`. +#[inline] +fn hc_resolve( + input: &[u8], + pos: usize, + found: (usize, usize), + anchor: usize, +) -> (usize, usize, usize) { + let (mut mpos, mut mlen) = found; + let mut spos = pos; + while spos > anchor && mpos > 0 && input[spos - 1] == input[mpos - 1] { + spos -= 1; + mpos -= 1; + mlen += 1; + } + (spos, mpos, mlen) +} + +/// LZ4-HC-style match finder + parse (used for higher levels). +/// +/// Maintains a hash-chain over 4-byte sequences (`head[hash]` = most recent +/// position; `chain[pos]` = previous position sharing that hash). For each +/// candidate start it walks the chain up to `nb_attempts` links and keeps the +/// longest match inside the 64 KiB window. A one-step lazy heuristic defers a +/// match when the next position offers a strictly longer one. +fn encode_block_hc(input: &[u8], out: &mut Vec, level: u8) { + out.clear(); + if input.is_empty() { + return; + } + if input.len() < MFLIMIT + 1 { + emit_last_literals(input, out); + return; + } + + let n = input.len(); + let nb_attempts = nb_attempts_for_level(level); + + let mut head = alloc::vec![HASH_EMPTY; HC_HASH_TABLE_SIZE]; + let mut chain = alloc::vec![HASH_EMPTY; n]; + + let match_limit = n - MFLIMIT; // last position a match may start at + let hash_limit = n - MIN_MATCH - LAST_LITERALS; // last hashable position + let forward_limit = n - LAST_LITERALS; // last 5 bytes stay literal + + // `inserted_through` is the count of positions already recorded in the + // chain: positions [0, inserted_through) are inserted. We insert lazily so + // each position is inserted exactly once and the chain stays strictly + // ordered newest-first. + let mut inserted_through: usize = 0; + let mut anchor: usize = 0; + let mut ip: usize = 0; + + // Insert all hashable positions in [inserted_through, up_to). + macro_rules! insert_up_to { + ($up_to:expr) => {{ + let up_to = $up_to; + while inserted_through < up_to && inserted_through <= hash_limit { + hc_insert(input, inserted_through, &mut head, &mut chain); + inserted_through += 1; + } + }}; + } + + while ip <= match_limit && ip <= hash_limit { + // Ensure positions up to and including `ip` are in the chain. + insert_up_to!(ip + 1); + + let found = hc_longest_match(input, ip, &head, &chain, nb_attempts, forward_limit); + let (mut cur_start, mut cur_mpos, mut cur_len) = match found { + None => { + ip += 1; + continue; + } + Some(f) => hc_resolve(input, ip, f, anchor), + }; + + // One-step lazy matching: while the next position offers a strictly + // longer match, defer (the current first byte becomes a literal) and + // chase the better match from there. + loop { + let next = cur_start + 1; + if next > match_limit || next > hash_limit { + break; + } + insert_up_to!(next + 1); + if let Some(f) = + hc_longest_match(input, next, &head, &chain, nb_attempts, forward_limit) + { + let (ns, nmp, nl) = hc_resolve(input, next, f, anchor); + if nl > cur_len { + cur_start = ns; + cur_mpos = nmp; + cur_len = nl; + continue; + } + } + break; + } + + // Emit literals [anchor, cur_start) followed by the match. + let literal_len = cur_start - anchor; + let offset = (cur_start - cur_mpos) as u16; + let match_excess = cur_len - MIN_MATCH; + emit_sequence( + &input[anchor..cur_start], + literal_len, + offset, + match_excess, + out, + ); + + let match_end = cur_start + cur_len; + // Insert every position the match covers so later matches can point + // inside it. `insert_up_to!` skips any already inserted by the lazy + // walk, keeping the chain strictly ordered. + insert_up_to!(match_end); + + anchor = match_end; + ip = match_end; + } + + emit_last_literals(&input[anchor..], out); +} + +/// Encoded byte cost of `litlen` literals, per the LZ4 token/run-length rules. +/// +/// The literal payload is `litlen` bytes; if `litlen >= 15` the run-length +/// nibble overflows and one or more extension bytes are appended: +/// `1 + (litlen - 15) / 255`. The token nibble itself is billed once per +/// sequence (see [`sequence_overhead`]), not here. +#[inline] +fn literals_price(litlen: usize) -> usize { + let mut price = litlen; + if litlen >= 15 { + price += 1 + (litlen - 15) / 255; + } + price +} + +/// Marginal cost of extending a literal run from length `run` to `run + 1`: +/// always 1 byte for the new literal, plus 1 more whenever the new length +/// crosses a run-length extension boundary (15, then every 255 after). +#[inline] +fn marginal_literal_price(run: usize) -> usize { + 1 + (literals_price(run + 1) - literals_price(run) - 1) +} + +/// Fixed per-sequence overhead beyond the coupled literals: 1 token byte + +/// 2 offset bytes, plus the match-length run-extension bytes once the match +/// length nibble overflows (`mlen >= 19`). +#[inline] +fn sequence_overhead(mlen: usize) -> usize { + let mut price = 1 + 2; // token + 16-bit offset + if mlen >= ML_MASK_PLUS_MIN { + price += 1 + (mlen - ML_MASK_PLUS_MIN) / 255; + } + price +} + +/// `ML_MASK (15) + MINMATCH (4)` — the match length at which the match-length +/// nibble first overflows into extension bytes. +const ML_MASK_PLUS_MIN: usize = 15 + MIN_MATCH; + +/// Match length at or beyond which the optimal parse stops enumerating every +/// shorter length and simply takes the whole match. A match this long is +/// effectively always worth taking in full (3 bytes of overhead amortised over +/// 64+ bytes), and the cap keeps the per-position inner loop bounded so highly +/// repetitive inputs stay near-linear instead of O(n²). Mirrors the role of +/// `sufficient_len` in the reference LZ4-HC optimal parser. +const OPT_SUFFICIENT_LEN: usize = 64; + +/// One step of the chosen parse path, recovered by backtracking the DP. +#[derive(Clone, Copy)] +struct OptStep { + /// Length of the literal run preceding this position's incoming edge. + litlen: usize, + /// `match_pos` of the incoming match, or `usize::MAX` for a literal step. + match_pos: usize, + /// Match length of the incoming edge (0 for a literal step). + match_len: usize, +} + +/// Price-based optimal parse (top levels). +/// +/// Runs a forward dynamic program over the block: `price[i]` is the minimal +/// encoded byte cost to reach position `i`. Each position can advance by a +/// single literal (marginal literal price, tracking the run length so the +/// run-length token overflow is charged accurately) or by any match found via +/// the hash-chain finder (sequence overhead + the literal run it terminates). +/// Backtracking recovers the cheapest path, which is then emitted with the +/// shared sequence emitter — so the bitstream stays a valid LZ4 block. +fn encode_block_optimal(input: &[u8], out: &mut Vec, level: u8) { + out.clear(); + if input.is_empty() { + return; + } + if input.len() < MFLIMIT + 1 { + emit_last_literals(input, out); + return; + } + + let n = input.len(); + let nb_attempts = nb_attempts_for_level(level); + + let mut head = alloc::vec![HASH_EMPTY; HC_HASH_TABLE_SIZE]; + let mut chain = alloc::vec![HASH_EMPTY; n]; + + let match_limit = n - MFLIMIT; // last position a match may start at + let hash_limit = n - MIN_MATCH - LAST_LITERALS; // last hashable position + let forward_limit = n - LAST_LITERALS; // last 5 bytes stay literal + + // DP arrays over positions 0..=n. + // `price[i]` = min cost to encode input[0..i]. + // `run[i]` = literal-run length ending at i on the best path to i. + // `step[i]` = the incoming edge used to reach i (for backtracking). + let mut price = alloc::vec![usize::MAX; n + 1]; + let mut run = alloc::vec![0usize; n + 1]; + let mut step = alloc::vec![ + OptStep { + litlen: 0, + match_pos: usize::MAX, + match_len: 0, + }; + n + 1 + ]; + price[0] = 0; + + // Insert all positions up to `up_to` (exclusive) that are hashable. + let mut inserted_through = 0usize; + macro_rules! insert_up_to { + ($up_to:expr) => {{ + let up_to = $up_to; + while inserted_through < up_to && inserted_through <= hash_limit { + hc_insert(input, inserted_through, &mut head, &mut chain); + inserted_through += 1; + } + }}; + } + + let mut i = 0usize; + while i < n { + if price[i] == usize::MAX { + i += 1; + continue; // unreachable position + } + let cur_price = price[i]; + let cur_run = run[i]; + + // Literal edge: advance one byte, extending the literal run. + { + let lit_cost = cur_price + marginal_literal_price(cur_run); + if lit_cost < price[i + 1] { + price[i + 1] = lit_cost; + run[i + 1] = cur_run + 1; + step[i + 1] = OptStep { + litlen: cur_run + 1, + match_pos: usize::MAX, + match_len: 0, + }; + } + } + + // Match edges: only valid starting positions can begin a match, and + // only where a 4-byte hash is readable. + if i > match_limit || i > hash_limit { + i += 1; + continue; + } + insert_up_to!(i + 1); + let found = hc_longest_match(input, i, &head, &chain, nb_attempts, forward_limit); + let (best_pos, best_len) = match found { + Some(f) => f, + None => { + i += 1; + continue; + } + }; + + // The literal run that *would* precede this match was already paid for + // in `cur_price`/`cur_run`. Emitting a match terminates that run, so + // the new sequence's coupled-literal price equals what we already + // charged for the run — i.e. taking the match adds only the sequence + // overhead. (The token nibble that also encodes the literal length is + // the single token byte we add here.) + // + // For a sufficiently long match, shorter splits are never preferable: + // record the full-length edge, insert the positions it covers so later + // matches can chain inside it, and fast-forward past the interior. This + // keeps highly-repetitive inputs near-linear (no O(n²) DP sweep), while + // the global DP still chooses among long matches and literal runs. + if best_len >= OPT_SUFFICIENT_LEN { + let end = i + best_len; + let cost = cur_price + sequence_overhead(best_len); + if cost < price[end] { + price[end] = cost; + run[end] = 0; + step[end] = OptStep { + litlen: cur_run, + match_pos: best_pos, + match_len: best_len, + }; + } + insert_up_to!(end); + i = end; + continue; + } + + // Short match: enumerate every length in [MIN_MATCH, best_len]; a + // shorter match can line up a cheaper continuation, which is exactly + // what the DP weighs. + for mlen in MIN_MATCH..=best_len { + let end = i + mlen; + if end > n { + break; + } + let cost = cur_price + sequence_overhead(mlen); + if cost < price[end] { + price[end] = cost; + run[end] = 0; + step[end] = OptStep { + litlen: cur_run, + match_pos: best_pos, + match_len: mlen, + }; + } + } + i += 1; + } + + // Backtrack from n to 0, collecting the path edges in reverse. + let mut path: Vec = Vec::new(); + let mut pos = n; + while pos > 0 { + let s = step[pos]; + if s.match_pos == usize::MAX { + // Literal edge: step back one byte. Collapse a contiguous literal + // run into the match step that follows; here we just step. + pos -= 1; + } else { + path.push(s); + pos -= s.match_len; + } + } + path.reverse(); + + // Replay forward, emitting literals then each match. + let mut anchor = 0usize; + for s in &path { + let match_start = { + // The match's start position is the end-of-literal-run point. We + // reconstruct it from the literal run length recorded on the edge. + anchor + s.litlen + }; + let offset = (match_start - s.match_pos) as u16; + let match_excess = s.match_len - MIN_MATCH; + emit_sequence( + &input[anchor..match_start], + s.litlen, + offset, + match_excess, + out, + ); + anchor = match_start + s.match_len; + } + emit_last_literals(&input[anchor..], out); +} + /// Write a single sequence (literals + offset + match-length excess). fn emit_sequence( literals: &[u8], @@ -360,6 +891,14 @@ mod tests { assert_eq!(decoded, data); } + fn round_trip_level(data: &[u8], level: u8) { + let mut encoded = Vec::new(); + encode_block_level(data, &mut encoded, level); + let mut decoded = Vec::new(); + decode_block(&encoded, &mut decoded, usize::MAX).expect("decode"); + assert_eq!(decoded, data, "round-trip mismatch at level {level}"); + } + #[test] fn empty() { round_trip(&[]); @@ -384,4 +923,135 @@ mod tests { } round_trip(&v); } + + #[test] + fn hc_round_trip_all_levels() { + let mut text = Vec::new(); + for _ in 0..200 { + text.extend_from_slice(b"the quick brown fox jumps over the lazy dog. "); + } + // Pseudo-random data exercises the no-match / chain-miss paths. + let mut prng = Vec::new(); + let mut s: u32 = 0x1234_5678; + for _ in 0..8192 { + s = s.wrapping_mul(1_103_515_245).wrapping_add(12345); + prng.push((s >> 16) as u8); + } + for level in 0..=12u8 { + round_trip_level(&text, level); + round_trip_level(b"hello", level); + round_trip_level(&[], level); + round_trip_level(&alloc::vec![b'x'; 5000], level); + round_trip_level(&prng, level); + } + } + + #[test] + fn hc_not_worse_than_fast() { + let mut v = Vec::new(); + for i in 0..5000u32 { + v.extend_from_slice(&i.to_le_bytes()); + v.extend_from_slice(b"common suffix string here "); + } + let mut fast = Vec::new(); + encode_block(&v, &mut fast); + let mut hc = Vec::new(); + encode_block_level(&v, &mut hc, 9); + assert!( + hc.len() <= fast.len(), + "hc {} should be <= fast {}", + hc.len(), + fast.len() + ); + } + + /// Walk an encoded block and assert it obeys the strict end-of-block rules + /// the reference `lz4` decoder enforces: the last 5 bytes are literals, and + /// no match starts within the final `MFLIMIT` (12) bytes of the block. + /// + /// `raw_len` is the decoded length (so we can compute output positions). + fn assert_eob_rules(encoded: &[u8], raw_len: usize) { + if encoded.is_empty() { + assert_eq!(raw_len, 0); + return; + } + let mut i = 0usize; + let mut outpos = 0usize; + let n = encoded.len(); + loop { + let token = encoded[i]; + i += 1; + let mut lit = (token >> 4) as usize; + if lit == 15 { + loop { + let b = encoded[i]; + i += 1; + lit += b as usize; + if b != 255 { + break; + } + } + } + i += lit; + outpos += lit; + if i == n { + // Closing literal-only sequence: the spec requires the final + // run be at least LAST_LITERALS bytes (unless the whole block + // is shorter than that). + if raw_len >= LAST_LITERALS { + assert!( + lit >= LAST_LITERALS, + "final literal run {lit} < {LAST_LITERALS}" + ); + } + break; + } + // A match follows. Its start in the decoded stream is `outpos`. + let match_start = outpos; + assert!( + match_start + MFLIMIT <= raw_len, + "match starts at {match_start}, within MFLIMIT of end {raw_len}" + ); + i += 2; // offset + let mut ml = (token & 0x0F) as usize; + if ml == 15 { + loop { + let b = encoded[i]; + i += 1; + ml += b as usize; + if b != 255 { + break; + } + } + } + ml += MIN_MATCH; + outpos += ml; + } + assert_eq!(outpos, raw_len, "decoded length mismatch"); + } + + #[test] + fn end_of_block_rules_all_levels() { + // Construct an input whose best parse lands a match right up against + // the end of the block — exactly the case that previously produced a + // block the reference decoder rejected (a match starting inside the + // final MFLIMIT bytes). + let mut v = Vec::new(); + for _ in 0..400 { + v.extend_from_slice(b"alpha beta gamma delta epsilon "); + } + // Append a tail that repeats earlier content so a match is tempting at + // the very end. + v.extend_from_slice(b"alpha beta gamma delta epsilon"); + + for level in 0..=12u8 { + let mut enc = Vec::new(); + encode_block_level(&v, &mut enc, level); + assert_eob_rules(&enc, v.len()); + // And it must still round-trip. + let mut dec = Vec::new(); + decode_block(&enc, &mut dec, usize::MAX).expect("decode"); + assert_eq!(dec, v, "round-trip at level {level}"); + } + } } diff --git a/src/lz4/frame.rs b/src/lz4/frame.rs index e2207ee..6dfaa30 100644 --- a/src/lz4/frame.rs +++ b/src/lz4/frame.rs @@ -142,6 +142,12 @@ pub struct EncoderConfig { /// `true` (default) = append xxHash32 of the raw content at the /// frame end. Matches `lz4 -c`. pub content_checksum: bool, + /// Block-compression level forwarded to + /// [`block::encode_block_level`]. Low levels use the fast greedy parse; + /// higher levels engage the HC hash-chain match finder with lazy matching + /// for a better ratio. The emitted bitstream is a valid LZ4 block in every + /// case. Default `0` (fast path, matching `lz4 -c`'s default speed). + pub level: u8, } impl Default for EncoderConfig { @@ -151,6 +157,7 @@ impl Default for EncoderConfig { block_independence: false, block_checksum: false, content_checksum: true, + level: 0, } } } @@ -421,7 +428,7 @@ impl Encoder { // Compress into a scratch buffer. let mut compressed = Vec::with_capacity(block::compress_bound(self.raw.len())); - block::encode_block(&self.raw, &mut compressed); + block::encode_block_level(&self.raw, &mut compressed, self.cfg.level); self.staged.clear(); // Choose the smaller of compressed / raw. The LZ4 Frame spec diff --git a/src/lz4/mod.rs b/src/lz4/mod.rs index f7494ac..1f70c56 100644 --- a/src/lz4/mod.rs +++ b/src/lz4/mod.rs @@ -39,14 +39,28 @@ pub const BLOCK_SIZE: usize = 64 * 1024; #[derive(Debug, Clone, Copy, Default)] pub struct Lz4; +/// Encoder configuration for the LZ4 block stream. +/// +/// The only tunable is the compression `level`, which selects the parse +/// strategy in [`block::encode_block_level`]: low levels use the fast greedy +/// matcher (LZ4's speed crown), higher levels engage the HC hash-chain match +/// finder with lazy matching for a better ratio. The bitstream is identical +/// either way. Default `0` (fast path). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct EncoderConfig { + /// Compression level. `0` = fast greedy (default); higher values engage + /// the HC match finder with deeper search. + pub level: u8, +} + impl Algorithm for Lz4 { const NAME: &'static str = "lz4"; type Encoder = Encoder; type Decoder = Decoder; - type EncoderConfig = (); + type EncoderConfig = EncoderConfig; type DecoderConfig = (); - fn encoder_with(_: ()) -> Encoder { - Encoder::new() + fn encoder_with(c: EncoderConfig) -> Encoder { + Encoder::with_level(c.level) } fn decoder_with(_: ()) -> Decoder { Decoder::new() @@ -76,16 +90,25 @@ pub struct Encoder { compressed_idx: usize, terminator_idx: u8, phase: EncPhase, + /// Compression level forwarded to [`block::encode_block_level`]. + level: u8, } impl Encoder { pub fn new() -> Self { + Self::with_level(0) + } + + /// Construct an encoder at the given compression `level`. `0` keeps the + /// fast greedy parse; higher levels engage the HC match finder. + pub fn with_level(level: u8) -> Self { Self { raw: Vec::with_capacity(BLOCK_SIZE), compressed: Vec::with_capacity(block::compress_bound(BLOCK_SIZE) + 4), compressed_idx: 0, terminator_idx: 0, phase: EncPhase::Buffering, + level, } } @@ -106,7 +129,7 @@ impl Encoder { // clearing it. The block encoder clears its `out` arg, so we use a // temporary buffer and concatenate. let mut tmp = Vec::with_capacity(block::compress_bound(self.raw.len())); - block::encode_block(&self.raw, &mut tmp); + block::encode_block_level(&self.raw, &mut tmp, self.level); // The block format never emits zero bytes for a non-empty input; // even a single-byte input becomes at least a 1-byte token plus the // literal byte itself. diff --git a/src/lzma/encoder.rs b/src/lzma/encoder.rs index ec3778f..32345ee 100644 --- a/src/lzma/encoder.rs +++ b/src/lzma/encoder.rs @@ -36,8 +36,8 @@ use super::{ ALIGN_BITS, ALIGN_SIZE, DIST_MODEL_END, DIST_MODEL_START, DIST_SLOT_BITS, DIST_SLOTS, DIST_STATES, FULL_DISTANCES, LEN_HIGH_BITS, LEN_HIGH_SYMBOLS, LEN_LOW_BITS, LEN_LOW_SYMBOLS, LEN_MID_BITS, LEN_MID_SYMBOLS, LIT_STATES, MATCH_LEN_MIN, POS_STATES_MAX, PROB_INIT, - RC_BIT_MODEL_TOTAL_BITS, RC_MOVE_BITS, RC_TOP_VALUE, STATES, state_after_literal, - state_after_match, state_after_rep, state_after_short_rep, + RC_BIT_MODEL_TOTAL, RC_BIT_MODEL_TOTAL_BITS, RC_MOVE_BITS, RC_TOP_VALUE, STATES, + state_after_literal, state_after_match, state_after_rep, state_after_short_rep, }; // ─── encoder parameters ────────────────────────────────────────────────── @@ -98,65 +98,94 @@ struct LevelParams { max_chain: usize, /// Length at which the match finder stops looking for a longer candidate. nice_match: u32, + /// Length at which the optimal parser early-commits the current window + /// (a match this long is almost certainly taken, so there's no value in + /// extending the DP past it with increasingly stale prices). Keeping this + /// modest keeps committed segments short and prices fresh. + nice_len: u32, + /// Optimal-parser look-ahead window (number of optimum-buffer slots). When + /// `0` the parser falls back to a fast greedy/lazy parse. + opt_window: u32, } impl LevelParams { fn from_level(level: u8) -> Self { let level = level.min(9); // Mirrors xz's preset table for dictionary size, then a graduated - // chain budget / nice-match cutoff that grows with level. The - // numbers don't have to match xz precisely — what matters is that - // a higher level walks deeper chains and accepts longer matches. + // chain budget / nice-match cutoff and optimal-parse window that grow + // with level. The numbers don't have to match xz precisely — what + // matters is that a higher level walks deeper chains, accepts longer + // matches, and looks further ahead in the cost-based parse. match level { 0 => Self { dict_size: 1 << 16, // 64 KiB max_chain: 8, nice_match: 8, + nice_len: 8, + opt_window: 0, }, 1 => Self { dict_size: 1 << 20, // 1 MiB max_chain: 16, nice_match: 16, + nice_len: 16, + opt_window: 0, }, 2 => Self { dict_size: 1 << 21, // 2 MiB max_chain: 24, nice_match: 32, + nice_len: 32, + opt_window: 0, }, 3 => Self { dict_size: 1 << 22, // 4 MiB max_chain: 32, - nice_match: 48, + nice_match: 64, + nice_len: 16, + opt_window: 512, }, 4 => Self { dict_size: 1 << 22, // 4 MiB - max_chain: 48, - nice_match: 64, + max_chain: 64, + nice_match: 128, + nice_len: 24, + opt_window: 1024, }, 5 => Self { dict_size: 1 << 23, // 8 MiB - max_chain: 64, - nice_match: 96, + max_chain: 128, + nice_match: 192, + nice_len: 32, + opt_window: 2048, }, 6 => Self { dict_size: 1 << 23, // 8 MiB - max_chain: 96, - nice_match: 128, + max_chain: 256, + nice_match: 273, + nice_len: 48, + opt_window: 4096, }, 7 => Self { dict_size: 1 << 24, // 16 MiB - max_chain: 192, - nice_match: 192, + max_chain: 512, + nice_match: 273, + nice_len: 64, + opt_window: 4096, }, 8 => Self { dict_size: 1 << 25, // 32 MiB - max_chain: 384, - nice_match: 256, + max_chain: 1024, + nice_match: 273, + nice_len: 96, + opt_window: 4096, }, _ => Self { dict_size: 1 << 26, // 64 MiB (level 9) - max_chain: 768, + max_chain: 2048, nice_match: MAX_MATCH_LEN, + nice_len: 128, + opt_window: 4096, }, } } @@ -348,6 +377,96 @@ fn dist_special_encode( } } +// ─── price model ────────────────────────────────────────────────────────── +// +// The optimal parser needs the *bit cost* of encoding a given symbol with the +// current probability model. LZMA prices in 1/16-bit units: the cost of +// coding a bit against probability `p` is a fixed-point `-log2` of the +// matching probability. We replicate the SDK's `ProbPrices` table. + +const PRICE_SHIFT_BITS: u32 = 4; +const PRICE_TABLE_SIZE: usize = (RC_BIT_MODEL_TOTAL >> PRICE_SHIFT_BITS) as usize; + +/// Precomputed price table: `prices[p >> 4]` is the cost in 1/16-bit units of +/// coding a 0-bit against probability `p`. Generated the same way as the LZMA +/// SDK's price table (a fixed-point `-log2` approximation). +fn build_prob_prices() -> [u32; PRICE_TABLE_SIZE] { + let mut prices = [0u32; PRICE_TABLE_SIZE]; + // `kCyclesBits` in the SDK: the squaring loop runs exactly this many times + // (it equals the price shift, 4 — NOT the model-bit count). Getting this + // wrong makes `bit_count` overflow the subtraction and yields garbage + // prices. + let cycles_bits = PRICE_SHIFT_BITS; + let mut i: usize = (1usize << PRICE_SHIFT_BITS) >> 1; + while i < (PRICE_TABLE_SIZE << PRICE_SHIFT_BITS) { + let mut w = i as u32; + let mut bit_count = 0u32; + let mut j = 0; + while j < cycles_bits { + w = w.wrapping_mul(w); + bit_count <<= 1; + while w >= (1u32 << 16) { + w >>= 1; + bit_count += 1; + } + j += 1; + } + let idx = i >> PRICE_SHIFT_BITS; + prices[idx] = (RC_BIT_MODEL_TOTAL_BITS << PRICE_SHIFT_BITS) - 15 - bit_count; + i += 1 << PRICE_SHIFT_BITS; + } + prices +} + +#[inline] +fn price_bit(prices: &[u32; PRICE_TABLE_SIZE], prob: u16, bit: u32) -> u32 { + let p = if bit == 0 { + prob as u32 + } else { + RC_BIT_MODEL_TOTAL - prob as u32 + }; + prices[(p >> PRICE_SHIFT_BITS) as usize] +} + +#[inline] +fn price_bit0(prices: &[u32; PRICE_TABLE_SIZE], prob: u16) -> u32 { + prices[(prob as u32 >> PRICE_SHIFT_BITS) as usize] +} + +#[inline] +fn price_bit1(prices: &[u32; PRICE_TABLE_SIZE], prob: u16) -> u32 { + prices[((RC_BIT_MODEL_TOTAL - prob as u32) >> PRICE_SHIFT_BITS) as usize] +} + +fn bittree_price(prices: &[u32; PRICE_TABLE_SIZE], probs: &[u16], bits: u32, symbol: u32) -> u32 { + let mut total = 0u32; + let mut idx: u32 = 1; + let mut i = bits; + while i > 0 { + i -= 1; + let bit = (symbol >> i) & 1; + total += price_bit(prices, probs[idx as usize], bit); + idx = (idx << 1) | bit; + } + total +} + +fn bittree_reverse_price( + prices: &[u32; PRICE_TABLE_SIZE], + probs: &[u16], + bits: u32, + symbol: u32, +) -> u32 { + let mut total = 0u32; + let mut idx: u32 = 1; + for i in 0..bits { + let bit = (symbol >> i) & 1; + total += price_bit(prices, probs[idx as usize], bit); + idx = (idx << 1) | bit; + } + total +} + // ─── length coder ──────────────────────────────────────────────────────── struct LengthCoderEnc { @@ -393,6 +512,39 @@ impl LengthCoderEnc { ); } } + + /// Price of encoding length symbol `length` (0-based) at `pos_state`. + fn price(&self, prices: &[u32; PRICE_TABLE_SIZE], pos_state: u32, length: u32) -> u32 { + if length < LEN_LOW_SYMBOLS as u32 { + let base = (pos_state as usize) * LEN_LOW_SYMBOLS; + price_bit0(prices, self.choice) + + bittree_price( + prices, + &self.low[base..base + LEN_LOW_SYMBOLS], + LEN_LOW_BITS, + length, + ) + } else if length < (LEN_LOW_SYMBOLS + LEN_MID_SYMBOLS) as u32 { + let base = (pos_state as usize) * LEN_MID_SYMBOLS; + price_bit1(prices, self.choice) + + price_bit0(prices, self.choice2) + + bittree_price( + prices, + &self.mid[base..base + LEN_MID_SYMBOLS], + LEN_MID_BITS, + length - LEN_LOW_SYMBOLS as u32, + ) + } else { + price_bit1(prices, self.choice) + + price_bit1(prices, self.choice2) + + bittree_price( + prices, + &self.high, + LEN_HIGH_BITS, + length - (LEN_LOW_SYMBOLS + LEN_MID_SYMBOLS) as u32, + ) + } + } } // ─── encoder core ──────────────────────────────────────────────────────── @@ -512,6 +664,142 @@ impl LzmaEncCore { } } + /// Price of coding the literal `byte` at output offset `out_pos` with the + /// given previous byte, optional match byte (rep0 byte), and literal + /// state. Reads the live `lit` probabilities (snapshot at call time). + fn literal_price( + &self, + prices: &[u32; PRICE_TABLE_SIZE], + out_pos: u64, + byte: u8, + prev_byte: u8, + match_byte: Option, + ) -> u32 { + let lp_state = ((out_pos as u32) & self.lit_pos_mask) << self.lc; + let prev_high = (prev_byte as u32) >> (8 - self.lc); + let probs_idx = (lp_state + prev_high) as usize * 0x300; + let probs = &self.lit[probs_idx..probs_idx + 0x300]; + + let mut total = 0u32; + let mut symbol: u32 = 1; + let target = byte as u32; + match match_byte { + Some(mb) => { + let mut match_byte_w = mb as u32; + let mut mismatched = false; + let mut i: i32 = 8; + while symbol < 0x100 { + i -= 1; + let bit = (target >> i) & 1; + match_byte_w <<= 1; + let match_bit = match_byte_w & 0x100; + if !mismatched { + let idx = (0x100 + match_bit + symbol) as usize; + total += price_bit(prices, probs[idx], bit); + symbol = (symbol << 1) | bit; + if (match_bit >> 8) != bit { + mismatched = true; + } + } else { + total += price_bit(prices, probs[symbol as usize], bit); + symbol = (symbol << 1) | bit; + } + } + } + None => { + let mut i: i32 = 8; + while symbol < 0x100 { + i -= 1; + let bit = (target >> i) & 1; + total += price_bit(prices, probs[symbol as usize], bit); + symbol = (symbol << 1) | bit; + } + } + } + total + } + + /// Price of coding a new-match distance `distance` for a match of length + /// `length`. Reads the live distance probabilities. + fn distance_price(&self, prices: &[u32; PRICE_TABLE_SIZE], length: u32, distance: u32) -> u32 { + let dist_state_idx = + (length.min(DIST_STATES as u32 + MATCH_LEN_MIN - 1) - MATCH_LEN_MIN) as usize; + let slot = get_dist_slot(distance); + let slot_base = dist_state_idx * DIST_SLOTS; + let mut total = bittree_price( + prices, + &self.dist_slot[slot_base..slot_base + DIST_SLOTS], + DIST_SLOT_BITS, + slot, + ); + + if slot < DIST_MODEL_START { + return total; + } + + let num_direct_bits = (slot >> 1) - 1; + let base = (2 | (slot & 1)) << num_direct_bits; + let extra = distance.wrapping_sub(base); + + if slot < DIST_MODEL_END { + let base_idx = base as usize + 1; + let mut idx = base_idx; + let mut m: u32 = 1; + for i in 0..num_direct_bits { + let bit = (extra >> i) & 1; + total += price_bit(prices, self.dist_special[idx], bit); + if bit == 0 { + idx += m as usize; + m += m; + } else { + m += m; + idx += m as usize; + } + } + } else { + let direct_count = num_direct_bits - ALIGN_BITS; + // Direct (uniform) bits cost exactly 1 bit each. + total += direct_count << PRICE_SHIFT_BITS; + let align = extra & (ALIGN_SIZE as u32 - 1); + total += bittree_reverse_price(prices, &self.dist_align[..], ALIGN_BITS, align); + } + total + } + + /// Snapshot the cheap per-flag bit prices used by the optimal parser. + fn price_snapshot(&self, prices: &[u32; PRICE_TABLE_SIZE]) -> PriceSnapshot { + let mut is_match = [[0u32; 2]; STATES * POS_STATES_MAX]; + let mut is_rep0_long = [[0u32; 2]; STATES * POS_STATES_MAX]; + for i in 0..STATES * POS_STATES_MAX { + is_match[i][0] = price_bit0(prices, self.is_match[i]); + is_match[i][1] = price_bit1(prices, self.is_match[i]); + is_rep0_long[i][0] = price_bit0(prices, self.is_rep0_long[i]); + is_rep0_long[i][1] = price_bit1(prices, self.is_rep0_long[i]); + } + let mut is_rep = [[0u32; 2]; STATES]; + let mut is_rep0 = [[0u32; 2]; STATES]; + let mut is_rep1 = [[0u32; 2]; STATES]; + let mut is_rep2 = [[0u32; 2]; STATES]; + for s in 0..STATES { + is_rep[s][0] = price_bit0(prices, self.is_rep[s]); + is_rep[s][1] = price_bit1(prices, self.is_rep[s]); + is_rep0[s][0] = price_bit0(prices, self.is_rep0[s]); + is_rep0[s][1] = price_bit1(prices, self.is_rep0[s]); + is_rep1[s][0] = price_bit0(prices, self.is_rep1[s]); + is_rep1[s][1] = price_bit1(prices, self.is_rep1[s]); + is_rep2[s][0] = price_bit0(prices, self.is_rep2[s]); + is_rep2[s][1] = price_bit1(prices, self.is_rep2[s]); + } + PriceSnapshot { + is_match, + is_rep, + is_rep0, + is_rep1, + is_rep2, + is_rep0_long, + } + } + fn encode_distance(&mut self, length: u32, distance: u32) { let dist_state_idx = (length.min(DIST_STATES as u32 + MATCH_LEN_MIN - 1) - MATCH_LEN_MIN) as usize; @@ -812,6 +1100,71 @@ impl HashChain { None } } + + /// Collect the candidate match set for the optimal parser: for each + /// achievable length `>= MATCH_LEN_MIN`, the *shortest* distance that + /// achieves it. `out` is filled with `(len, dist0based)` pairs in + /// strictly increasing length order. Returns the longest length found. + fn find_matches( + &self, + input: &[u8], + pos: usize, + dict_size: u32, + max_chain: usize, + nice_match: u32, + out: &mut Vec<(u32, u32)>, + ) -> u32 { + out.clear(); + if pos + 3 > input.len() { + return 0; + } + let h = hash3(input[pos], input[pos + 1], input[pos + 2]) as usize; + let max_len = MAX_MATCH_LEN.min((input.len() - pos) as u32); + let max_dist = (dict_size as usize).min(pos); + let mut best_len: u32 = MATCH_LEN_MIN - 1; + let mut cur = self.head[h]; + let mut steps = 0usize; + while cur != NIL && steps < max_chain { + let cur_pos = cur as usize; + if cur_pos >= pos { + cur = self.prev[cur_pos]; + steps += 1; + continue; + } + let dist = pos - cur_pos; + if dist > max_dist { + break; + } + if best_len >= MATCH_LEN_MIN + && (best_len as usize) < (input.len() - pos) + && input[cur_pos + best_len as usize] != input[pos + best_len as usize] + { + cur = self.prev[cur_pos]; + steps += 1; + continue; + } + let mut len = 0u32; + while len < max_len && input[cur_pos + len as usize] == input[pos + len as usize] { + len += 1; + } + if len >= MATCH_LEN_MIN && len > best_len { + // Chain is walked nearest-first, so this is the shortest + // distance achieving every length in (best_len, len]. + out.push((len, (dist - 1) as u32)); + best_len = len; + if len >= nice_match || len >= max_len { + break; + } + } + cur = self.prev[cur_pos]; + steps += 1; + } + if best_len >= MATCH_LEN_MIN { + best_len + } else { + 0 + } + } } // ─── rep-match helpers ─────────────────────────────────────────────────── @@ -833,16 +1186,179 @@ fn rep_match_len(input: &[u8], pos: usize, dist: u32) -> u32 { len as u32 } +// ─── price snapshot + optimal-parse scaffolding ─────────────────────────── + +/// Cached bit prices for the cheap per-decision flags. Length/distance/literal +/// prices are computed on demand from the core's live tables. +struct PriceSnapshot { + is_match: [[u32; 2]; STATES * POS_STATES_MAX], + is_rep: [[u32; 2]; STATES], + is_rep0: [[u32; 2]; STATES], + is_rep1: [[u32; 2]; STATES], + is_rep2: [[u32; 2]; STATES], + is_rep0_long: [[u32; 2]; STATES * POS_STATES_MAX], +} + +impl PriceSnapshot { + /// Price of the rep-flag prefix selecting rep index `rep_idx` from `state` + /// (the `is_rep`=1 bit plus the rep0/rep1/rep2 selector bits, but NOT the + /// length and NOT the is_rep0_long bit for rep0). + fn rep_choice_price(&self, state: usize, rep_idx: u32) -> u32 { + let mut p = self.is_rep[state][1]; + match rep_idx { + 0 => p += self.is_rep0[state][0], + 1 => p += self.is_rep0[state][1] + self.is_rep1[state][0], + 2 => p += self.is_rep0[state][1] + self.is_rep1[state][1] + self.is_rep2[state][0], + _ => p += self.is_rep0[state][1] + self.is_rep1[state][1] + self.is_rep2[state][1], + } + p + } +} + +/// One parser decision, replayed through the real emit functions after the +/// optimal parse has chosen it. +#[derive(Clone, Copy)] +enum Decision { + Literal, + /// New match: `(distance0based, length)`. + Match(u32, u32), + /// Long rep: `(rep_index, length)`. + Rep(u32, u32), + ShortRep, +} + +/// A node in the optimum DP buffer. +#[derive(Clone, Copy)] +struct OptNode { + price: u32, + prev_pos: u32, + decision: Decision, + state: usize, + reps: [u32; 4], +} + +const INFINITY_PRICE: u32 = u32::MAX; + +/// Scratch buffers for the optimal parser. +struct Optimizer { + opt: Vec, + matches: Vec<(u32, u32)>, + decisions: Vec, +} + +impl Optimizer { + fn new(window: usize) -> Self { + let cap = window + MAX_MATCH_LEN as usize + 2; + Self { + opt: vec![ + OptNode { + price: INFINITY_PRICE, + prev_pos: 0, + decision: Decision::Literal, + state: 0, + reps: [0; 4], + }; + cap + ], + matches: Vec::with_capacity(64), + decisions: Vec::with_capacity(window + 1), + } + } +} + +fn reorder_reps(reps: [u32; 4], rep_idx: u32) -> [u32; 4] { + match rep_idx { + 0 => reps, + 1 => [reps[1], reps[0], reps[2], reps[3]], + 2 => [reps[2], reps[0], reps[1], reps[3]], + _ => [reps[3], reps[0], reps[1], reps[2]], + } +} + +#[allow(clippy::too_many_arguments)] +fn literal_price_at( + core: &LzmaEncCore, + prices: &[u32; PRICE_TABLE_SIZE], + snap: &PriceSnapshot, + input: &[u8], + pos: usize, + out_pos: u64, + state: usize, + rep0: u32, +) -> u32 { + let pos_state = (out_pos as u32) & core.pos_mask; + let im_idx = state * POS_STATES_MAX + pos_state as usize; + let prev_byte = if pos > 0 { input[pos - 1] } else { 0 }; + let match_byte = if state < LIT_STATES { + None + } else { + let d = rep0 as usize + 1; + if d <= pos { Some(input[pos - d]) } else { None } + }; + snap.is_match[im_idx][0] + + core.literal_price(prices, out_pos, input[pos], prev_byte, match_byte) +} + // ─── full encode pass ──────────────────────────────────────────────────── fn encode_all(input: &[u8], params: LevelParams) -> Vec { + let dict_size = params.effective_dict_size(input.len()); + + // Threshold below which we also run a greedy pass and keep the smaller + // body. The optimal parser's cold-start price model can briefly lose to + // greedy on small, highly-repetitive inputs; the absolute loss is bounded + // by the first few price-refresh segments, so on larger inputs the optimal + // parse always wins overall and the extra greedy pass is pure waste. We + // therefore only run the guard pass on small inputs. + const GUARD_LIMIT: usize = 64 * 1024; + + let body = if params.opt_window == 0 { + encode_body(input, dict_size, params, false) + } else if input.len() <= GUARD_LIMIT { + let opt = encode_body(input, dict_size, params, true); + let greedy = encode_body(input, dict_size, params, false); + if greedy.len() < opt.len() { + greedy + } else { + opt + } + } else { + encode_body(input, dict_size, params, true) + }; + + let mut out = Vec::with_capacity(13 + body.len()); + out.push(ENC_PROPS_BYTE); + out.extend_from_slice(&dict_size.to_le_bytes()); + out.extend_from_slice(&u64::MAX.to_le_bytes()); + out.extend_from_slice(&body); + out +} + +/// Encode the range-coded body (with EOS marker + flush, no 13-byte header) +/// using the greedy or optimal parse. Returns the raw body bytes. +fn encode_body(input: &[u8], dict_size: u32, params: LevelParams, optimal: bool) -> Vec { let mut core = LzmaEncCore::new(); let mut hc = HashChain::new(input.len()); - let dict_size = params.effective_dict_size(input.len()); + if optimal { + encode_optimal(&mut core, &mut hc, input, dict_size, params); + } else { + encode_greedy(&mut core, &mut hc, input, dict_size, params); + } + core.emit_eos_marker(); + core.rc.flush(); + core.rc.out +} +/// Greedy/lazy parse — used by the lowest levels. +fn encode_greedy( + core: &mut LzmaEncCore, + hc: &mut HashChain, + input: &[u8], + dict_size: u32, + params: LevelParams, +) { let mut pos = 0usize; while pos < input.len() { - // Try rep matches first — they're the cheapest to encode. let rep_lens = [ rep_match_len(input, pos, core.rep0), rep_match_len(input, pos, core.rep1), @@ -850,10 +1366,8 @@ fn encode_all(input: &[u8], params: LevelParams) -> Vec { rep_match_len(input, pos, core.rep3), ]; - // Best new match from the hash chain. let new_match = hc.find_longest(input, pos, dict_size, params.max_chain, params.nice_match); - // Decide what to emit. let best_rep_len = rep_lens.iter().copied().max().unwrap_or(0); let best_rep_idx = rep_lens .iter() @@ -864,22 +1378,14 @@ fn encode_all(input: &[u8], params: LevelParams) -> Vec { let new_match_len = new_match.map(|(l, _)| l).unwrap_or(0); - // Heuristic: prefer rep if it's at least as long as a new match, or - // if rep[0] still matches at least one byte (SHORTREP is dirt-cheap - // when no longer match exists). New matches are emitted only when - // they strictly beat the best rep. let emit_new = new_match_len > best_rep_len && new_match_len >= MATCH_LEN_MIN; let emit_rep_long = !emit_new && best_rep_len >= MATCH_LEN_MIN; let emit_short_rep = !emit_new && !emit_rep_long && rep_lens[0] >= 1; - // Insert the current position into the hash chain so future positions - // can reference us. (We do this regardless of what we emit; the chain - // is the *input* index, not the output index.) hc.insert(input, pos); if emit_new { let (len, dist) = new_match.unwrap(); - // Insert covered positions for higher-quality future matches. for j in 1..(len as usize) { let p = pos + j; if p + 3 <= input.len() { @@ -905,16 +1411,312 @@ fn encode_all(input: &[u8], params: LevelParams) -> Vec { pos += 1; } } +} - // End-of-stream marker + flush. - core.emit_eos_marker(); - core.rc.flush(); - let mut out = Vec::with_capacity(13 + core.rc.out.len()); - out.push(ENC_PROPS_BYTE); - out.extend_from_slice(&dict_size.to_le_bytes()); - out.extend_from_slice(&u64::MAX.to_le_bytes()); - out.extend_from_slice(&core.rc.out); - out +/// Cost-based optimal parse over a look-ahead window. +fn encode_optimal( + core: &mut LzmaEncCore, + hc: &mut HashChain, + input: &[u8], + dict_size: u32, + params: LevelParams, +) { + let prob_prices = build_prob_prices(); + let window = params.opt_window as usize; + let mut opt = Optimizer::new(window); + + let mut pos = 0usize; + while pos < input.len() { + let snap = core.price_snapshot(&prob_prices); + let parsed = parse_window( + core, + hc, + input, + pos, + dict_size, + params, + window, + &prob_prices, + &snap, + &mut opt, + ); + debug_assert!(parsed > 0); + replay(core, hc, input, pos, &opt.decisions); + pos += parsed; + } +} + +/// Parse one look-ahead window starting at `start`; fills `opt.decisions` and +/// returns the number of input bytes the chosen decisions consume. +#[allow(clippy::too_many_arguments)] +fn parse_window( + core: &LzmaEncCore, + hc: &HashChain, + input: &[u8], + start: usize, + dict_size: u32, + params: LevelParams, + window: usize, + prices: &[u32; PRICE_TABLE_SIZE], + snap: &PriceSnapshot, + opt: &mut Optimizer, +) -> usize { + let avail = input.len() - start; + let limit = window.min(avail); + + opt.opt[0] = OptNode { + price: 0, + prev_pos: 0, + decision: Decision::Literal, + state: core.state, + reps: [core.rep0, core.rep1, core.rep2, core.rep3], + }; + for node in opt.opt[1..=limit].iter_mut() { + node.price = INFINITY_PRICE; + } + + // Hard commit cap: even without a long match we commit after this many + // bytes so the price snapshot is refreshed frequently against the live + // (adapting) model. Without this, a long literal run parsed under a single + // stale snapshot makes systematically worse rep-vs-match decisions. + const COMMIT_CAP: usize = 192; + + let mut reached = 0usize; + let mut commit_end: Option = None; + let mut cur = 0usize; + while cur < limit { + if let Some(ce) = commit_end + && cur >= ce + { + break; + } + // Force a commit boundary once we've extended COMMIT_CAP bytes past the + // window start with no earlier long-match commit. + if commit_end.is_none() && cur >= COMMIT_CAP { + commit_end = Some(cur); + break; + } + let node = opt.opt[cur]; + if node.price == INFINITY_PRICE { + cur += 1; + continue; + } + let pos = start + cur; + let out_pos = core.output_pos + cur as u64; + let state = node.state; + let reps = node.reps; + let pos_state = (out_pos as u32) & core.pos_mask; + let im_idx = state * POS_STATES_MAX + pos_state as usize; + let mut best_here: u32 = 0; + + // ── literal ────────────────────────────────────────────────────── + { + let lp = literal_price_at(core, prices, snap, input, pos, out_pos, state, reps[0]); + let np = node.price.saturating_add(lp); + let to = cur + 1; + if to <= limit && np < opt.opt[to].price { + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::Literal, + state: state_after_literal(state), + reps, + }; + if to > reached { + reached = to; + } + } + } + + let match_flag = snap.is_match[im_idx][1]; + + // ── rep matches ────────────────────────────────────────────────── + for rep_idx in 0..4u32 { + let rlen = rep_match_len(input, pos, reps[rep_idx as usize]); + if rlen < 1 { + continue; + } + if rep_idx == 0 { + let sp = match_flag + + snap.is_rep[state][1] + + snap.is_rep0[state][0] + + snap.is_rep0_long[im_idx][0]; + let np = node.price.saturating_add(sp); + let to = cur + 1; + if to <= limit && np < opt.opt[to].price { + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::ShortRep, + state: state_after_short_rep(state), + reps, + }; + if to > reached { + reached = to; + } + } + } + if rlen < MATCH_LEN_MIN { + continue; + } + if rlen > best_here { + best_here = rlen; + } + let rep_new_reps = reorder_reps(reps, rep_idx); + let choice = match_flag + snap.rep_choice_price(state, rep_idx); + let rep0_long = if rep_idx == 0 { + snap.is_rep0_long[im_idx][1] + } else { + 0 + }; + let st_after = state_after_rep(state); + let maxr = rlen.min((limit - cur) as u32); + let mut l = MATCH_LEN_MIN; + while l <= maxr { + let len_price = core + .rep_len_coder + .price(prices, pos_state, l - MATCH_LEN_MIN); + let np = node.price.saturating_add(choice + rep0_long + len_price); + let to = cur + l as usize; + if np < opt.opt[to].price { + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::Rep(rep_idx, l), + state: st_after, + reps: rep_new_reps, + }; + if to > reached { + reached = to; + } + } + l += 1; + } + } + + // ── new matches ────────────────────────────────────────────────── + let longest = { + let opt_matches = &mut opt.matches; + hc.find_matches( + input, + pos, + dict_size, + params.max_chain, + params.nice_match, + opt_matches, + ) + }; + if longest >= MATCH_LEN_MIN { + if longest > best_here { + best_here = longest; + } + let match_choice = match_flag + snap.is_rep[state][0]; + let st_after = state_after_match(state); + let cap = (limit - cur) as u32; + let mut prev_len = MATCH_LEN_MIN - 1; + let nmatches = opt.matches.len(); + for mi in 0..nmatches { + let (mlen, mdist) = opt.matches[mi]; + let band_end = mlen.min(cap); + let mut l = (prev_len + 1).max(MATCH_LEN_MIN); + while l <= band_end { + let len_price = core.len_coder.price(prices, pos_state, l - MATCH_LEN_MIN); + let dist_price = core.distance_price(prices, l, mdist); + let np = node + .price + .saturating_add(match_choice + len_price + dist_price); + let to = cur + l as usize; + if np < opt.opt[to].price { + let new_reps = [mdist, reps[0], reps[1], reps[2]]; + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::Match(mdist, l), + state: st_after, + reps: new_reps, + }; + if to > reached { + reached = to; + } + } + l += 1; + } + prev_len = mlen; + } + } + + // Early-commit once a long match is reachable: commit up to its end so + // the price snapshot stays close to the live model. Mirrors the SDK's + // `nice_len` cut-off in GetOptimum. + if commit_end.is_none() && best_here >= params.nice_len { + commit_end = Some((cur + best_here as usize).min(limit)); + } + + cur += 1; + } + + let end = match commit_end { + Some(ce) => ce.max(1).min(reached.max(1)), + None => reached.max(1), + } + .min(avail); + trace_back(opt, end); + end +} + +fn trace_back(opt: &mut Optimizer, end: usize) { + opt.decisions.clear(); + let mut cur = end; + while cur > 0 { + let node = opt.opt[cur]; + opt.decisions.push(node.decision); + cur = node.prev_pos as usize; + } + opt.decisions.reverse(); +} + +fn replay( + core: &mut LzmaEncCore, + hc: &mut HashChain, + input: &[u8], + start: usize, + decisions: &[Decision], +) { + let mut pos = start; + for &d in decisions { + match d { + Decision::Literal => { + hc.insert(input, pos); + core.emit_literal(input, pos); + pos += 1; + } + Decision::ShortRep => { + hc.insert(input, pos); + core.emit_short_rep(); + pos += 1; + } + Decision::Match(dist, len) => { + for j in 0..(len as usize) { + let p = pos + j; + if p + 3 <= input.len() { + hc.insert(input, p); + } + } + core.emit_match(dist, len); + pos += len as usize; + } + Decision::Rep(idx, len) => { + for j in 0..(len as usize) { + let p = pos + j; + if p + 3 <= input.len() { + hc.insert(input, p); + } + } + core.emit_long_rep(idx, len); + pos += len as usize; + } + } + } } // ─── public streaming Encoder ──────────────────────────────────────────── diff --git a/src/lzma2_internal/lzma2_encoder.rs b/src/lzma2_internal/lzma2_encoder.rs index 69dadd4..bd3589b 100644 --- a/src/lzma2_internal/lzma2_encoder.rs +++ b/src/lzma2_internal/lzma2_encoder.rs @@ -16,10 +16,24 @@ //! decoder; the chunk's compressed-size field in the LZMA2 header includes //! the flush bytes. //! -//! Strategy mirrors the LZMA encoder: a greedy parser over the input buffer -//! with a 3-byte hash chain match finder. Quality is the same as the -//! `.lzma` encoder — sufficient for xz cross-validation; not competitive -//! with xz-utils at higher presets. +//! ## Parse strategy +//! +//! This encoder uses a **cost-based optimal parse** modelled on the LZMA SDK's +//! `GetOptimum`. For each window of input it builds a forward +//! dynamic-programming table over an "optimum" buffer: every reachable +//! position records the minimum range-coder bit price to arrive there and a +//! back-pointer to the decision (literal / match / rep0..rep3 / short-rep) +//! that produced it. Prices come from a snapshot of the live probability +//! model — the same probabilities the range coder is about to use — so the +//! parser optimises the *actual* encoded size rather than a length heuristic. +//! +//! Match finding is a hash-chain finder that returns the full set of +//! candidate (length, distance) pairs at a position (the shortest distance for +//! each achievable length), plus the four repeat-distance matches, so the +//! optimal parser has the complete candidate set it needs. +//! +//! Lower levels fall back to a fast greedy/lazy parse; the optimal parse and +//! its look-ahead window scale up with `level`. use alloc::boxed::Box; use alloc::vec; @@ -98,73 +112,95 @@ const HASH_BITS: u32 = 16; const HASH_SIZE: usize = 1 << HASH_BITS; const NIL: u32 = u32::MAX; -/// Match-finder tuning expanded from the user-facing `level` byte. Higher -/// levels widen `max_chain` (more hash-chain links walked per probe) and -/// raise `nice_match` (the length at which the chain walk gives up and -/// accepts the current match). This is the same speed-vs-ratio knob that -/// xz-utils exposes — we just expose a small subset. +/// Match-finder + optimal-parser tuning expanded from the user-facing `level` +/// byte. Higher levels widen `max_chain` (more hash-chain links walked per +/// probe), raise `nice_match` (the length at which the chain walk gives up and +/// accepts the current match), and enlarge `opt_window` (how far ahead the +/// optimal parser looks before committing a parse). This is the same +/// speed-vs-ratio knob xz-utils exposes — we expose a small subset. #[derive(Debug, Clone, Copy)] pub(crate) struct EncoderParams { pub max_chain: usize, pub nice_match: u32, + /// Length at which the optimal parser early-commits the current window + /// (a match this long is almost certainly taken). Keeping this modest + /// keeps committed segments short and the price snapshot fresh. + pub nice_len: u32, + /// Optimal-parser look-ahead window (number of optimum-buffer slots). When + /// `0` the parser falls back to a fast greedy/lazy parse (used by the + /// lowest levels so they stay genuinely fast). + pub opt_window: u32, } impl EncoderParams { - /// Expand a `0..=9` level into match-finder knobs. + /// Expand a `0..=9` level into match-finder + parser knobs. /// - /// The mapping is monotonic and centred on the default level 6 producing - /// the same `(96, 192)` numbers the previous fixed-tuning code used. - /// Values outside `0..=9` are clamped — we keep the public surface - /// infallible. + /// The mapping is monotonic: higher level = deeper chain walk, longer + /// nice-match cutoff, and a larger optimal-parse window. Values outside + /// `0..=9` are clamped — we keep the public surface infallible. pub fn from_level(level: u8) -> Self { let level = level.min(9); - // Hand-tuned table: low levels skip most of the chain walk so the - // greedy parser commits the first short match it finds; high levels - // walk wide chains and accept only long matches. The values aren't - // meant to mirror xz-utils' presets exactly — they just have to - // produce a measurably monotonic compressed size on a hash- - // collision-heavy corpus, which is what `tests/xz.rs` checks. match level { 0 => Self { - max_chain: 2, - nice_match: 4, - }, - 1 => Self { max_chain: 4, nice_match: 8, + nice_len: 8, + opt_window: 0, }, - 2 => Self { + 1 => Self { max_chain: 8, nice_match: 16, + nice_len: 16, + opt_window: 0, }, - 3 => Self { + 2 => Self { max_chain: 16, nice_match: 32, + nice_len: 32, + opt_window: 0, }, - 4 => Self { + 3 => Self { max_chain: 32, nice_match: 64, + nice_len: 16, + opt_window: 512, }, - 5 => Self { + 4 => Self { max_chain: 64, nice_match: 128, + nice_len: 24, + opt_window: 1024, }, - 6 => Self { - max_chain: 96, + 5 => Self { + max_chain: 128, nice_match: 192, + nice_len: 32, + opt_window: 2048, + }, + 6 => Self { + max_chain: 256, + nice_match: 273, + nice_len: 48, + opt_window: 4096, }, 7 => Self { - max_chain: 192, - nice_match: 224, + max_chain: 512, + nice_match: 273, + nice_len: 64, + opt_window: 4096, }, 8 => Self { - max_chain: 384, - nice_match: 256, + max_chain: 1024, + nice_match: 273, + nice_len: 96, + opt_window: 4096, }, // 9 (and clamp-from-above) _ => Self { - max_chain: 768, - nice_match: 273, // MAX_MATCH_LEN + max_chain: 2048, + nice_match: MAX_MATCH_LEN, + nice_len: 128, + opt_window: 4096, }, } } @@ -316,6 +352,96 @@ fn dist_special_encode( } } +// ─── price model ────────────────────────────────────────────────────────── +// +// The optimal parser needs the *bit cost* of encoding a given symbol with the +// current probability model. LZMA prices in 1/16-bit units: the cost of +// coding a bit against probability `p` is a fixed-point `-log2` of the +// matching probability. We replicate the SDK's `ProbPrices` table. + +const PRICE_SHIFT_BITS: u32 = 4; +const PRICE_TABLE_SIZE: usize = (RC_BIT_MODEL_TOTAL >> PRICE_SHIFT_BITS) as usize; + +/// Precomputed price table: `prices[p >> 4]` is the cost in 1/16-bit units of +/// coding a 0-bit against probability `p`. Generated the same way as the LZMA +/// SDK's price table (a fixed-point `-log2` approximation). +fn build_prob_prices() -> [u32; PRICE_TABLE_SIZE] { + let mut prices = [0u32; PRICE_TABLE_SIZE]; + // `kCyclesBits` in the SDK: the squaring loop runs exactly this many times + // (it equals the price shift, 4 — NOT the model-bit count). Getting this + // wrong makes `bit_count` overflow the subtraction and yields garbage + // prices. + let cycles_bits = PRICE_SHIFT_BITS; + let mut i: usize = (1usize << PRICE_SHIFT_BITS) >> 1; + while i < (PRICE_TABLE_SIZE << PRICE_SHIFT_BITS) { + let mut w = i as u32; + let mut bit_count = 0u32; + let mut j = 0; + while j < cycles_bits { + w = w.wrapping_mul(w); + bit_count <<= 1; + while w >= (1u32 << 16) { + w >>= 1; + bit_count += 1; + } + j += 1; + } + let idx = i >> PRICE_SHIFT_BITS; + prices[idx] = (RC_BIT_MODEL_TOTAL_BITS << PRICE_SHIFT_BITS) - 15 - bit_count; + i += 1 << PRICE_SHIFT_BITS; + } + prices +} + +#[inline] +fn price_bit(prices: &[u32; PRICE_TABLE_SIZE], prob: u16, bit: u32) -> u32 { + let p = if bit == 0 { + prob as u32 + } else { + RC_BIT_MODEL_TOTAL - prob as u32 + }; + prices[(p >> PRICE_SHIFT_BITS) as usize] +} + +#[inline] +fn price_bit0(prices: &[u32; PRICE_TABLE_SIZE], prob: u16) -> u32 { + prices[(prob as u32 >> PRICE_SHIFT_BITS) as usize] +} + +#[inline] +fn price_bit1(prices: &[u32; PRICE_TABLE_SIZE], prob: u16) -> u32 { + prices[((RC_BIT_MODEL_TOTAL - prob as u32) >> PRICE_SHIFT_BITS) as usize] +} + +fn bittree_price(prices: &[u32; PRICE_TABLE_SIZE], probs: &[u16], bits: u32, symbol: u32) -> u32 { + let mut total = 0u32; + let mut idx: u32 = 1; + let mut i = bits; + while i > 0 { + i -= 1; + let bit = (symbol >> i) & 1; + total += price_bit(prices, probs[idx as usize], bit); + idx = (idx << 1) | bit; + } + total +} + +fn bittree_reverse_price( + prices: &[u32; PRICE_TABLE_SIZE], + probs: &[u16], + bits: u32, + symbol: u32, +) -> u32 { + let mut total = 0u32; + let mut idx: u32 = 1; + for i in 0..bits { + let bit = (symbol >> i) & 1; + total += price_bit(prices, probs[idx as usize], bit); + idx = (idx << 1) | bit; + } + total +} + // ─── length coder ──────────────────────────────────────────────────────── struct LengthCoderEnc { @@ -360,6 +486,39 @@ impl LengthCoderEnc { ); } } + + /// Price of encoding length symbol `length` (0-based) at `pos_state`. + fn price(&self, prices: &[u32; PRICE_TABLE_SIZE], pos_state: u32, length: u32) -> u32 { + if length < LEN_LOW_SYMBOLS as u32 { + let base = (pos_state as usize) * LEN_LOW_SYMBOLS; + price_bit0(prices, self.choice) + + bittree_price( + prices, + &self.low[base..base + LEN_LOW_SYMBOLS], + LEN_LOW_BITS, + length, + ) + } else if length < (LEN_LOW_SYMBOLS + LEN_MID_SYMBOLS) as u32 { + let base = (pos_state as usize) * LEN_MID_SYMBOLS; + price_bit1(prices, self.choice) + + price_bit0(prices, self.choice2) + + bittree_price( + prices, + &self.mid[base..base + LEN_MID_SYMBOLS], + LEN_MID_BITS, + length - LEN_LOW_SYMBOLS as u32, + ) + } else { + price_bit1(prices, self.choice) + + price_bit1(prices, self.choice2) + + bittree_price( + prices, + &self.high, + LEN_HIGH_BITS, + length - (LEN_LOW_SYMBOLS + LEN_MID_SYMBOLS) as u32, + ) + } + } } // ─── encoder core ──────────────────────────────────────────────────────── @@ -470,6 +629,61 @@ impl LzmaEncCore { } } + /// Price of coding the literal `byte` at output offset `out_pos` with the + /// given previous byte, optional match byte (rep0 byte), and literal + /// state. Reads the live `lit` probabilities (snapshot at call time). + fn literal_price( + &self, + prices: &[u32; PRICE_TABLE_SIZE], + out_pos: u64, + byte: u8, + prev_byte: u8, + match_byte: Option, + ) -> u32 { + let lp_state = ((out_pos as u32) & self.lit_pos_mask) << self.lc; + let prev_high = (prev_byte as u32) >> (8 - self.lc); + let probs_idx = (lp_state + prev_high) as usize * 0x300; + let probs = &self.lit[probs_idx..probs_idx + 0x300]; + + let mut total = 0u32; + let mut symbol: u32 = 1; + let target = byte as u32; + match match_byte { + Some(mb) => { + let mut match_byte_w = mb as u32; + let mut mismatched = false; + let mut i: i32 = 8; + while symbol < 0x100 { + i -= 1; + let bit = (target >> i) & 1; + match_byte_w <<= 1; + let match_bit = match_byte_w & 0x100; + if !mismatched { + let idx = (0x100 + match_bit + symbol) as usize; + total += price_bit(prices, probs[idx], bit); + symbol = (symbol << 1) | bit; + if (match_bit >> 8) != bit { + mismatched = true; + } + } else { + total += price_bit(prices, probs[symbol as usize], bit); + symbol = (symbol << 1) | bit; + } + } + } + None => { + let mut i: i32 = 8; + while symbol < 0x100 { + i -= 1; + let bit = (target >> i) & 1; + total += price_bit(prices, probs[symbol as usize], bit); + symbol = (symbol << 1) | bit; + } + } + } + total + } + fn encode_distance(&mut self, length: u32, distance: u32) { let dist_state_idx = (length.min(DIST_STATES as u32 + MATCH_LEN_MIN - 1) - MATCH_LEN_MIN) as usize; @@ -509,6 +723,53 @@ impl LzmaEncCore { } } + /// Price of coding a new-match distance `distance` for a match of length + /// `length`. Reads the live distance probabilities. + fn distance_price(&self, prices: &[u32; PRICE_TABLE_SIZE], length: u32, distance: u32) -> u32 { + let dist_state_idx = + (length.min(DIST_STATES as u32 + MATCH_LEN_MIN - 1) - MATCH_LEN_MIN) as usize; + let slot = get_dist_slot(distance); + let slot_base = dist_state_idx * DIST_SLOTS; + let mut total = bittree_price( + prices, + &self.dist_slot[slot_base..slot_base + DIST_SLOTS], + DIST_SLOT_BITS, + slot, + ); + + if slot < DIST_MODEL_START { + return total; + } + + let num_direct_bits = (slot >> 1) - 1; + let base = (2 | (slot & 1)) << num_direct_bits; + let extra = distance.wrapping_sub(base); + + if slot < DIST_MODEL_END { + let base_idx = base as usize + 1; + let mut idx = base_idx; + let mut m: u32 = 1; + for i in 0..num_direct_bits { + let bit = (extra >> i) & 1; + total += price_bit(prices, self.dist_special[idx], bit); + if bit == 0 { + idx += m as usize; + m += m; + } else { + m += m; + idx += m as usize; + } + } + } else { + let direct_count = num_direct_bits - ALIGN_BITS; + // Direct (uniform) bits cost exactly 1 bit each. + total += direct_count << PRICE_SHIFT_BITS; + let align = extra & (ALIGN_SIZE as u32 - 1); + total += bittree_reverse_price(prices, &self.dist_align[..], ALIGN_BITS, align); + } + total + } + fn emit_literal(&mut self, input: &[u8], pos: usize) { let pos_state = self.pos_state(); let idx = self.state * POS_STATES_MAX + pos_state as usize; @@ -607,6 +868,69 @@ impl LzmaEncCore { self.state = state_after_rep(self.state); self.output_pos += length as u64; } + + /// Snapshot the cheap per-flag bit prices used by the optimal parser. + /// Recomputed periodically as the live probabilities drift. + fn price_snapshot(&self, prices: &[u32; PRICE_TABLE_SIZE]) -> PriceSnapshot { + let mut is_match = [[0u32; 2]; STATES * POS_STATES_MAX]; + let mut is_rep0_long = [[0u32; 2]; STATES * POS_STATES_MAX]; + for i in 0..STATES * POS_STATES_MAX { + is_match[i][0] = price_bit0(prices, self.is_match[i]); + is_match[i][1] = price_bit1(prices, self.is_match[i]); + is_rep0_long[i][0] = price_bit0(prices, self.is_rep0_long[i]); + is_rep0_long[i][1] = price_bit1(prices, self.is_rep0_long[i]); + } + let mut is_rep = [[0u32; 2]; STATES]; + let mut is_rep0 = [[0u32; 2]; STATES]; + let mut is_rep1 = [[0u32; 2]; STATES]; + let mut is_rep2 = [[0u32; 2]; STATES]; + for s in 0..STATES { + is_rep[s][0] = price_bit0(prices, self.is_rep[s]); + is_rep[s][1] = price_bit1(prices, self.is_rep[s]); + is_rep0[s][0] = price_bit0(prices, self.is_rep0[s]); + is_rep0[s][1] = price_bit1(prices, self.is_rep0[s]); + is_rep1[s][0] = price_bit0(prices, self.is_rep1[s]); + is_rep1[s][1] = price_bit1(prices, self.is_rep1[s]); + is_rep2[s][0] = price_bit0(prices, self.is_rep2[s]); + is_rep2[s][1] = price_bit1(prices, self.is_rep2[s]); + } + PriceSnapshot { + is_match, + is_rep, + is_rep0, + is_rep1, + is_rep2, + is_rep0_long, + } + } +} + +/// Cached bit prices for the cheap per-decision flags. Length/distance/literal +/// prices are computed on demand from the core's live tables (which the +/// optimizer holds a reference to) since they have large key spaces. +struct PriceSnapshot { + is_match: [[u32; 2]; STATES * POS_STATES_MAX], + is_rep: [[u32; 2]; STATES], + is_rep0: [[u32; 2]; STATES], + is_rep1: [[u32; 2]; STATES], + is_rep2: [[u32; 2]; STATES], + is_rep0_long: [[u32; 2]; STATES * POS_STATES_MAX], +} + +impl PriceSnapshot { + /// Price of the rep-flag prefix selecting rep index `rep_idx` from `state` + /// (the `is_rep`=1 bit plus the rep0/rep1/rep2 selector bits, but NOT the + /// length and NOT the is_rep0_long bit for rep0). + fn rep_choice_price(&self, state: usize, rep_idx: u32) -> u32 { + let mut p = self.is_rep[state][1]; + match rep_idx { + 0 => p += self.is_rep0[state][0], + 1 => p += self.is_rep0[state][1] + self.is_rep1[state][0], + 2 => p += self.is_rep0[state][1] + self.is_rep1[state][1] + self.is_rep2[state][0], + _ => p += self.is_rep0[state][1] + self.is_rep1[state][1] + self.is_rep2[state][1], + } + p + } } fn rc_encode_bit(rc: &mut RangeEncoder, prob: &mut u16, bit: u32) { @@ -648,6 +972,7 @@ impl HashChain { self.head[h] = pos as u32; } + /// Find the single longest match (greedy use). Returns `(len, dist0based)`. fn find_longest( &self, input: &[u8], @@ -704,6 +1029,71 @@ impl HashChain { None } } + + /// Collect the candidate match set for the optimal parser: for each + /// achievable length `>= MATCH_LEN_MIN`, the *shortest* distance that + /// achieves it. `out` is filled with `(len, dist0based)` pairs in + /// strictly increasing length order. Returns the longest length found. + fn find_matches( + &self, + input: &[u8], + pos: usize, + dict_size: u32, + params: EncoderParams, + out: &mut Vec<(u32, u32)>, + ) -> u32 { + out.clear(); + if pos + 3 > input.len() { + return 0; + } + let h = hash3(input[pos], input[pos + 1], input[pos + 2]) as usize; + let max_len = MAX_MATCH_LEN.min((input.len() - pos) as u32); + let max_dist = (dict_size as usize).min(pos); + let mut best_len: u32 = MATCH_LEN_MIN - 1; + let mut cur = self.head[h]; + let mut steps = 0usize; + while cur != NIL && steps < params.max_chain { + let cur_pos = cur as usize; + if cur_pos >= pos { + cur = self.prev[cur_pos]; + steps += 1; + continue; + } + let dist = pos - cur_pos; + if dist > max_dist { + break; + } + if best_len >= MATCH_LEN_MIN + && (best_len as usize) < (input.len() - pos) + && input[cur_pos + best_len as usize] != input[pos + best_len as usize] + { + cur = self.prev[cur_pos]; + steps += 1; + continue; + } + let mut len = 0u32; + while len < max_len && input[cur_pos + len as usize] == input[pos + len as usize] { + len += 1; + } + if len >= MATCH_LEN_MIN && len > best_len { + // Chain is walked nearest-first, so this is the shortest + // distance achieving every length in (best_len, len]. Record + // one entry at `len`. + out.push((len, (dist - 1) as u32)); + best_len = len; + if len >= params.nice_match || len >= max_len { + break; + } + } + cur = self.prev[cur_pos]; + steps += 1; + } + if best_len >= MATCH_LEN_MIN { + best_len + } else { + 0 + } + } } // ─── rep-match helpers ─────────────────────────────────────────────────── @@ -721,6 +1111,92 @@ fn rep_match_len(input: &[u8], pos: usize, dist: u32) -> u32 { len as u32 } +// ─── parse decision replay ──────────────────────────────────────────────── + +/// One parser decision, replayed through the real (probability-updating) +/// emit functions after the optimal parse has chosen it. +#[derive(Clone, Copy)] +enum Decision { + Literal, + /// New match: `(distance0based, length)`. + Match(u32, u32), + /// Long rep: `(rep_index, length)`. + Rep(u32, u32), + ShortRep, +} + +// ─── optimal parser ──────────────────────────────────────────────────────── + +/// A node in the optimum DP buffer. `price` is the cheapest known cost (in +/// 1/16-bit units) to reach this input offset; the back-pointer fields encode +/// the decision that produced the cheapest arrival. +#[derive(Clone, Copy)] +struct OptNode { + price: u32, + /// Offset of the previous node this arrival came from. + prev_pos: u32, + /// Decision taken from `prev_pos` to here. + decision: Decision, + /// State after arriving here. + state: usize, + /// Rep distances after arriving here. + reps: [u32; 4], +} + +const INFINITY_PRICE: u32 = u32::MAX; + +/// Scratch buffers for the optimal parser. +struct Optimizer { + opt: Vec, + matches: Vec<(u32, u32)>, + decisions: Vec, +} + +impl Optimizer { + fn new(window: usize) -> Self { + let cap = window + MAX_MATCH_LEN as usize + 2; + Self { + opt: vec![ + OptNode { + price: INFINITY_PRICE, + prev_pos: 0, + decision: Decision::Literal, + state: 0, + reps: [0; 4], + }; + cap + ], + matches: Vec::with_capacity(64), + decisions: Vec::with_capacity(window + 1), + } + } +} + +/// Compute the price of a literal at `pos` given the encoder's live state. +#[allow(clippy::too_many_arguments)] +fn literal_price_at( + core: &LzmaEncCore, + prices: &[u32; PRICE_TABLE_SIZE], + snap: &PriceSnapshot, + input: &[u8], + pos: usize, + out_pos: u64, + state: usize, + rep0: u32, +) -> u32 { + let pos_state = (out_pos as u32) & core.pos_mask; + let im_idx = state * POS_STATES_MAX + pos_state as usize; + let prev_byte = if pos > 0 { input[pos - 1] } else { 0 }; + let match_byte = if state < LIT_STATES { + None + } else { + let d = rep0 as usize + 1; + if d <= pos { Some(input[pos - d]) } else { None } + }; + snap.is_match[im_idx][0] + + core.literal_price(prices, out_pos, input[pos], prev_byte, match_byte) +} + // ─── public chunk encoder ──────────────────────────────────────────────── /// Encode `input` as a single LZMA2 compressed chunk payload (the @@ -734,12 +1210,57 @@ fn rep_match_len(input: &[u8], pos: usize, dist: u32) -> u32 { /// are bounded by this value. For LZMA2 the dict size is shared across /// all chunks of a block; pass a single value consistently. /// -/// `params` is the level-derived match-finder tuning; see +/// `params` is the level-derived match-finder + parser tuning; see /// [`EncoderParams::from_level`]. pub(crate) fn encode_lzma_chunk(input: &[u8], dict_size: u32, params: EncoderParams) -> Vec { + if params.opt_window == 0 { + return encode_chunk_body(input, dict_size, params, false); + } + // Run both parses and keep the smaller body. The optimal parse is almost + // always smaller, but on tiny, highly-repetitive inputs its cold-start + // price model can momentarily lose to greedy; this guard guarantees a + // level never regresses below the greedy baseline. + let opt = encode_chunk_body(input, dict_size, params, true); + let greedy = encode_chunk_body(input, dict_size, params, false); + if greedy.len() < opt.len() { + greedy + } else { + opt + } +} + +/// Encode one chunk body (range-coded packets + 5-byte flush, no EOS marker) +/// using the greedy or optimal parse. +fn encode_chunk_body( + input: &[u8], + dict_size: u32, + params: EncoderParams, + optimal: bool, +) -> Vec { let mut core = LzmaEncCore::new(); let mut hc = HashChain::new(input.len()); + if optimal { + encode_optimal(&mut core, &mut hc, input, dict_size, params); + } else { + encode_greedy(&mut core, &mut hc, input, dict_size, params); + } + + // Flush the range coder. NO EOS marker — LZMA2 frames the uncompressed + // length externally and decoders read exactly that many bytes. + core.rc.flush(); + core.rc.out +} + +/// Greedy/lazy parse — used by the lowest levels where speed matters most and +/// the optimal-parse overhead isn't worth it. +fn encode_greedy( + core: &mut LzmaEncCore, + hc: &mut HashChain, + input: &[u8], + dict_size: u32, + params: EncoderParams, +) { let mut pos = 0usize; while pos < input.len() { let rep_lens = [ @@ -794,9 +1315,342 @@ pub(crate) fn encode_lzma_chunk(input: &[u8], dict_size: u32, params: EncoderPar pos += 1; } } +} - // Flush the range coder. NO EOS marker — LZMA2 frames the uncompressed - // length externally and decoders read exactly that many bytes. - core.rc.flush(); - core.rc.out +/// Cost-based optimal parse: forward DP over a look-ahead window, committing +/// the cheapest path through the optimum buffer, then replaying decisions +/// through the real (probability-updating) emit functions. +fn encode_optimal( + core: &mut LzmaEncCore, + hc: &mut HashChain, + input: &[u8], + dict_size: u32, + params: EncoderParams, +) { + let prob_prices = build_prob_prices(); + let window = params.opt_window as usize; + let mut opt = Optimizer::new(window); + + let mut pos = 0usize; + // Refresh the price snapshot once per committed window. Prices drift as + // the model adapts; refreshing each window keeps them close to the live + // model without recomputing per byte. + while pos < input.len() { + let snap = core.price_snapshot(&prob_prices); + let parsed = parse_window( + core, + hc, + input, + pos, + dict_size, + params, + window, + &prob_prices, + &snap, + &mut opt, + ); + debug_assert!(parsed > 0); + // Replay the chosen decisions through the real emit path. `pos` + // advances by exactly `parsed` bytes. + replay(core, hc, input, pos, &opt.decisions); + pos += parsed; + } +} + +/// Parse a single look-ahead window starting at `start`. Fills +/// `opt.decisions` with the cheapest sequence of decisions covering the +/// reachable commit boundary, and returns the number of input bytes the +/// decisions consume. The hash chain is NOT mutated here (read-only match +/// finding); `replay` handles insertion. +#[allow(clippy::too_many_arguments)] +fn parse_window( + core: &LzmaEncCore, + hc: &HashChain, + input: &[u8], + start: usize, + dict_size: u32, + params: EncoderParams, + window: usize, + prices: &[u32; PRICE_TABLE_SIZE], + snap: &PriceSnapshot, + opt: &mut Optimizer, +) -> usize { + let avail = input.len() - start; + let limit = window.min(avail); + + // Initialize node 0 with the encoder's current live state. + opt.opt[0] = OptNode { + price: 0, + prev_pos: 0, + decision: Decision::Literal, + state: core.state, + reps: [core.rep0, core.rep1, core.rep2, core.rep3], + }; + for node in opt.opt[1..=limit].iter_mut() { + node.price = INFINITY_PRICE; + } + + // Hard commit cap: even without a long match we commit after this many + // bytes so the price snapshot is refreshed frequently against the live + // (adapting) model. Without this, a long literal run parsed under a single + // stale snapshot makes systematically worse rep-vs-match decisions. + const COMMIT_CAP: usize = 192; + + // `reached` is the furthest offset we've filled a finite price for. + let mut reached = 0usize; + // When a long match is found at some position we stop extending the DP and + // commit up to that match's end, keeping the committed segment short so the + // price snapshot stays close to the live model. `None` means run to the + // window limit. + let mut commit_end: Option = None; + + let mut cur = 0usize; + while cur < limit { + if let Some(ce) = commit_end + && cur >= ce + { + break; + } + if commit_end.is_none() && cur >= COMMIT_CAP { + commit_end = Some(cur); + break; + } + let node = opt.opt[cur]; + if node.price == INFINITY_PRICE { + cur += 1; + continue; + } + let pos = start + cur; + let out_pos = core.output_pos + cur as u64; + let state = node.state; + let reps = node.reps; + let pos_state = (out_pos as u32) & core.pos_mask; + let im_idx = state * POS_STATES_MAX + pos_state as usize; + // Longest match (rep or new) seen at this position; drives the + // early-commit decision below. + let mut best_here: u32 = 0; + + // ── literal transition ────────────────────────────────────────── + { + let lp = literal_price_at(core, prices, snap, input, pos, out_pos, state, reps[0]); + let np = node.price.saturating_add(lp); + let to = cur + 1; + if to <= limit && np < opt.opt[to].price { + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::Literal, + state: state_after_literal(state), + reps, + }; + if to > reached { + reached = to; + } + } + } + + // Base price of choosing "match" (is_match=1). + let match_flag = snap.is_match[im_idx][1]; + + // ── rep matches (rep0..rep3) ──────────────────────────────────── + for rep_idx in 0..4u32 { + let rlen = rep_match_len(input, pos, reps[rep_idx as usize]); + if rlen < 1 { + continue; + } + // Short-rep (length 1, rep0 only). + if rep_idx == 0 { + let sp = match_flag + + snap.is_rep[state][1] + + snap.is_rep0[state][0] + + snap.is_rep0_long[im_idx][0]; + let np = node.price.saturating_add(sp); + let to = cur + 1; + if to <= limit && np < opt.opt[to].price { + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::ShortRep, + state: state_after_short_rep(state), + reps, + }; + if to > reached { + reached = to; + } + } + } + if rlen < MATCH_LEN_MIN { + continue; + } + if rlen > best_here { + best_here = rlen; + } + let rep_new_reps = reorder_reps(reps, rep_idx); + let choice = match_flag + snap.rep_choice_price(state, rep_idx); + let rep0_long = if rep_idx == 0 { + snap.is_rep0_long[im_idx][1] + } else { + 0 + }; + let st_after = state_after_rep(state); + let cap = (limit - cur) as u32; + let maxr = rlen.min(cap); + let mut l = MATCH_LEN_MIN; + while l <= maxr { + let len_price = core + .rep_len_coder + .price(prices, pos_state, l - MATCH_LEN_MIN); + let np = node.price.saturating_add(choice + rep0_long + len_price); + let to = cur + l as usize; + if np < opt.opt[to].price { + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::Rep(rep_idx, l), + state: st_after, + reps: rep_new_reps, + }; + if to > reached { + reached = to; + } + } + l += 1; + } + } + + // ── new matches ───────────────────────────────────────────────── + let longest = { + let opt_matches = &mut opt.matches; + hc.find_matches(input, pos, dict_size, params, opt_matches) + }; + if longest >= MATCH_LEN_MIN { + if longest > best_here { + best_here = longest; + } + let match_choice = match_flag + snap.is_rep[state][0]; + let st_after = state_after_match(state); + let cap = (limit - cur) as u32; + let mut prev_len = MATCH_LEN_MIN - 1; + let nmatches = opt.matches.len(); + for mi in 0..nmatches { + let (mlen, mdist) = opt.matches[mi]; + let band_end = mlen.min(cap); + let mut l = (prev_len + 1).max(MATCH_LEN_MIN); + while l <= band_end { + let len_price = core.len_coder.price(prices, pos_state, l - MATCH_LEN_MIN); + let dist_price = core.distance_price(prices, l, mdist); + let np = node + .price + .saturating_add(match_choice + len_price + dist_price); + let to = cur + l as usize; + if np < opt.opt[to].price { + let new_reps = [mdist, reps[0], reps[1], reps[2]]; + opt.opt[to] = OptNode { + price: np, + prev_pos: cur as u32, + decision: Decision::Match(mdist, l), + state: st_after, + reps: new_reps, + }; + if to > reached { + reached = to; + } + } + l += 1; + } + prev_len = mlen; + } + } + + // Early-commit: once a long match is reachable from this position, the + // optimal path almost certainly takes it, and there's little value in + // extending the DP past it with increasingly stale prices. Commit up + // to its end. This mirrors the SDK's `nice_len` cut-off in GetOptimum. + if commit_end.is_none() && best_here >= params.nice_len { + let bounded = (cur + best_here as usize).min(limit); + commit_end = Some(bounded); + } + + cur += 1; + } + + // Commit boundary. If an early long match capped the DP, commit exactly to + // its end; otherwise commit the furthest reached offset (always `limit`, + // since literals reach every offset). `max(1)` guards `limit == 0`. + let end = match commit_end { + Some(ce) => ce.max(1).min(reached.max(1)), + None => reached.max(1), + } + .min(avail); + trace_back(opt, end); + end +} + +/// Reorder rep distances for a long rep referencing index `rep_idx`. +fn reorder_reps(reps: [u32; 4], rep_idx: u32) -> [u32; 4] { + match rep_idx { + 0 => reps, + 1 => [reps[1], reps[0], reps[2], reps[3]], + 2 => [reps[2], reps[0], reps[1], reps[3]], + _ => [reps[3], reps[0], reps[1], reps[2]], + } +} + +/// Trace back the cheapest path from offset `end` to 0, filling +/// `opt.decisions` in forward order. +fn trace_back(opt: &mut Optimizer, end: usize) { + opt.decisions.clear(); + let mut cur = end; + while cur > 0 { + let node = opt.opt[cur]; + opt.decisions.push(node.decision); + cur = node.prev_pos as usize; + } + opt.decisions.reverse(); +} + +/// Replay chosen decisions through the real emit path, updating the hash chain +/// and the live probability model. +fn replay( + core: &mut LzmaEncCore, + hc: &mut HashChain, + input: &[u8], + start: usize, + decisions: &[Decision], +) { + let mut pos = start; + for &d in decisions { + match d { + Decision::Literal => { + hc.insert(input, pos); + core.emit_literal(input, pos); + pos += 1; + } + Decision::ShortRep => { + hc.insert(input, pos); + core.emit_short_rep(); + pos += 1; + } + Decision::Match(dist, len) => { + for j in 0..(len as usize) { + let p = pos + j; + if p + 3 <= input.len() { + hc.insert(input, p); + } + } + core.emit_match(dist, len); + pos += len as usize; + } + Decision::Rep(idx, len) => { + for j in 0..(len as usize) { + let p = pos + j; + if p + 3 <= input.len() { + hc.insert(input, p); + } + } + core.emit_long_rep(idx, len); + pos += len as usize; + } + } + } } diff --git a/src/zstd/encoder.rs b/src/zstd/encoder.rs index 86fcef8..10719cb 100644 --- a/src/zstd/encoder.rs +++ b/src/zstd/encoder.rs @@ -43,7 +43,7 @@ use crate::zstd::encoder_fse::{ }; use crate::zstd::encoder_huffman::{ HuffLengths, build_huff_encoder, build_huff_lengths, encode_huff_4streams, encode_huff_stream, - encode_huff_tree_direct, histogram, lengths_to_weights, predicted_bits, + encode_huff_tree_direct, encode_huff_tree_fse, histogram, lengths_to_weights, predicted_bits, }; use crate::zstd::encoder_seq::{encode_sequence_count, ll_code, ml_code, of_code}; use crate::zstd::matcher::{MIN_MATCH, MatchFinder}; @@ -108,8 +108,22 @@ pub(crate) struct LevelParams { /// take the later match if it's meaningfully longer. Mirrors zstd's /// `lazy`/`lazy2` strategies (we do single-step lookahead only). pub lazy_search: bool, + /// When true, the parser runs a price-based optimal parse (forward DP over + /// the whole block) instead of greedy/lazy. Enabled at high levels where + /// the extra CPU is acceptable. Mirrors zstd's `btopt`/`btultra`. + pub optimal: bool, } +/// Lowest level at which the optimal parser is used. +const OPTIMAL_LEVEL: u8 = 13; + +/// Per-position hash-chain depth cap for the optimal parser. The DP visits +/// every position, so an uncapped chain (up to 16384 at level 22) makes each +/// block quadratic; this bound keeps encode time reasonable while preserving +/// nearly all of the ratio (the DP's win comes from length/repeat pricing, +/// not from exhaustive chain walks). +const OPTIMAL_MAX_CHAIN: usize = 4096; + impl LevelParams { /// Clamp `level` to `1..=22` and expand to match-finder tuning. The /// table broadly tracks zstd's reference presets but doesn't try to @@ -122,118 +136,37 @@ impl LevelParams { // Lazy parsing kicks in at level 4 — matches zstd's reference table // where `lazy` strategies start at level 4. let lazy_search = level >= 4; - match level { - 1 => Self { - max_chain: 4, - nice_match: 8, - lazy_search, - }, - 2 => Self { - max_chain: 8, - nice_match: 12, - lazy_search, - }, - 3 => Self { - max_chain: 16, - nice_match: 16, - lazy_search, - }, - 4 => Self { - max_chain: 24, - nice_match: 24, - lazy_search, - }, - 5 => Self { - max_chain: 32, - nice_match: 32, - lazy_search, - }, - 6 => Self { - max_chain: 48, - nice_match: 48, - lazy_search, - }, - 7 => Self { - max_chain: 64, - nice_match: 64, - lazy_search, - }, - 8 => Self { - max_chain: 96, - nice_match: 96, - lazy_search, - }, - 9 => Self { - max_chain: 128, - nice_match: 128, - lazy_search, - }, - 10 => Self { - max_chain: 192, - nice_match: 160, - lazy_search, - }, - 11 => Self { - max_chain: 256, - nice_match: 192, - lazy_search, - }, - 12 => Self { - max_chain: 384, - nice_match: 224, - lazy_search, - }, - 13 => Self { - max_chain: 512, - nice_match: 256, - lazy_search, - }, - 14 => Self { - max_chain: 768, - nice_match: 384, - lazy_search, - }, - 15 => Self { - max_chain: 1024, - nice_match: 512, - lazy_search, - }, - 16 => Self { - max_chain: 1536, - nice_match: 768, - lazy_search, - }, - 17 => Self { - max_chain: 2048, - nice_match: 1024, - lazy_search, - }, - 18 => Self { - max_chain: 3072, - nice_match: 1536, - lazy_search, - }, - 19 => Self { - max_chain: 4096, - nice_match: 2048, - lazy_search, - }, - 20 => Self { - max_chain: 6144, - nice_match: 3072, - lazy_search, - }, - 21 => Self { - max_chain: 8192, - nice_match: 4096, - lazy_search, - }, + let optimal = level >= OPTIMAL_LEVEL; + let (max_chain, nice_match) = match level { + 1 => (4, 8), + 2 => (8, 12), + 3 => (16, 16), + 4 => (24, 24), + 5 => (32, 32), + 6 => (48, 48), + 7 => (64, 64), + 8 => (96, 96), + 9 => (128, 128), + 10 => (192, 160), + 11 => (256, 192), + 12 => (384, 224), + 13 => (512, 256), + 14 => (768, 384), + 15 => (1024, 512), + 16 => (1536, 768), + 17 => (2048, 1024), + 18 => (3072, 1536), + 19 => (4096, 2048), + 20 => (6144, 3072), + 21 => (8192, 4096), // 22 (and clamp-from-above) - _ => Self { - max_chain: 16384, - nice_match: super::matcher::MAX_MATCH, - lazy_search, - }, + _ => (16384, super::matcher::MAX_MATCH), + }; + Self { + max_chain, + nice_match, + lazy_search, + optimal, } } } @@ -350,16 +283,48 @@ impl Encoder { // each position first — a repeat-offset match costs 1 bit in the // offset stream vs. ~log2(distance) bits for a fresh offset, so even // short repeats are cheap wins. - let mut sequences: Vec = Vec::new(); - let mut literals: Vec = Vec::with_capacity(buffer.len()); - let mut lit_start: usize = 0; - let mut pos: usize = 0; - let mut block_offsets = self.prev_offsets; let lazy = self.params.lazy_search; let buf_len = buffer.len(); let max_chain = self.params.max_chain; let nice_match = self.params.nice_match; + // High levels: price-based optimal parse over the whole block. The DP + // probes a match candidate at every input position, so we cap the + // per-position chain depth to keep the per-block cost bounded — the DP + // recovers most of the ratio from trying lengths and repeat offsets + // rather than from exhaustive chain walks. + if self.params.optimal { + let opt_chain = max_chain.min(OPTIMAL_MAX_CHAIN); + let (sequences, new_offsets) = optimal_parse( + &mut self.matcher, + buffer, + self.prev_offsets, + opt_chain, + nice_match, + ); + if sequences.is_empty() { + return None; + } + return finish_compressed_block( + buffer, + &sequences, + new_offsets, + self.prev_huff_lengths.as_ref(), + ) + .map(|(body, new_lengths, committed_offsets)| { + self.prev_offsets = committed_offsets; + if let Some(lengths) = new_lengths { + self.prev_huff_lengths = Some(lengths); + } + body + }); + } + + let mut sequences: Vec = Vec::new(); + let mut lit_start: usize = 0; + let mut pos: usize = 0; + let mut block_offsets = self.prev_offsets; + // Invariant: positions in [0, next_insert) have already been spliced // into the matcher's hash chain. We advance `next_insert` lazily. let mut next_insert: usize = 0; @@ -422,7 +387,6 @@ impl Encoder { let literal_run = best_pos - lit_start; let offset_value = assign_offset(best_dist as u32, literal_run as u32, &mut block_offsets); - literals.extend_from_slice(&buffer[lit_start..best_pos]); sequences.push(Seq { literal_length: literal_run as u32, match_length: best_len as u32, @@ -440,214 +404,234 @@ impl Encoder { lit_start = pos; } + let _ = lit_start; if sequences.is_empty() { return None; } - // Trailing literals: from lit_start to end of buffer. - let trailing_literals = &buffer[lit_start..]; - - // Build all literal bytes (LZ77 literals + trailing) for use in - // literals-section construction. - let mut all_literals: Vec = - Vec::with_capacity(literals.len() + trailing_literals.len()); - all_literals.extend_from_slice(&literals); - all_literals.extend_from_slice(trailing_literals); + finish_compressed_block( + buffer, + &sequences, + block_offsets, + self.prev_huff_lengths.as_ref(), + ) + .map(|(body, new_lengths, committed_offsets)| { + self.prev_offsets = committed_offsets; + if let Some(lengths) = new_lengths { + self.prev_huff_lengths = Some(lengths); + } + body + }) + } +} - // Build literals section. Try Huffman first (with optional Treeless - // reuse of the previous block's tree); fall back to raw. - let (lit_section, new_lengths) = - build_literals_section(&all_literals, self.prev_huff_lengths.as_ref()); +/// Shared back half of block compression: reconstruct the literal byte stream +/// from the chosen sequences, build the literals + sequences sections, and +/// return `(body, new_huff_lengths, committed_offsets)` if the compressed body +/// beats a Raw_Block. The caller commits the returned state. Free function so +/// both the greedy/lazy and the optimal parsers can share it without aliasing +/// `self.pending` (which `buffer` borrows) against `&mut self`. +fn finish_compressed_block( + buffer: &[u8], + sequences: &[Seq], + block_offsets: [u32; 3], + prev_huff_lengths: Option<&HuffLengths>, +) -> Option<(Vec, Option, [u32; 3])> { + // Reconstruct all literal bytes by replaying the sequences: each sequence + // emits `literal_length` literals from the cursor, then skips + // `match_length` matched bytes. Trailing bytes after the last sequence are + // literals too. + let mut all_literals: Vec = Vec::with_capacity(buffer.len()); + let mut cursor = 0usize; + for s in sequences { + let ll = s.literal_length as usize; + all_literals.extend_from_slice(&buffer[cursor..cursor + ll]); + cursor += ll + s.match_length as usize; + } + all_literals.extend_from_slice(&buffer[cursor..]); - // Build sequences section. - let seq_section = self.build_sequences_section(&sequences); + let (lit_section, new_lengths) = build_literals_section(&all_literals, prev_huff_lengths); + let seq_section = build_sequences_section(sequences); - let total = lit_section.len() + seq_section.len(); - let raw_size = buffer.len(); - if total >= raw_size { - return None; // Not worth compressing. - } + let total = lit_section.len() + seq_section.len(); + if total >= buffer.len() { + return None; // Not worth compressing. + } - // Commit the per-block offset history and (if we emitted a Huffman - // tree) the new lengths to the encoder state. - self.prev_offsets = block_offsets; - if let Some(lengths) = new_lengths { - self.prev_huff_lengths = Some(lengths); - } + let mut body = Vec::with_capacity(total); + body.extend_from_slice(&lit_section); + body.extend_from_slice(&seq_section); + Some((body, new_lengths, block_offsets)) +} - let mut body = Vec::with_capacity(total); - body.extend_from_slice(&lit_section); - body.extend_from_slice(&seq_section); - Some(body) +/// Build the sequence section bytes: header (count + symbol-modes byte) +/// followed by the FSE-encoded sequence bitstream. +/// +/// Per-table mode selection: for each of LL/OF/ML we try the predefined +/// distribution against a custom FSE_Compressed_Mode distribution built from +/// this block's actual code histogram. Whichever produces the smaller +/// estimated byte count wins. +fn build_sequences_section(sequences: &[Seq]) -> Vec { + let n = sequences.len() as u32; + + // Pre-compute (code, extra_bits, extra_val) for each sequence. + let mut ll_codes: Vec = Vec::with_capacity(sequences.len()); + let mut ml_codes: Vec = Vec::with_capacity(sequences.len()); + let mut of_codes: Vec = Vec::with_capacity(sequences.len()); + let mut ll_extras: Vec<(u32, u32)> = Vec::with_capacity(sequences.len()); + let mut ml_extras: Vec<(u32, u32)> = Vec::with_capacity(sequences.len()); + let mut of_extras: Vec<(u32, u32)> = Vec::with_capacity(sequences.len()); + + for s in sequences { + let (oc, oe_bits, oe_val) = of_code(s.offset_value); + of_codes.push(oc); + of_extras.push((oe_bits, oe_val)); + + let (lc, le_bits, le_val) = ll_code(s.literal_length); + ll_codes.push(lc); + ll_extras.push((le_bits, le_val)); + + let (mc, me_bits, me_val) = ml_code(s.match_length); + ml_codes.push(mc); + ml_extras.push((me_bits, me_val)); } - /// Build the sequence section bytes: header (count + symbol-modes byte) - /// followed by the FSE-encoded sequence bitstream. - /// - /// Per-table mode selection: for each of LL/OF/ML we try the predefined - /// distribution against a custom FSE_Compressed_Mode distribution built - /// from this block's actual code histogram. Whichever produces the - /// smaller estimated byte count wins. - fn build_sequences_section(&self, sequences: &[Seq]) -> Vec { - let n = sequences.len() as u32; - - // Pre-compute (code, extra_bits, extra_val) for each sequence. - let mut ll_codes: Vec = Vec::with_capacity(sequences.len()); - let mut ml_codes: Vec = Vec::with_capacity(sequences.len()); - let mut of_codes: Vec = Vec::with_capacity(sequences.len()); - let mut ll_extras: Vec<(u32, u32)> = Vec::with_capacity(sequences.len()); - let mut ml_extras: Vec<(u32, u32)> = Vec::with_capacity(sequences.len()); - let mut of_extras: Vec<(u32, u32)> = Vec::with_capacity(sequences.len()); - - for s in sequences { - let (oc, oe_bits, oe_val) = of_code(s.offset_value); - of_codes.push(oc); - of_extras.push((oe_bits, oe_val)); - - let (lc, le_bits, le_val) = ll_code(s.literal_length); - ll_codes.push(lc); - ll_extras.push((le_bits, le_val)); - - let (mc, me_bits, me_val) = ml_code(s.match_length); - ml_codes.push(mc); - ml_extras.push((me_bits, me_val)); - } - - // Pick per-table mode and build the encoders + any header bytes. - let (ll_enc, ll_mode, ll_header) = pick_table( - &ll_codes, - &DEFAULT_LL_COUNTS, - DEFAULT_LL_ACCURACY_LOG, - 9, - 35, - ); - let (of_enc, of_mode, of_header) = pick_table( - &of_codes, - &DEFAULT_OF_COUNTS, - DEFAULT_OF_ACCURACY_LOG, - 8, - 31, - ); - let (ml_enc, ml_mode, ml_header) = pick_table( - &ml_codes, - &DEFAULT_ML_COUNTS, - DEFAULT_ML_ACCURACY_LOG, - 9, - 52, - ); - - // Build the sequences-section bytes. - let mut out = encode_sequence_count(n); - // Symbol_Compression_Modes byte: bits [7:6]=LL_Mode, [5:4]=OF_Mode, - // [3:2]=ML_Mode, [1:0]=Reserved. - let modes: u8 = (ll_mode << 6) | (of_mode << 4) | (ml_mode << 2); - out.push(modes); - out.extend_from_slice(&ll_header); - out.extend_from_slice(&of_header); - out.extend_from_slice(&ml_header); - - // FSE-encode the symbol streams. - let mut writer = RevBitWriter::new(); - let n_seq = sequences.len(); - - // Reverse encoding pattern. Init states from the LAST sequence. - let mut ll_state = ll_enc.init_state(ll_codes[n_seq - 1] as usize); - let mut of_state = of_enc.init_state(of_codes[n_seq - 1] as usize); - let mut ml_state = ml_enc.init_state(ml_codes[n_seq - 1] as usize); - - // For each sequence (processed in reverse), write to the bitstream - // in the EXACT REVERSE of the decoder's read order. - // - // Decoder per-sequence read order (recall §3.1.1.3.2.1): - // 1. OF_extra_bits (number = of_code value) - // 2. ML_extra_bits - // 3. LL_extra_bits - // 4. (only if not last sequence): LL_advance, ML_advance, OF_advance. - // - // The reverse-bitstream writer is "first-written = last-read". So if - // we walk sequences i = n-1 → 0: - // For i = n-1 (DECODER's last sequence): write extras only, in - // reverse read order: write LL_extra first, then ML_extra, then - // OF_extra. - // For i < n-1: write the FSE advance bits for THIS sequence's - // transition (out_OF, then out_ML, then out_LL — reverse of the - // decoder's LL, ML, OF advance read order), THEN write the - // extras (LL, ML, OF reversed). - // - // FSE advance bits are emitted by `encode_symbol(state, sym)`. - // The bits returned correspond to the decoder's read at that - // advance step. - // - // To produce the correct interleaving, we structure the loop: - // for i in (0..n_seq).rev() { - // if i == n_seq - 1 { - // // No advance for the last decoder-side sequence. - // } else { - // // Advance: encode the transition FROM sequence i+1's - // // state INTO sequence i's state for each of OF, ML, LL. - // // Decoder reads advance order LL, ML, OF — so we write - // // OF first (most recently read), then ML, then LL. - // of_state = self.of_enc.encode_symbol(of_state, of_codes[i] as usize, &mut writer); - // ml_state = self.ml_enc.encode_symbol(ml_state, ml_codes[i] as usize, &mut writer); - // ll_state = self.ll_enc.encode_symbol(ll_state, ll_codes[i] as usize, &mut writer); - // } - // // Extras: decoder reads OF, ML, LL — write LL, ML, OF. - // writer.write_bits(ll_extras[i].1 as u64, ll_extras[i].0); - // writer.write_bits(ml_extras[i].1 as u64, ml_extras[i].0); - // writer.write_bits(of_extras[i].1 as u64, of_extras[i].0); - // } - // - // Hmm wait — encode_symbol(state, sym) consumes the CURRENT state - // (which corresponds to the decoder's PRE-advance state) and - // produces NEW state (decoder's POST-advance state). The bits - // written are the bits the decoder reads to perform the advance. - // - // The decoder advances at the END of sequence i (using sequence i's - // current state to compute next_state for sequence i+1). So the - // bits FOR THIS ADVANCE are read at the END of sequence i's - // processing. From sequence i+1's POV, the state was set up by - // this advance. - // - // We're processing sequences in reverse (i from n-1 to 0). When - // i = n-2, we're handling the SECOND-TO-LAST sequence (decoder- - // side). The advance bits at this point are the ones the decoder - // reads at the END of i=n-2 to set up i=n-1's state. So we encode - // the transition FROM sequence n-2's state INTO n-1's state. - // - // In our reverse loop, "current state" represents sequence n-1's - // initial state (set up via init_state). After encode_symbol with - // ll_codes[n-2], the state will represent sequence n-2's initial - // state. The BITS written reflect the (current → new) transition - // i.e. n-2 → n-1 advance (since current = n-1 before). - // - // So `encode_symbol(state_for_seq_iplus1, codes[i])` writes the - // bits the decoder reads at the end of seq i to advance from - // seq_i.state to seq_(i+1).state. ✓ - for i in (0..n_seq).rev() { - if i == n_seq - 1 { - // No advance bits for the decoder's last sequence. - } else { - of_state = of_enc.encode_symbol(of_state, of_codes[i] as usize, &mut writer); - ml_state = ml_enc.encode_symbol(ml_state, ml_codes[i] as usize, &mut writer); - ll_state = ll_enc.encode_symbol(ll_state, ll_codes[i] as usize, &mut writer); - } - // Extras: decoder reads OF, ML, LL — write LL, ML, OF. - writer.write_bits(ll_extras[i].1 as u64, ll_extras[i].0); - writer.write_bits(ml_extras[i].1 as u64, ml_extras[i].0); - writer.write_bits(of_extras[i].1 as u64, of_extras[i].0); + // Pick per-table mode and build the encoders + any header bytes. + let (ll_enc, ll_mode, ll_header) = pick_table( + &ll_codes, + &DEFAULT_LL_COUNTS, + DEFAULT_LL_ACCURACY_LOG, + 9, + 35, + ); + let (of_enc, of_mode, of_header) = pick_table( + &of_codes, + &DEFAULT_OF_COUNTS, + DEFAULT_OF_ACCURACY_LOG, + 8, + 31, + ); + let (ml_enc, ml_mode, ml_header) = pick_table( + &ml_codes, + &DEFAULT_ML_COUNTS, + DEFAULT_ML_ACCURACY_LOG, + 9, + 52, + ); + + // Build the sequences-section bytes. + let mut out = encode_sequence_count(n); + // Symbol_Compression_Modes byte: bits [7:6]=LL_Mode, [5:4]=OF_Mode, + // [3:2]=ML_Mode, [1:0]=Reserved. + let modes: u8 = (ll_mode << 6) | (of_mode << 4) | (ml_mode << 2); + out.push(modes); + out.extend_from_slice(&ll_header); + out.extend_from_slice(&of_header); + out.extend_from_slice(&ml_header); + + // FSE-encode the symbol streams. + let mut writer = RevBitWriter::new(); + let n_seq = sequences.len(); + + // Reverse encoding pattern. Init states from the LAST sequence. + let mut ll_state = ll_enc.init_state(ll_codes[n_seq - 1] as usize); + let mut of_state = of_enc.init_state(of_codes[n_seq - 1] as usize); + let mut ml_state = ml_enc.init_state(ml_codes[n_seq - 1] as usize); + + // For each sequence (processed in reverse), write to the bitstream + // in the EXACT REVERSE of the decoder's read order. + // + // Decoder per-sequence read order (recall §3.1.1.3.2.1): + // 1. OF_extra_bits (number = of_code value) + // 2. ML_extra_bits + // 3. LL_extra_bits + // 4. (only if not last sequence): LL_advance, ML_advance, OF_advance. + // + // The reverse-bitstream writer is "first-written = last-read". So if + // we walk sequences i = n-1 → 0: + // For i = n-1 (DECODER's last sequence): write extras only, in + // reverse read order: write LL_extra first, then ML_extra, then + // OF_extra. + // For i < n-1: write the FSE advance bits for THIS sequence's + // transition (out_OF, then out_ML, then out_LL — reverse of the + // decoder's LL, ML, OF advance read order), THEN write the + // extras (LL, ML, OF reversed). + // + // FSE advance bits are emitted by `encode_symbol(state, sym)`. + // The bits returned correspond to the decoder's read at that + // advance step. + // + // To produce the correct interleaving, we structure the loop: + // for i in (0..n_seq).rev() { + // if i == n_seq - 1 { + // // No advance for the last decoder-side sequence. + // } else { + // // Advance: encode the transition FROM sequence i+1's + // // state INTO sequence i's state for each of OF, ML, LL. + // // Decoder reads advance order LL, ML, OF — so we write + // // OF first (most recently read), then ML, then LL. + // of_state = self.of_enc.encode_symbol(of_state, of_codes[i] as usize, &mut writer); + // ml_state = self.ml_enc.encode_symbol(ml_state, ml_codes[i] as usize, &mut writer); + // ll_state = self.ll_enc.encode_symbol(ll_state, ll_codes[i] as usize, &mut writer); + // } + // // Extras: decoder reads OF, ML, LL — write LL, ML, OF. + // writer.write_bits(ll_extras[i].1 as u64, ll_extras[i].0); + // writer.write_bits(ml_extras[i].1 as u64, ml_extras[i].0); + // writer.write_bits(of_extras[i].1 as u64, of_extras[i].0); + // } + // + // Hmm wait — encode_symbol(state, sym) consumes the CURRENT state + // (which corresponds to the decoder's PRE-advance state) and + // produces NEW state (decoder's POST-advance state). The bits + // written are the bits the decoder reads to perform the advance. + // + // The decoder advances at the END of sequence i (using sequence i's + // current state to compute next_state for sequence i+1). So the + // bits FOR THIS ADVANCE are read at the END of sequence i's + // processing. From sequence i+1's POV, the state was set up by + // this advance. + // + // We're processing sequences in reverse (i from n-1 to 0). When + // i = n-2, we're handling the SECOND-TO-LAST sequence (decoder- + // side). The advance bits at this point are the ones the decoder + // reads at the END of i=n-2 to set up i=n-1's state. So we encode + // the transition FROM sequence n-2's state INTO n-1's state. + // + // In our reverse loop, "current state" represents sequence n-1's + // initial state (set up via init_state). After encode_symbol with + // ll_codes[n-2], the state will represent sequence n-2's initial + // state. The BITS written reflect the (current → new) transition + // i.e. n-2 → n-1 advance (since current = n-1 before). + // + // So `encode_symbol(state_for_seq_iplus1, codes[i])` writes the + // bits the decoder reads at the end of seq i to advance from + // seq_i.state to seq_(i+1).state. ✓ + for i in (0..n_seq).rev() { + if i == n_seq - 1 { + // No advance bits for the decoder's last sequence. + } else { + of_state = of_enc.encode_symbol(of_state, of_codes[i] as usize, &mut writer); + ml_state = ml_enc.encode_symbol(ml_state, ml_codes[i] as usize, &mut writer); + ll_state = ll_enc.encode_symbol(ll_state, ll_codes[i] as usize, &mut writer); } + // Extras: decoder reads OF, ML, LL — write LL, ML, OF. + writer.write_bits(ll_extras[i].1 as u64, ll_extras[i].0); + writer.write_bits(ml_extras[i].1 as u64, ml_extras[i].0); + writer.write_bits(of_extras[i].1 as u64, of_extras[i].0); + } - // Write final FSE states (decoder reads these via init in order - // LL, OF, ML — we write reverse: ML, OF, LL). - ml_enc.write_final_state(ml_state, &mut writer); - of_enc.write_final_state(of_state, &mut writer); - ll_enc.write_final_state(ll_state, &mut writer); + // Write final FSE states (decoder reads these via init in order + // LL, OF, ML — we write reverse: ML, OF, LL). + ml_enc.write_final_state(ml_state, &mut writer); + of_enc.write_final_state(of_state, &mut writer); + ll_enc.write_final_state(ll_state, &mut writer); - let bitstream = writer.finish(); - out.extend_from_slice(&bitstream); - out - } + let bitstream = writer.finish(); + out.extend_from_slice(&bitstream); + out +} +impl Encoder { /// Flush `pending` as a single block (RLE / compressed / raw — whichever /// is smallest). Sets `last` on the block header. fn flush_block(&mut self, last: bool) { @@ -693,6 +677,208 @@ impl Encoder { } } +// ─── price-based optimal parser ─────────────────────────────────────────── + +/// Estimated bit cost of a literal byte (~Huffman-coded text/code literal). +/// Only the literal-vs-match trade-off depends on it, not correctness. +const LIT_PRICE: u32 = 9; + +/// Estimated bit cost of the offset part of a match: the FSE offset code plus +/// its extra bits, with a distance matching one of the active repeat offsets +/// priced near-free (repeats emit a tiny FSE code and NO offset extra bits). +fn offset_price(distance: u32, reps: &[u32; 3], ll: u32) -> u32 { + let is_rep = if ll > 0 { + distance == reps[0] || distance == reps[1] || distance == reps[2] + } else { + distance == reps[1] || distance == reps[2] || (reps[0] > 1 && distance == reps[0] - 1) + }; + if is_rep { + return 4; + } + // Fresh offset: `code` extra bits (the literal low bits of the distance) + // plus the FSE-coded offset code itself (~5 bits amortised). The FSE code + // adapts to the block, so it is NOT another `log2(D)` — charging that would + // double-count and push the DP away from good long-distance matches. + let val = distance + 3; + let code = 31 - val.leading_zeros(); + code + 5 +} + +/// Estimated bit cost of the literal-length / match-length FSE codes plus +/// their extra bits for a sequence with the given run/length. +fn ll_ml_price(literal_length: u32, match_length: u32) -> u32 { + let (_lc, lb, _lv) = ll_code(literal_length); + let (_mc, mb, _mv) = ml_code(match_length); + 10 + lb + mb +} + +/// Update the repeat-offset ring after a match (mirrors `assign_offset`'s +/// transitions) and return the new ring. Used to carry rep state along the +/// optimal-parse DP path. +fn advance_reps(distance: u32, literal_length: u32, reps: &[u32; 3]) -> [u32; 3] { + let mut r = *reps; + let _ = assign_offset(distance, literal_length, &mut r); + r +} + +/// Price-based optimal parse of `buffer` into a sequence list. +/// +/// Forward dynamic program: `price[i]` is the cheapest estimated bit cost to +/// encode `buffer[0..i]`. Each position can be reached by emitting a literal +/// (advance 1) or a match of some length (advance L). Match candidates come +/// from the hash chain plus the three active repeat offsets, and every length +/// from `MIN_MATCH` up to a candidate's max is priced — so the DP can pick a +/// slightly shorter match that lands on a cheaper (closer or repeated) offset. +/// Repeat offsets are priced near-free, which is where most of the win over +/// greedy/lazy parsing comes from (their offset extra bits dominate output). +/// +/// Returns the chosen sequences (in order) and the final repeat-offset ring. +fn optimal_parse( + matcher: &mut MatchFinder, + buffer: &[u8], + init_offsets: [u32; 3], + max_chain: usize, + nice_match: usize, +) -> (Vec, [u32; 3]) { + let n = buffer.len(); + if n < MIN_MATCH + 1 { + return (Vec::new(), init_offsets); + } + + // Insert every hashable position up front so chain walks see the whole + // block (back-references only look earlier, so insertion order within the + // block doesn't affect correctness). + matcher.resize_for(n); + for i in 0..n.saturating_sub(3) { + matcher.insert(buffer, i); + } + + const INF: u32 = u32::MAX; + let mut price: Vec = vec![INF; n + 1]; + // Back-pointer: (prev_pos, match_len, match_dist). match_len == 0 → literal. + let mut back: Vec<(u32, u32, u32)> = vec![(0, 0, 0); n + 1]; + let mut reps_at: Vec<[u32; 3]> = vec![init_offsets; n + 1]; + price[0] = 0; + + // Step length sparsely for long matches to bound DP work. Dense up to 128 + // (where most matches live), then coarser. + let push_len = |l: usize, max_l: usize| -> usize { + let step = if l < 128 { 1 } else { 32 }; + let next = l + step; + if next > max_l && l < max_l { + max_l + } else { + next + } + }; + + let mut cands: Vec = Vec::new(); + + for i in 0..n { + let base = price[i]; + if base == INF { + continue; + } + let cur_reps = reps_at[i]; + + // Option A: emit a literal. + let lit_cand = base.saturating_add(LIT_PRICE); + if lit_cand < price[i + 1] { + price[i + 1] = lit_cand; + back[i + 1] = (i as u32, 0, 0); + reps_at[i + 1] = cur_reps; + } + + if i + MIN_MATCH > n { + continue; + } + // Proxy literal-length for offset rep-aliasing: the common case is a + // sequence following some literals (LL>0, reps map to codes 1..=3). + let ll_proxy = 1u32; + + // Option B1: repeat-offset matches at the three active distances. + for &d in &cur_reps { + if d == 0 || (d as usize) > i { + continue; + } + let m = matcher.check_repeat_offset(buffer, i, d as usize); + if m >= MIN_MATCH { + let max_l = m.min(n - i); + let off = offset_price(d, &cur_reps, ll_proxy); + let mut l = MIN_MATCH; + while l <= max_l { + let cost = base + .saturating_add(off) + .saturating_add(ll_ml_price(0, l as u32)); + if cost < price[i + l] { + price[i + l] = cost; + back[i + l] = (i as u32, l as u32, d); + reps_at[i + l] = advance_reps(d, ll_proxy, &cur_reps); + } + if l == max_l { + break; + } + l = push_len(l, max_l); + } + } + } + + // Option B2: fresh hash-chain matches. + matcher.collect_matches(buffer, i, n, max_chain, nice_match, &mut cands); + for c in &cands { + let d = c.distance as u32; + let max_l = c.length.min(n - i); + let off = offset_price(d, &cur_reps, ll_proxy); + let mut l = MIN_MATCH; + while l <= max_l { + let cost = base + .saturating_add(off) + .saturating_add(ll_ml_price(0, l as u32)); + if cost < price[i + l] { + price[i + l] = cost; + back[i + l] = (i as u32, l as u32, d); + reps_at[i + l] = advance_reps(d, ll_proxy, &cur_reps); + } + if l == max_l { + break; + } + l = push_len(l, max_l); + } + } + } + + // Backtrack to recover the chosen steps, then emit sequences forward. + let mut steps: Vec<(u32, u32)> = Vec::new(); // (match_len, match_dist); 0 = literal + let mut i = n; + while i > 0 { + let (prev, mlen, mdist) = back[i]; + steps.push((mlen, mdist)); + i = prev as usize; + } + steps.reverse(); + + let mut sequences: Vec = Vec::new(); + let mut block_offsets = init_offsets; + let mut pending_literals: u32 = 0; + for (mlen, mdist) in steps { + if mlen == 0 { + pending_literals += 1; + continue; + } + let offset_value = assign_offset(mdist, pending_literals, &mut block_offsets); + sequences.push(Seq { + literal_length: pending_literals, + match_length: mlen, + offset_value, + }); + pending_literals = 0; + } + // Trailing literals are emitted by the block builder; drop the counter. + let _ = pending_literals; + + (sequences, block_offsets) +} + /// Find the best (distance, length) match at `pos`, mixing repeat-offset /// probes with a hash-chain search. /// @@ -715,38 +901,60 @@ fn best_at( max_chain: usize, nice_match: usize, ) -> (usize, usize, bool) { - // Repeat-offset probes. The reference encoder gives these strong - // preference because they're nearly free in the offset stream. - let mut best_len: usize = 0; - let mut best_dist: usize = 0; - let mut best_is_rep1: bool = false; + // Repeat-offset probes. A repeat offset costs only the FSE code (1..=3) in + // the offset stream and — crucially — emits NO offset extra bits, whereas + // a fresh offset at distance D spends ~log2(D) FSE-code bits PLUS ~log2(D) + // extra bits. On real corpora those offset extra bits are the single + // largest part of the output, so a repeat match that is several bytes + // shorter than the best fresh match is often still the cheaper encoding. + let mut rep_len: usize = 0; + let mut rep_dist: usize = 0; + let mut rep_is_rep1: bool = false; for (i, &d) in block_offsets.iter().enumerate() { let len = matcher.check_repeat_offset(buffer, pos, d as usize); // Prefer earlier rep slots on ties (they encode in fewer bits and // don't perturb the ring). - if len > best_len { - best_len = len; - best_dist = d as usize; - best_is_rep1 = i == 0; - if best_len >= nice_match { - return (best_dist, best_len, best_is_rep1); - } + if len > rep_len { + rep_len = len; + rep_dist = d as usize; + rep_is_rep1 = i == 0; } } + if rep_len >= nice_match { + return (rep_dist, rep_len, rep_is_rep1); + } - // Hash-chain probe. The matcher already returns the longest such match. - if let Some(m) = matcher.find_match(buffer, pos, buffer.len(), max_chain, nice_match) { - // For a fresh-offset match to beat a repeat match, it has to be - // strictly longer — repeat-offset matches save bits in the offset - // stream, so equal lengths favour the repeat. - if m.length > best_len { - best_len = m.length; - best_dist = m.distance; - best_is_rep1 = best_dist == block_offsets[0] as usize; + // Hash-chain probe (longest fresh match). + let fresh = matcher.find_match(buffer, pos, buffer.len(), max_chain, nice_match); + + match fresh { + Some(m) if rep_len >= MIN_MATCH => { + // Both a repeat and a fresh candidate exist. The fresh match must + // beat the repeat by enough length to pay for the offset bits it + // spends that the repeat avoids. A fresh offset at distance D costs + // roughly `2 * log2(D + 3)` bits more than a repeat; each matched + // byte is worth ~6 bits, so require the fresh match to be longer by + // at least `2 * log2(D) / 6` bytes. + let val = m.distance as u32 + 3; + let log2d = 31 - val.leading_zeros(); + let margin = ((2 * log2d) / 6).max(1) as usize; + if m.length >= rep_len + margin { + ( + m.distance, + m.length, + m.distance == block_offsets[0] as usize, + ) + } else { + (rep_dist, rep_len, rep_is_rep1) + } } + Some(m) => ( + m.distance, + m.length, + m.distance == block_offsets[0] as usize, + ), + None => (rep_dist, rep_len, rep_is_rep1), } - - (best_dist, best_len, best_is_rep1) } /// Pick the best per-table FSE mode (Predefined or FSE_Compressed) given the @@ -990,12 +1198,30 @@ fn try_build_huffman_literals_section_with( } let enc = build_huff_encoder(lengths); // Compute or skip the tree-description bytes depending on `fresh_tree`. + // When emitting a fresh tree we choose the smaller of two serialisations: + // - direct nibble-packed weights (only valid for ≤ 128 weights), and + // - FSE-compressed weights (mandatory above 128 weights, and often + // smaller for large skewed alphabets even below the cap). let tree_bytes: Vec = if fresh_tree { let (weights, _max_num_bits) = lengths_to_weights(lengths); - if weights.len() > 128 { - return None; // Direct nibble encoding cap. + let direct: Option> = if weights.len() <= 128 { + Some(encode_huff_tree_direct(&weights)) + } else { + None + }; + let fse = encode_huff_tree_fse(&weights); + match (direct, fse) { + (Some(d), Some(f)) => { + if f.len() < d.len() { + f + } else { + d + } + } + (Some(d), None) => d, + (None, Some(f)) => f, + (None, None) => return None, // alphabet too large for either path } - encode_huff_tree_direct(&weights) } else { Vec::new() }; diff --git a/src/zstd/encoder_huffman.rs b/src/zstd/encoder_huffman.rs index 55d21e8..70ba4e0 100644 --- a/src/zstd/encoder_huffman.rs +++ b/src/zstd/encoder_huffman.rs @@ -310,6 +310,121 @@ pub fn encode_huff_tree_direct(weights: &[u8]) -> Vec { out } +/// Encode a Huffman tree description using FSE-compressed weights +/// (Header_Byte < 128: the byte value is the FSE payload length in bytes). +/// +/// This is needed when the literal alphabet spans more than 128 byte values +/// (e.g. UTF-8 text, whose multi-byte lead/continuation bytes push the +/// highest-indexed present symbol past 127) — the direct nibble encoding caps +/// at 128 weights, so without this path such blocks fall back to a +/// Raw_Literals_Block and get no entropy coding at all. +/// +/// The payload layout matches the decoder in +/// [`crate::zstd::huffman::decode_fse_weights`]: an FSE table header +/// (accuracy_log ≤ 6, weight alphabet 0..=11) followed by two interleaved FSE +/// streams written backwards. Returns `None` if the weights can't be +/// FSE-coded smaller than (or the structure doesn't fit) — caller falls back. +pub fn encode_huff_tree_fse(weights: &[u8]) -> Option> { + use crate::zstd::encoder_fse::{FseEncoder, build_normalised_counts, encode_fse_table_header}; + + let n = weights.len(); + // Need at least 2 weights to run the 2-state interleaved encoder, and the + // decoder also requires ≥ 2 symbols (it inits two states). + if n < 2 { + return None; + } + + // Histogram of weight values (alphabet 0..=11). + const WALPHA: usize = 12; // weights are 0..=HUF_MAX_BITS(11) + let mut hist = [0u32; WALPHA]; + let mut max_w = 0usize; + for &w in weights { + let w = w as usize; + if w >= WALPHA { + return None; + } + hist[w] += 1; + if w > max_w { + max_w = w; + } + } + let max_symbol = max_w; // highest present weight value + + // Choose accuracy_log: weights use a small alphabet, RFC caps at 6. + // Pick the largest log (≤6) that still lets every present symbol get a + // slot; smaller tables save header bytes but a log of 6 keeps the streams + // tight, and the header is only a handful of bytes either way. + let mut accuracy_log: u8 = 6; + // accuracy_log must be ≥ 5 for the table-header encoder and large enough + // to hold the distinct present symbols. + let distinct = hist.iter().filter(|&&c| c > 0).count(); + while accuracy_log > 5 && (1u32 << accuracy_log) > (n as u32).max(distinct as u32) * 4 { + accuracy_log -= 1; + } + if accuracy_log < 5 { + accuracy_log = 5; + } + + let counts = build_normalised_counts(&hist[..=max_symbol], n as u32, accuracy_log)?; + let header = encode_fse_table_header(&counts, accuracy_log); + let enc = FseEncoder::from_normalized(&counts, accuracy_log); + + // The decoder (`decode_fse_weights`) emits weights in index order, with + // even indices owned by state 1 and odd indices by state 2: + // w0(s1) w1(s2) w2(s1) w3(s2) … + // It initialises s1 then s2 (each reads accuracy_log bits at the very end + // of the bitstream, so s1's init bits are read before s2's), then + // alternately emits+advances each state in increasing index order, and + // terminates by emitting the partner state's pending symbol. + // + // To replay `weights[0..n]` forward we run the two FSE state machines + // backwards: seed each state's `init_state` with the HIGHEST-index symbol + // it owns (the last symbol that state emits), then `encode_symbol` the + // remaining symbols from the highest index down to 0, picking the owning + // state by index parity. Each `encode_symbol(state, sym)` writes the bits + // the decoder consumes to land on `sym` while advancing — so forward + // decoding reproduces the original order. + let last_even = (n - 1).is_multiple_of(2); + let s1_high = if last_even { n - 1 } else { n - 2 }; + let s2_high = if last_even { n - 2 } else { n - 1 }; + let mut writer = RevBitWriter::new(); + let mut s1 = enc.init_state(weights[s1_high] as usize); + let mut s2 = enc.init_state(weights[s2_high] as usize); + let mut i1: isize = s1_high as isize - 2; + let mut i2: isize = s2_high as isize - 2; + loop { + if i1 < 0 && i2 < 0 { + break; + } + // Emit in strictly decreasing index order (the mirror of the decoder's + // increasing reads). + if i1 >= i2 { + s1 = enc.encode_symbol(s1, weights[i1 as usize] as usize, &mut writer); + i1 -= 2; + } else { + s2 = enc.encode_symbol(s2, weights[i2 as usize] as usize, &mut writer); + i2 -= 2; + } + } + // Final states: the decoder reads s1's init before s2's, and the reverse + // writer's last-written bits are read first — so write s2 first, then s1. + enc.write_final_state(s2, &mut writer); + enc.write_final_state(s1, &mut writer); + + let bitstream = writer.finish(); + let mut payload = Vec::with_capacity(1 + header.len() + bitstream.len()); + let fse_len = header.len() + bitstream.len(); + if fse_len >= 128 { + // Header_Byte must be < 128 (it IS the payload length). Too big to + // address — bail (caller falls back to direct/raw). + return None; + } + payload.push(fse_len as u8); + payload.extend_from_slice(&header); + payload.extend_from_slice(&bitstream); + Some(payload) +} + // ─── Stream encoding ────────────────────────────────────────────────────── /// Encode a slice of bytes as a single Huffman bitstream using `enc`. @@ -465,6 +580,43 @@ mod tests { assert_eq!(out, input); } + #[test] + fn fse_weights_round_trip() { + use crate::zstd::huffman::decode_huffman_tree_weights_for_test; + // Build a literal alphabet that spans > 128 byte values so the direct + // nibble path would be rejected. UTF-8-ish: bytes scattered across the + // 0..=200 range with skewed frequencies. + let mut freq = [0u32; 256]; + for b in 0u32..200 { + // Skewed: low bytes common, high bytes rare but present. + freq[b as usize] = 200 - b + 1; + } + let lengths = build_huff_lengths(&freq).unwrap(); + let (weights, _max) = lengths_to_weights(&lengths); + assert!(weights.len() > 128, "test needs > 128 weights"); + let payload = encode_huff_tree_fse(&weights).expect("fse weight encode"); + let decoded = decode_huffman_tree_weights_for_test(&payload).unwrap(); + assert_eq!(decoded, weights, "FSE weight round-trip mismatch"); + } + + #[test] + fn fse_weights_round_trip_small_alphabet() { + use crate::zstd::huffman::decode_huffman_tree_weights_for_test; + // Even a modest alphabet should round-trip (when it has ≥ 2 weights). + let text = + b"the quick brown fox jumps over the lazy dog. pack my box with five dozen liquor jugs."; + let mut freq = [0u32; 256]; + for &b in text { + freq[b as usize] += 1; + } + let lengths = build_huff_lengths(&freq).unwrap(); + let (weights, _max) = lengths_to_weights(&lengths); + if let Some(payload) = encode_huff_tree_fse(&weights) { + let decoded = decode_huffman_tree_weights_for_test(&payload).unwrap(); + assert_eq!(decoded, weights); + } + } + #[test] fn cap_code_lengths_idempotent_under_limit() { let mut lengths = [0u8; 256]; diff --git a/src/zstd/matcher.rs b/src/zstd/matcher.rs index 3826e76..d31fd1c 100644 --- a/src/zstd/matcher.rs +++ b/src/zstd/matcher.rs @@ -202,6 +202,76 @@ impl MatchFinder { let len = match_extend(buffer, src, pos, max_len); if len >= MIN_MATCH { len } else { 0 } } + + /// Collect distinct-length match candidates for `buffer[pos..]` for the + /// optimal parser. Walks the hash chain (bounded by `max_chain`) and, for + /// each length value reachable, records the *smallest distance* that + /// achieves it — a shorter distance is always at least as cheap to encode. + /// + /// Returns `(length, distance)` pairs with strictly increasing length, so + /// the price DP can try every length tier from `MIN_MATCH` up to the + /// longest match and weigh each against its offset cost. Stops early once a + /// match reaches `nice_match`. + pub fn collect_matches( + &self, + buffer: &[u8], + pos: usize, + window: usize, + max_chain: usize, + nice_match: usize, + out: &mut Vec, + ) { + out.clear(); + if pos + MIN_MATCH > buffer.len() || pos + 4 > buffer.len() { + return; + } + let h = hash4(&buffer[pos..pos + 4]) as usize; + let max_dist = window.min(pos); + let max_len = MAX_MATCH.min(buffer.len() - pos); + if max_len < MIN_MATCH { + return; + } + + let mut best_len: usize = MIN_MATCH - 1; + let mut cur = self.head[h]; + let mut steps = 0usize; + + while cur != NIL && steps < max_chain { + let cur_pos = cur as usize; + if cur_pos >= pos { + cur = self.prev[cur_pos]; + steps += 1; + continue; + } + let dist = pos - cur_pos; + if dist > max_dist { + break; + } + // Cheap rejection: can't beat the longest length we already have. + if best_len >= max_len { + break; + } + if buffer[cur_pos + best_len] == buffer[pos + best_len] { + let len = match_extend(buffer, cur_pos, pos, max_len); + if len > best_len { + // New longest tier. Because we walk the chain from the most + // recent position downward, the first candidate to reach a + // given length is at the smallest distance — exactly what + // we want for cheap offsets. + out.push(Match { + length: len, + distance: dist, + }); + best_len = len; + if len >= nice_match { + break; + } + } + } + cur = self.prev[cur_pos]; + steps += 1; + } + } } /// Extend a match forward up to `max_len` bytes, comparing `buffer[a..]` diff --git a/tests/bzip2.rs b/tests/bzip2.rs index 74400dc..c7ac941 100644 --- a/tests/bzip2.rs +++ b/tests/bzip2.rs @@ -222,6 +222,41 @@ fn round_trip_mixed_corpus() { round_trip(&input); } +#[test] +fn round_trip_large_compressible_multiblock() { + // 1.5 MB of zeros. Post-RLE-1 this is tiny, but it exercises the + // RLE-1-size-based block fill in the encoder (reference bzip2 sizes + // blocks by post-RLE-1 length, not raw input). Round-trips through + // the library decoder regardless of how compressible the data is. + let input = vec![0u8; 1_500_000]; + let encoded = encode_all(&input); + let decoded = decode_chunked(&encoded, 4096, 4096).unwrap(); + assert_eq!(decoded.len(), input.len()); + assert_eq!(decoded, input); +} + +#[test] +fn round_trip_low_level_forces_multiple_blocks() { + // At level 1 (≈100 KB post-RLE-1 cap) a ~250 KB low-redundancy + // payload spans several blocks, exercising the multi-table Huffman + // optimisation and selector encoding across block boundaries. + let mut state: u32 = 0x1234_5678; + let mut input = Vec::with_capacity(250_000); + while input.len() < 250_000 { + state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223); + input.push((state >> 16) as u8); + input.push((state >> 8) as u8); + // Inject some structure so the BWT/MTF/Huffman path is non-trivial. + if input.len() % 64 == 0 { + input.extend_from_slice(b"compcol-bzip2-multiblock "); + } + } + let mut enc = Encoder::with_config(EncoderConfig { level: 1 }); + let encoded = encode_chunked(&mut enc, &input, 7919, 4096); + let decoded = decode_chunked(&encoded, 4096, 4096).unwrap(); + assert_eq!(decoded, input); +} + // ─── streaming chunk sizes ───────────────────────────────────────────── #[test]