Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d748af2
remove FFT + Merkle padding optimization
TomWambsgans Apr 22, 2026
60b04f4
reverse variable ordering and apply split-eq trick
TomWambsgans Apr 23, 2026
fcda679
w
TomWambsgans Apr 23, 2026
1779e72
wip
TomWambsgans Apr 23, 2026
74ec0c9
Merge branch 'main' into whir-split-eq
TomWambsgans Apr 24, 2026
6b5d7f4
w
TomWambsgans Apr 24, 2026
f310d3f
simplify
TomWambsgans Apr 24, 2026
685269c
packing on build_all_compressed_groups (only improves avx)
TomWambsgans Apr 24, 2026
92b2d16
wip
TomWambsgans Apr 24, 2026
b691be7
w
TomWambsgans Apr 24, 2026
bef2f44
faster prepare_evals_for_fft_helper
TomWambsgans Apr 24, 2026
b9f0989
w
TomWambsgans Apr 24, 2026
6ba37ce
simplify
TomWambsgans Apr 24, 2026
8fd20ec
wip
TomWambsgans Apr 24, 2026
ad975ff
Merge branch 'main' into whir-split-eq
TomWambsgans Apr 24, 2026
fb31a0f
padding aware FFT
TomWambsgans Apr 24, 2026
fc62b98
even faster fft
TomWambsgans Apr 24, 2026
243b840
simplify
TomWambsgans Apr 24, 2026
6280964
faster merkle
TomWambsgans Apr 25, 2026
54b849b
avoid unnecessary allocation in initial Merkle tree
TomWambsgans Apr 25, 2026
5e341dd
Merge branch 'main' into whir-split-eq
TomWambsgans Apr 25, 2026
3bf8d69
Merge branch 'main' into whir-split-eq
TomWambsgans Apr 25, 2026
f346b1e
Merge branch 'main' into whir-split-eq
TomWambsgans Apr 25, 2026
e6c2329
add TODO comment for a potential opti in `compute_eval_eq_base_packed…
TomWambsgans Apr 25, 2026
681d173
Merge remote-tracking branch 'origin/main' into whir-split-eq
TomWambsgans Apr 25, 2026
8f1619e
fmt
TomWambsgans Apr 25, 2026
4062fa6
Perf simd compress hi dot (mirror of https://github.com/Plonky3/Plonk…
TomWambsgans Apr 27, 2026
fba3779
merge branch main
TomWambsgans Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/backend/fiat-shamir/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ where
// SAFETY: We've confirmed PF<EF> == KoalaBear
let paths: PrunedMerklePaths<KoalaBear, KoalaBear> = unsafe { std::mem::transmute(paths) };
let perm = default_koalabear_poseidon1_16();
let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data);
let hash_fn =
|data: &[KoalaBear]| symetric::hash_iter::<_, _, _, 16, 8, DIGEST_LEN_FE>(&perm, data.iter().copied());
let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| {
symetric::compress(&perm, [*left, *right])
};
Expand Down
25 changes: 8 additions & 17 deletions crates/backend/poly/src/eq_mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,8 @@ pub fn compute_eval_eq_base_packed<F, EF, const INITIALIZED: bool>(
}

#[inline]
pub fn compute_eval_eq_base_packed_batched<F, EF>(
evals: &[MultilinearPoint<F>],
out: &mut [EF::ExtensionPacking],
scalars: &[EF],
) where
pub fn compute_eval_eq_base_batched<F, EF>(evals: &[MultilinearPoint<F>], out: &mut [EF], scalars: &[EF])
where
F: Field,
EF: ExtensionField<F>,
{
Expand All @@ -383,22 +380,21 @@ pub fn compute_eval_eq_base_packed_batched<F, EF>(
}

let n = evals[0].len();
let packing_width = F::Packing::WIDTH;
let log_packing_width = log2_strict_usize(packing_width);
let log_packing_width = log2_strict_usize(F::Packing::WIDTH);
assert!(log_packing_width <= n);
assert_eq!(out.len(), 1 << (n - log_packing_width));
assert_eq!(out.len(), 1 << n);

let k = n.min(LOG_BATCHED_TILE_SIZE);

if k <= log_packing_width || k >= n {
for (eval, &scalar) in evals.iter().zip(scalars) {
compute_eval_eq_base_packed::<F, EF, true>(eval, out, scalar);
compute_eval_eq_base::<F, EF, true>(eval, out, scalar);
}
return;
}

let n_prefix_levels = n - k;
let tile_packed_size = 1 << (k - log_packing_width);
let tile_size = 1 << k;

let per_query: Vec<_> = evals
.iter()
Expand All @@ -412,19 +408,14 @@ pub fn compute_eval_eq_base_packed_batched<F, EF>(
})
.collect();

out.par_chunks_exact_mut(tile_packed_size)
out.par_chunks_exact_mut(tile_size)
.enumerate()
.for_each(|(tile_idx, out_tile)| {
for (eq_prefix, middle, eq_suffix) in &per_query {
// Here e could precompute the eq poly, trading some memory for less computation
// (2x faster on M4 max, but 2x slower on machines with smaller caches.
// TODO implement both and choose based on cache size?)
base_eval_eq_packed_with_packed_output::<F, EF, true>(
middle,
out_tile,
*eq_suffix,
EF::ExtensionPacking::from(eq_prefix[tile_idx]),
);
base_eval_eq_packed::<F, EF, true>(middle, out_tile, *eq_suffix, eq_prefix[tile_idx]);
}
});
}
Expand Down
9 changes: 9 additions & 0 deletions crates/backend/poly/src/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ where
}
}

impl<F: Clone> MultilinearPoint<F> {
#[must_use]
pub fn reversed(&self) -> Self {
let mut v = self.0.clone();
v.reverse();
Self(v)
}
}

impl<F> From<Vec<F>> for MultilinearPoint<F> {
fn from(v: Vec<F>) -> Self {
Self(v)
Expand Down
2 changes: 1 addition & 1 deletion crates/backend/symetric/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ where
return false;
}

let mut root = crate::hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values);
let mut root = crate::hash_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values.iter().copied());

for &sibling in opening_proof.iter() {
let (left, right) = if index & 1 == 0 {
Expand Down
75 changes: 9 additions & 66 deletions crates/backend/symetric/src/sponge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,9 @@ where
state[..OUT].try_into().unwrap()
}

/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks.
pub fn precompute_zero_suffix_state<T, Comp, const WIDTH: usize, const RATE: usize, const OUT: usize>(
comp: &Comp,
n_zero_chunks: usize,
) -> [T; WIDTH]
where
T: Default + Copy,
Comp: Compression<[T; WIDTH]>,
{
debug_assert!(RATE == OUT);
debug_assert!(WIDTH == OUT + RATE);
debug_assert!(n_zero_chunks >= 2);
let mut state = [T::default(); WIDTH];
comp.compress_mut(&mut state);
for _ in 0..n_zero_chunks - 2 {
for s in &mut state[WIDTH - RATE..] {
*s = T::default();
}
comp.compress_mut(&mut state);
}
state
}

/// RTL = Right-to-left
/// LTR = Left-to-right
#[inline(always)]
pub fn hash_rtl_iter<T, Comp, I, const WIDTH: usize, const RATE: usize, const OUT: usize>(
comp: &Comp,
rtl_iter: I,
) -> [T; OUT]
pub fn hash_iter<T, Comp, I, const WIDTH: usize, const RATE: usize, const OUT: usize>(comp: &Comp, iter: I) -> [T; OUT]
where
T: Default + Copy,
Comp: Compression<[T; WIDTH]>,
Expand All @@ -61,48 +35,17 @@ where
debug_assert!(RATE == OUT);
debug_assert!(WIDTH == OUT + RATE);
let mut state = [T::default(); WIDTH];
let mut iter = rtl_iter.into_iter();
for pos in (0..WIDTH).rev() {
state[pos] = iter.next().unwrap();
let mut iter = iter.into_iter();
for s in &mut state {
*s = iter.next().unwrap();
}
comp.compress_mut(&mut state);
absorb_rtl_chunks::<T, Comp, _, WIDTH, RATE, OUT>(comp, &mut state, &mut iter)
}

/// RTL = Right-to-left
#[inline(always)]
pub fn hash_rtl_iter_with_initial_state<T, Comp, I, const WIDTH: usize, const RATE: usize, const OUT: usize>(
comp: &Comp,
mut iter: I,
initial_state: &[T; WIDTH],
) -> [T; OUT]
where
T: Default + Copy,
Comp: Compression<[T; WIDTH]>,
I: Iterator<Item = T>,
{
let mut state = *initial_state;
absorb_rtl_chunks::<T, Comp, _, WIDTH, RATE, OUT>(comp, &mut state, &mut iter)
}

/// RTL = Right-to-left
#[inline(always)]
fn absorb_rtl_chunks<T, Comp, I, const WIDTH: usize, const RATE: usize, const OUT: usize>(
comp: &Comp,
state: &mut [T; WIDTH],
iter: &mut I,
) -> [T; OUT]
where
T: Default + Copy,
Comp: Compression<[T; WIDTH]>,
I: Iterator<Item = T>,
{
while let Some(elem) = iter.next() {
state[WIDTH - 1] = elem;
for pos in (WIDTH - RATE..WIDTH - 1).rev() {
state[pos] = iter.next().unwrap();
state[WIDTH - RATE] = elem;
for s in &mut state[WIDTH - RATE + 1..] {
*s = iter.next().unwrap();
}
comp.compress_mut(state);
comp.compress_mut(&mut state);
}
state[..OUT].try_into().unwrap()
}
44 changes: 0 additions & 44 deletions crates/rec_aggregation/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,6 @@
BIG_ENDIAN = 0


def batch_hash_slice_rtl(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks):
if num_chunks == DIM * 2:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM * 2)
return
if num_chunks == 16:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 16)
return
if num_chunks == 8:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 8)
return
if num_chunks == 20:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 20)
return
if num_chunks == 1:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 1)
return
if num_chunks == 4:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 4)
return
if num_chunks == 5:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 5)
return
print(num_chunks)
assert False, "batch_hash_slice called with unsupported len"


def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks: Const):
for i in range(0, num_queries):
data = all_data_to_hash[i]
res = slice_hash_rtl(data, num_chunks)
all_resulting_hashes[i] = res
return


@inline
def slice_hash_rtl(data, num_chunks):
states = Array((num_chunks - 1) * DIGEST_LEN)

poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states)
for j in unroll(1, num_chunks - 1):
poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN)
return states + (num_chunks - 2) * DIGEST_LEN


@inline
def slice_hash(data, num_chunks):
states = Array((num_chunks - 1) * DIGEST_LEN)
Expand Down
2 changes: 1 addition & 1 deletion crates/rec_aggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks):

leaf_data = Array(num_chunks * DIGEST_LEN)
hint_witness("merkle_leaf", leaf_data)
leaf_hash = slice_hash_rtl(leaf_data, num_chunks)
leaf_hash = slice_hash(leaf_data, num_chunks)

merkle_path = Array(domain_size * DIGEST_LEN)
hint_witness("merkle_path", merkle_path)
Expand Down
38 changes: 28 additions & 10 deletions crates/rec_aggregation/whir.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,19 @@ def whir_open(

folding_randomness_global = Array(n_vars * DIM)

start: Mut = folding_randomness_global
# WHIR sumcheck folds LSB-first, so chronological challenges are in reverse polynomial-var
# order. Write each chronological challenge to position (n_vars - 1 - chrono_idx) so the
# final cumulative reads as [x_0, x_1, ..., x_{n_vars-1}] (matches MSB-fold layout).
chrono_idx: Mut = 0
for i in range(0, n_rounds + 1):
for j in range(0, folding_factors[i]):
copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM)
start += folding_factors[i] * DIM
target_pos = n_vars - 1 - chrono_idx
copy_5(all_folding_randomness[i] + j * DIM, folding_randomness_global + target_pos * DIM)
chrono_idx += 1
for j in range(0, n_final_vars):
copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM)
target_pos = n_vars - 1 - chrono_idx
copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, folding_randomness_global + target_pos * DIM)
chrono_idx += 1

all_ood_recovered_evals = Array(num_oods[0] * DIM)
for i in range(0, num_oods[0]):
Expand All @@ -122,16 +128,16 @@ def whir_open(
num_oods[0],
)

# LSB-fold: at round r the polynomial has remaining vars [x_0, ..., x_{n_vars_remaining-1}],
# so the relevant cumulative slice is the FIRST n_vars_remaining elements (no pointer advance).
n_vars_remaining: Mut = n_vars
my_folding_randomness: Mut = folding_randomness_global
for i in range(0, n_rounds):
n_vars_remaining -= folding_factors[i]
my_ood_recovered_evals = Array(num_oods[i + 1] * DIM)
combination_randomness_powers = all_combination_randomness_powers[i]
my_folding_randomness += folding_factors[i] * DIM
for j in range(0, num_oods[i + 1]):
expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars_remaining)
poly_eq_extension_dynamic_to(expanded_from_univariate, my_folding_randomness, my_ood_recovered_evals + j * DIM, n_vars_remaining)
poly_eq_extension_dynamic_to(expanded_from_univariate, folding_randomness_global, my_ood_recovered_evals + j * DIM, n_vars_remaining)
summed_ood = Array(DIM)
dot_product_ee_dynamic(
my_ood_recovered_evals,
Expand All @@ -144,7 +150,7 @@ def whir_open(
circle_value_i = all_circle_values[i]
for j in range(0, num_queries[i]): # unroll ?
expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars_remaining)
poly_eq_base_extension_to(expanded_from_univariate, my_folding_randomness, s6s + j * DIM, n_vars_remaining)
poly_eq_base_extension_to(expanded_from_univariate, folding_randomness_global, s6s + j * DIM, n_vars_remaining)
s7 = Array(DIM)
dot_product_ee_dynamic(
s6s,
Expand All @@ -154,10 +160,17 @@ def whir_open(
)
s = add_extension_ret(s, s7)
s = add_extension_ret(summed_ood, s)
# WHIR sumcheck folds LSB-first: final_sumcheck challenges are [r_1=x_{m-1}, ..., r_m=x_0].
# eval_multilinear_coeffs_rev computes f(x_j = point[j]); for LSB-fold we need
# f(x_j = r_{m-j}) = point[j] = r_{j+1} = x_{m-j-1} which is wrong, so reverse first.
final_sumcheck_chals_rev = Array(n_final_vars * DIM)
final_sumcheck_chals = all_folding_randomness[n_rounds + 1]
for j in range(0, n_final_vars):
copy_5(final_sumcheck_chals + (n_final_vars - 1 - j) * DIM, final_sumcheck_chals_rev + j * DIM)
final_value = match_range(
n_final_vars,
range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1),
lambda n: eval_multilinear_coeffs_rev(final_coeffcients, all_folding_randomness[n_rounds + 1], n),
lambda n: eval_multilinear_coeffs_rev(final_coeffcients, final_sumcheck_chals_rev, n),
)
# copy_5(mul_extension_ret(s, final_value), end_sum);

Expand Down Expand Up @@ -301,7 +314,12 @@ def sample_stir_indexes_and_fold(

folds = Array(num_queries * DIM)

poly_eq = compute_eq_mle_extension_dynamic(folding_randomness, folding_factor)
# WHIR sumcheck folds LSB-first; the leaf is laid out so its first var is the polynomial's
# last LSB-folded var. evaluate (poly_eq) is MSB-first, so reverse the per-round challenges.
folding_randomness_reversed = Array(folding_factor * DIM)
for j in range(0, folding_factor):
copy_5(folding_randomness + (folding_factor - 1 - j) * DIM, folding_randomness_reversed + j * DIM)
poly_eq = compute_eq_mle_extension_dynamic(folding_randomness_reversed, folding_factor)

if merkle_leaves_in_basefield == 1:
for i in range(0, num_queries):
Expand Down
2 changes: 1 addition & 1 deletion crates/sub_protocols/src/quotient_gkr/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<'a, EF: ExtensionField<PF<EF>>> LayerStorage<'a, EF> {
}
}

pub fn materialise_in_full(self) -> (Vec<EF>, Vec<EF>) {
pub(super) fn materialise_in_full(self) -> (Vec<EF>, Vec<EF>) {
let natural = match self {
Self::Natural { .. } => self,
other => other.convert_to_natural(),
Expand Down
7 changes: 5 additions & 2 deletions crates/sub_protocols/src/stacked_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ pub fn stack_polynomials_and_commit(

let global_polynomial = MleOwned::Base(global_polynomial);

let inner_witness =
WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial, offset);
let inner_witness = WhirConfig::new(whir_config_builder, stacked_n_vars).commit_with_prefix_len(
prover_state,
&global_polynomial,
offset,
);
StackedPcsWitness {
stacked_n_vars,
inner_witness,
Expand Down
Loading
Loading