Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions crates/backend/poly/src/eq_mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,112 @@ fn packed_eq_poly<F: Field, EF: ExtensionField<F>>(eval: &[EF], scalar: EF) -> E
EF::ExtensionPacking::from_ext_slice(&buffer)
}

/// Tensor-tail variant of [`eval_eq`]: returns the table
/// `out[(hi << log2(tail.len())) | lo] = eq(eval)[hi] * tail[lo]`,
/// i.e. the eq expansion of `eval` tensored with an arbitrary vector `tail`
/// occupying the lowest `log2(tail.len())` variables. `tail.len()` must be a
/// power of two. With `tail = eval_eq(extra)` this equals
/// `eval_eq(concat(eval, extra))` exactly.
pub fn eval_eq_with_tail<F: ExtensionField<PF<F>>>(eval: &[F], tail: &[F]) -> ArenaVec<F> {
let log_tail = log2_strict_usize(tail.len());
let mut out = unsafe { ArenaVec::uninitialized(1 << (eval.len() + log_tail)) };
let (log_chunks, n_chunks) = parallel_split();
if eval.len() <= 1 + log_chunks {
eval_eq_tail_kernel(eval, &mut out, F::ONE, tail);
return out;
}
let mut buffer = F::zero_vec(n_chunks);
buffer[0] = F::ONE;
fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer);
let middle = &eval[log_chunks..];
let out_chunk_size = out.len() / n_chunks;
par_chunks_zip(&mut out, out_chunk_size, &buffer, |out_chunk, &b| {
eval_eq_tail_kernel(middle, out_chunk, b, tail);
});
out
}

/// Packed-output variant of [`eval_eq_with_tail`]; same table, packed like
/// [`eval_eq_packed`]. Requires `eval.len() + log2(tail.len()) >= packing_log_width`.
pub fn eval_eq_packed_with_tail<F: ExtensionField<PF<F>>>(eval: &[F], tail: &[F]) -> ArenaVec<EFPacking<F>> {
let w = packing_log_width::<F>();
let k = log2_strict_usize(tail.len());
assert!(eval.len() + k >= w);
if k < w {
// Absorb the (w − k) trailing `eval` variables into the tail so the
// packed lanes are fully covered by the (extended) tail.
let absorb = w - k;
let eq_absorb = eval_eq(&eval[eval.len() - absorb..]);
let mut extended_tail = F::zero_vec(1 << w);
for (a, &ea) in eq_absorb.iter().enumerate() {
for (lo, &t) in tail.iter().enumerate() {
extended_tail[(a << k) | lo] = ea * t;
}
}
return eval_eq_packed_with_tail(&eval[..eval.len() - absorb], &extended_tail);
}
let tail_packed: Vec<EFPacking<F>> = pack_extension(tail);
let mut out = unsafe { ArenaVec::uninitialized(1 << (eval.len() + k - w)) };
let (log_chunks, n_chunks) = parallel_split();
if eval.len() <= 1 + log_chunks {
eval_eq_tail_kernel_packed::<F>(eval, &mut out, F::ONE, &tail_packed);
return out;
}
let mut buffer = F::zero_vec(n_chunks);
buffer[0] = F::ONE;
fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer);
let middle = &eval[log_chunks..];
let out_chunk_size = out.len() / n_chunks;
par_chunks_zip(&mut out, out_chunk_size, &buffer, |out_chunk, &b| {
eval_eq_tail_kernel_packed::<F>(middle, out_chunk, b, &tail_packed);
});
out
}

/// Recursive kernel for [`eval_eq_with_tail`]: standard eq split on `eval`,
/// with the leaf writing `scalar * tail` instead of a single scalar.
#[inline]
fn eval_eq_tail_kernel<F: Field>(eval: &[F], out: &mut [F], scalar: F, tail: &[F]) {
debug_assert_eq!(out.len(), tail.len() << eval.len());
match eval.split_first() {
None => {
out.iter_mut().zip(tail).for_each(|(o, &t)| *o = t * scalar);
}
Some((&x, rest)) => {
let (low, high) = out.split_at_mut(out.len() / 2);
let s1 = scalar * x;
let s0 = scalar - s1;
eval_eq_tail_kernel(rest, low, s0, tail);
eval_eq_tail_kernel(rest, high, s1, tail);
}
}
}

/// Recursive kernel for [`eval_eq_packed_with_tail`] (requires the tail to
/// cover at least the packing width, guaranteed by the absorption step).
#[inline]
fn eval_eq_tail_kernel_packed<F: ExtensionField<PF<F>>>(
eval: &[F],
out: &mut [EFPacking<F>],
scalar: F,
tail_packed: &[EFPacking<F>],
) {
debug_assert_eq!(out.len(), tail_packed.len() << eval.len());
match eval.split_first() {
None => {
let b = EFPacking::<F>::from(scalar);
out.iter_mut().zip(tail_packed).for_each(|(o, &t)| *o = t * b);
}
Some((&x, rest)) => {
let (low, high) = out.split_at_mut(out.len() / 2);
let s1 = scalar * x;
let s0 = scalar - s1;
eval_eq_tail_kernel_packed::<F>(rest, low, s0, tail_packed);
eval_eq_tail_kernel_packed::<F>(rest, high, s1, tail_packed);
}
}
}

#[cfg(test)]
mod tests {
use std::time::Instant;
Expand Down Expand Up @@ -1322,4 +1428,55 @@ mod tests {
assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars);
}
}

#[test]
fn test_eval_eq_with_tail_matches_point_append() {
let mut rng = StdRng::seed_from_u64(7);
for n_vars in [2usize, 5, 9, 12] {
for k in [2usize, 3, 4] {
let eval: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let extra: Vec<EF> = (0..k).map(|_| rng.random()).collect();
let tail = eval_eq(&extra);
let with_tail = eval_eq_with_tail(&eval, &tail);
let appended = eval_eq(&[eval.clone(), extra.clone()].concat());
assert_eq!(with_tail.as_slice(), appended.as_slice(), "n={n_vars} k={k}");
}
}
}

#[test]
fn test_eval_eq_with_tail_random_tail_brute_force() {
let mut rng = StdRng::seed_from_u64(8);
let n_vars = 6;
let k = 3;
let eval: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let tail: Vec<EF> = (0..1 << k).map(|_| rng.random()).collect();
let with_tail = eval_eq_with_tail(&eval, &tail);
let eq_hi = eval_eq(&eval);
for hi in 0..1usize << n_vars {
for lo in 0..1usize << k {
assert_eq!(with_tail[(hi << k) | lo], eq_hi[hi] * tail[lo]);
}
}
}

#[test]
fn test_eval_eq_packed_with_tail_matches_unpacked() {
let mut rng = StdRng::seed_from_u64(9);
let w = packing_log_width::<EF>();
for n_vars in [2usize, 5, 9, 12] {
// Cover both k < w and k >= w paths regardless of the platform width.
for k in [2usize, 3, 4, 5] {
if n_vars + k < w {
continue;
}
let eval: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let tail: Vec<EF> = (0..1 << k).map(|_| rng.random()).collect();
let unpacked = eval_eq_with_tail(&eval, &tail);
let packed = eval_eq_packed_with_tail(&eval, &tail);
let unpacked_from_packed: Vec<EF> = unpack_extension(&packed);
assert_eq!(unpacked.as_slice(), &unpacked_from_packed[..], "n={n_vars} k={k}");
}
}
}
}
100 changes: 99 additions & 1 deletion crates/backend/poly/src/next_mle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ::utils::log2_strict_usize;
use field::{ExtensionField, Field, PrimeCharacteristicRing};
use zk_alloc::ArenaVec;

use crate::{PF, eval_eq_scaled};
use crate::{PF, eval_eq_scaled, eval_eq_with_tail, to_big_endian_in_field};

/// Evaluates the "next" multilinear polynomial at two n-variable points (x, y).
///
Expand Down Expand Up @@ -54,6 +55,39 @@ where
res
}

/// Tensor-tail variant of [`next_mle`]:
/// `Σ_x tail[x] · next_mle(concat(prefix, bits(x)), y)`, where `bits(x)` is
/// big-endian over `log2(tail.len())` variables (matching `eval_eq` indexing).
/// Verifier-side: the loop over the `2^k` cube points is intentional.
pub fn next_mle_with_tail<F: Field>(prefix: &[F], tail: &[F], y: &[F]) -> F {
let k = log2_strict_usize(tail.len());
debug_assert_eq!(prefix.len() + k, y.len());
let mut sum = F::ZERO;
for (x, &t) in tail.iter().enumerate() {
let mut point = prefix.to_vec();
point.extend(to_big_endian_in_field::<F>(x, k));
sum += next_mle(&point, y) * t;
}
sum
}

/// Tensor-tail variant of [`matrix_next_mle_folded`]: the dense vector
/// `w[y] = Σ_x tail[x] · next_mle(concat(prefix, bits(x)), y)`.
///
/// Since `next_mle` is multilinear in its first argument,
/// `w[y] = Σ_j v[j] · next_mle(j, y)` with `v = eval_eq_with_tail(prefix, tail)`,
/// and `next_mle(j, y) = 1` iff `y = j + 1`, plus the wrap-around
/// `next_mle(2^n − 1, 2^n − 1) = 1` (see [`next_mle`]). Hence `w` is the
/// shift-by-one of `v`, with `w[last] += v[last]`.
pub fn matrix_next_mle_folded_with_tail<F: ExtensionField<PF<F>>>(prefix: &[F], tail: &[F]) -> ArenaVec<F> {
let v = eval_eq_with_tail(prefix, tail);
let n = v.len();
let mut res = unsafe { ArenaVec::<F>::zeroed(n) };
res[1..].copy_from_slice(&v[..n - 1]);
res[n - 1] += v[n - 1];
res
}

#[cfg(test)]
mod tests {
use field::PrimeCharacteristicRing;
Expand Down Expand Up @@ -81,4 +115,68 @@ mod tests {
}
}
}

#[test]
fn test_next_mle_with_tail_brute_force() {
use koala_bear::QuinticExtensionFieldKB;
use rand::{RngExt, SeedableRng, rngs::StdRng};

use crate::next_mle_with_tail;
type EF = QuinticExtensionFieldKB;

let mut rng = StdRng::seed_from_u64(11);
for k in [2usize, 3] {
let n_prefix = 5 - k;
let prefix: Vec<EF> = (0..n_prefix).map(|_| rng.random()).collect();
let tail: Vec<EF> = (0..1 << k).map(|_| rng.random()).collect();
let y: Vec<EF> = (0..5).map(|_| rng.random()).collect();
let direct = next_mle_with_tail(&prefix, &tail, &y);
let mut brute = EF::ZERO;
for (x, &t) in tail.iter().enumerate() {
let mut point = prefix.clone();
point.extend(to_big_endian_in_field::<EF>(x, k));
brute += next_mle(&point, &y) * t;
}
assert_eq!(direct, brute);
}
}

#[test]
fn test_matrix_next_mle_folded_with_tail_matches_sum() {
use koala_bear::QuinticExtensionFieldKB;
use rand::{RngExt, SeedableRng, rngs::StdRng};

use crate::{matrix_next_mle_folded_with_tail, next_mle_with_tail};
type EF = QuinticExtensionFieldKB;

let mut rng = StdRng::seed_from_u64(12);
for k in [2usize, 3] {
let n_prefix = 5 - k;
let prefix: Vec<EF> = (0..n_prefix).map(|_| rng.random()).collect();
let tail: Vec<EF> = (0..1 << k).map(|_| rng.random()).collect();

let folded = matrix_next_mle_folded_with_tail(&prefix, &tail);

// Elementwise against the sum of per-cube-point folded matrices.
let mut expected = EF::zero_vec(1 << 5);
for (x, &t) in tail.iter().enumerate() {
let mut point = prefix.clone();
point.extend(to_big_endian_in_field::<EF>(x, k));
for (e, &m) in expected.iter_mut().zip(matrix_next_mle_folded(&point).iter()) {
*e += m * t;
}
}
assert_eq!(folded.as_slice(), &expected[..]);

// Consistency with the pointwise variant: the folded vector's MLE at a
// boolean point y equals next_mle_with_tail(prefix, tail, y).
for y in 0..1usize << 5 {
let y_bools = to_big_endian_in_field::<EF>(y, 5);
assert_eq!(
folded.evaluate(&MultilinearPoint(y_bools.clone())),
next_mle_with_tail(&prefix, &tail, &y_bools)
);
}
}
}
}
3 changes: 3 additions & 0 deletions crates/backend/sumcheck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pub use sc_computation::*;

mod product_computation;
pub use product_computation::*;

mod univariate_skip;
pub use univariate_skip::*;
Loading
Loading